From d2fa4a4feacf924b0bf281cebc4c6ccab3bbba48 Mon Sep 17 00:00:00 2001 From: AlgebraicJulia Bot <129184742+algebraicjuliabot@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:00:37 -0400 Subject: [PATCH 01/16] Set version to 0.1.7 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 21b3563..ab55d8e 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DiagrammaticEquations" uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317" license = "MIT" authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"] -version = "0.1.6" +version = "0.1.7" [deps] ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" From b14f09d600828877a68a1002a4a7f50487a70c0c Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:45:35 -0400 Subject: [PATCH 02/16] Added more exports (#44) Added `apex` and `@relation`, `to_graphviz` from Catlab Co-authored-by: James --- src/DiagrammaticEquations.jl | 5 ++++- test/composition.jl | 5 +---- test/language.jl | 6 ------ 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 0d0de25..ef5f2d6 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -2,6 +2,8 @@ """ module DiagrammaticEquations +using Catlab + export DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, # Deca @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, Collage, collate, ## composition oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, +apex, @relation, # Re-exported from Catlab ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, @@ -25,12 +28,12 @@ unique_lits!, Plus, AppCirc1, Var, Tan, App1, App2, ## visualization to_graphviz_property_graph, typename, draw_composition, +to_graphviz, # Re-exported from Catlab ## rewrite average_rewrite, ## openoperators transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! -using Catlab using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom using Catlab.Programs diff --git a/test/composition.jl b/test/composition.jl index f0c5c02..408cf37 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -2,15 +2,12 @@ using Test using DiagrammaticEquations using DiagrammaticEquations.Deca using Catlab -using Catlab.WiringDiagrams -using Catlab.Programs -using Catlab.CategoricalAlgebra # import DiagrammaticEquations: OpenSummationDecapode, Open, oapply, oapply_rename # @testset "Composition" begin # Simplest possible decapode relation. -Trivial = @decapode begin +Trivial = @decapode begin H::Form0{X} end diff --git a/test/language.jl b/test/language.jl index 3e3ab4e..c0c8680 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1,11 +1,5 @@ using Test using Catlab -using Catlab.Theories -using Catlab.CategoricalAlgebra -using Catlab.WiringDiagrams -using Catlab.WiringDiagrams.DirectedWiringDiagrams -using Catlab.Graphics -using Catlab.Programs using LinearAlgebra using MLStyle using Base.Iterators From c9b8898eda622e8ba69a059ab00416104f4e4556 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 14:35:57 -0400 Subject: [PATCH 03/16] Add type rules for vectorfields --- src/acset.jl | 40 +++++++++++------ src/deca/deca_acset.jl | 26 +++++++---- test/language.jl | 97 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 129 insertions(+), 34 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 9665afd..5366459 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -158,13 +158,16 @@ end # A collection of DecaType getters # TODO: This should be replaced by using a type hierarchy const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, - :Literal, :Parameter, :Constant, :infer] + :PVF, :DVF, + :Literal, :Parameter, :Constant, :infer] const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] -const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const VECTORFIELD_TYPES = [:PVF, :DVF] + +const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer] const USER_TYPES = [:Constant, :Parameter] const NUMBER_TYPES = [:Literal] const INFER_TYPES = [:infer] @@ -427,12 +430,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) end """ - filterfor_forms(types::AbstractVector{Symbol}) + filterfor_ec_types(types::AbstractVector{Symbol}) -Return any form type symbols. +Return any form or vector-field type symbols. """ -function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> !(x in NONFORM_TYPES) +function filterfor_ec_types(types::AbstractVector{Symbol}) + conditions = x -> !(x in NON_EC_TYPES) filter(conditions, types) end @@ -447,16 +450,16 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) types = d[idxs, :type] all(t != :infer for t in types) && return applied # We need not infer - forms = unique(filterfor_forms(types)) + ec_types = unique(filterfor_ec_types(types)) - form = @match length(forms) begin + ec_type = @match length(ec_types) begin 0 => return applied # We can not infer - 1 => only(forms) - _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") + 1 => only(ec_types) + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types") end for idx in idxs - applied |= safe_modifytype!(d, idx, form) + applied |= safe_modifytype!(d, idx, ec_type) end return applied @@ -489,11 +492,24 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - if(check_op && (score_proj1 + score_proj2 + score_res == 2)) + if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res + # Special logic for exponentiation: + elseif d[op2_id, :op2] == :^ && + (type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) || + (type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer)) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0) + mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0) + return mod_proj1 || mod_res + elseif d[op2_id, :op2] == :^ && + (type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) || + (type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer)) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0) + mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0) + return mod_proj1 || mod_res end return false diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 55520e3..c5fa67a 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -32,10 +32,17 @@ op1_inf_rules_1D = [ # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), + + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] op2_inf_rules_1D = [ # Rules for ∧₀₀, ∧₁₀, ∧₀₁ @@ -133,13 +140,16 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]), (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]), - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]), - (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ diff --git a/test/language.jl b/test/language.jl index c0c8680..ec1ff3c 100644 --- a/test/language.jl +++ b/test/language.jl @@ -350,13 +350,14 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end -import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, + DUALFORM_TYPES, VECTORFIELD_TYPES, NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, + INFER_TYPES, NONINFERABLE_TYPES @testset "Type Retrival" begin type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] # No repeated types @@ -368,12 +369,12 @@ import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_ no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) # Collections of these types should be the same - @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) - @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) - @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) + @test equal_types(ALL_TYPES, FORM_TYPES ∪ VECTORFIELD_TYPES ∪ NON_EC_TYPES) + @test equal_types(FORM_TYPES, PRIMALFORM_TYPES ∪ DUALFORM_TYPES) + @test equal_types(NONINFERABLE_TYPES, USER_TYPES ∪ NUMBER_TYPES) # Proper seperation of types - @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(FORM_TYPES ∪ VECTORFIELD_TYPES, NON_EC_TYPES) @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) @test INFER_TYPES == [:infer] @@ -394,9 +395,9 @@ end import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] bad_sources = [:Literal, :Constant, :Parameter] - good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :infer] + good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF, :infer] for tgt in all_types for src in bad_sources @@ -419,13 +420,13 @@ import DiagrammaticEquations: safe_modifytype end end -import DiagrammaticEquations: filterfor_forms +import DiagrammaticEquations: filterfor_ec_types @testset "Form Type Retrieval" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] - @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] - @test isempty(filterfor_forms(Symbol[])) - @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] + @test filterfor_ec_types(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF] + @test isempty(filterfor_ec_types(Symbol[])) + @test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer])) end @testset "Type Inference" begin @@ -825,6 +826,38 @@ end @test_throws "Type mismatch in summation" infer_types!(d) end + # Test #25: Infer between flattened and sharpened vector fields. + let + d = @decapode begin + A::Form1 + B::DualForm1 + C::PVF + D::DVF + + A == ♭(E) + B == ♭(F) + C == ♯(G) + D == ♯(H) + + I::Form1 + J::DualForm1 + K::PVF + L::DVF + + M == ♯(I) + N == ♯(J) + O == ♭(K) + P == ♭(L) + end + infer_types!(d) + + # TODO: Update this as more sharps and flats are released. + names_types_expected = Set([(:A, :Form1), (:B, :DualForm1), (:C, :PVF), (:D, :DVF), + (:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1), + (:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF), + (:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)]) + @test test_nametype_equality(d, names_types_expected) + end end @testset "Overloading Resolution" begin @@ -1042,6 +1075,42 @@ end op2s_hx = HeatXfer[:op2] op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀] @test op2s_hx == op2s_expected_hx + + # Infer types and resolve overloads for the Halfar equation. + let + d = @decapode begin + h::Form0 + Γ::Form1 + n::Constant + + ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) + end + d = expand_operators(d) + infer_types!(d) + resolve_overloads!(d) + @test d == @acset SummationDecapode{Any, Any, Symbol} begin + Var = 19 + TVar = 1 + Op1 = 8 + Op2 = 6 + Σ = 1 + Summand = 2 + src = [1, 1, 1, 13, 12, 6, 18, 19] + tgt = [4, 9, 13, 12, 11, 18, 19, 4] + proj1 = [2, 3, 11, 8, 1, 7] + proj2 = [9, 15, 14, 10, 5, 16] + res = [8, 14, 10, 7, 16, 6] + incl = [4] + summand = [3, 17] + summation = [1, 1] + sum = [5] + op1 = [:∂ₜ, :d₀, :d₀, :♯, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] + type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] + name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] + end + end + end @testset "Compilation Transformation" begin From 79f27e700322b1c8e73592caea04352e34bf8ac1 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 15:38:45 -0400 Subject: [PATCH 04/16] Add musical overload resolution --- src/deca/deca_acset.jl | 5 +++++ test/language.jl | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index c5fa67a..b557553 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -253,6 +253,11 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ), (src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif), (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯), + (src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯), + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭), # Rules for ∇². # TODO: Call this :nabla2 in ASCII? (src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²), diff --git a/test/language.jl b/test/language.jl index ec1ff3c..25f817e 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1104,7 +1104,7 @@ end summand = [3, 17] summation = [1, 1] sum = [5] - op1 = [:∂ₜ, :d₀, :d₀, :♯, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op1 = [:∂ₜ, :d₀, :d₀, :♯ᵖᵖ, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] From 23fbf3f727eb58822ec15dd25f89f7766d34f1a8 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 17:45:52 -0400 Subject: [PATCH 05/16] Take advantage of :infer in type rules --- src/acset.jl | 32 ++++++-------------------------- src/deca/deca_acset.jl | 6 +++++- 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 5366459..34d1d3c 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -466,13 +466,10 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) end function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) - type_src = d[d[op1_id, :src], :type] - type_tgt = d[d[op1_id, :tgt], :type] + score_src = (rule.src_type == d[d[op1_id, :src], :type]) + score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) - score_src = (rule.src_type == type_src) - score_tgt = (rule.tgt_type == type_tgt) check_op = (d[op1_id, :op1] in rule.op_names) - if(check_op && (score_src + score_tgt == 1)) mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) @@ -483,33 +480,16 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) end function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) - type_proj1 = d[d[op2_id, :proj1], :type] - type_proj2 = d[d[op2_id, :proj2], :type] - type_res = d[d[op2_id, :res], :type] + score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) + score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) + score_res = (rule.res_type == d[d[op2_id, :res], :type]) - score_proj1 = (rule.proj1_type == type_proj1) - score_proj2 = (rule.proj2_type == type_proj2) - score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res - # Special logic for exponentiation: - elseif d[op2_id, :op2] == :^ && - (type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) || - (type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0) - mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0) - return mod_proj1 || mod_res - elseif d[op2_id, :op2] == :^ && - (type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) || - (type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0) - mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0) - return mod_proj1 || mod_res end return false diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index b557553..4334c4b 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -90,7 +90,11 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])] + (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + + # These rules contain infer: + (proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]), + (proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])] """ These are the default rules used to do type inference in the 2D exterior calculus. From 2a6269cfd4a0647991a40cf101f77066d94e2120 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:34:08 -0400 Subject: [PATCH 06/16] Initial attempt at rewriting Converts ACSet to a series of Symbolic terms that can be rewritten with a provided rewriter --- Project.toml | 2 + src/DiagrammaticEquations.jl | 2 + src/acset2symbolic.jl | 60 ++++++++++++++++++++++++++++++ src/graph_traversal.jl | 72 ++++++++++++++++++++++++++++++++++++ test/graph_traversal.jl | 64 ++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++ 6 files changed, 204 insertions(+) create mode 100644 src/acset2symbolic.jl create mode 100644 src/graph_traversal.jl create mode 100644 test/graph_traversal.jl diff --git a/Project.toml b/Project.toml index ab55d8e..d79bbe9 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [compat] @@ -16,5 +17,6 @@ ACSets = "0.2" Catlab = "0.15, 0.16" DataStructures = "0.18.13" MLStyle = "0.4.17" +SymbolicUtils = "3.4" Unicode = "1.6" julia = "1.6" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index ef5f2d6..5ab376a 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -62,6 +62,8 @@ include("rewrite.jl") include("pretty.jl") include("colanguage.jl") include("openoperators.jl") +include("graph_traversal.jl") +include("acset2symbolic.jl") include("deca/Deca.jl") include("learn/Learn.jl") diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl new file mode 100644 index 0000000..a2b24e6 --- /dev/null +++ b/src/acset2symbolic.jl @@ -0,0 +1,60 @@ +using DiagrammaticEquations +using SymbolicUtils +using SymbolicUtils.Rewriters +using SymbolicUtils.Code +using MLStyle + +const DECA_EQUALITY_SYMBOL = (==) + +to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) + +function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) + input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name]) + output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) + op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + + rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) + SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) +end + +function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) + input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name]) + input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name]) + output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) + op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + + rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) + SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) +end + +#XXX: Always converting + -> .+ here since summation doesn't store the style of addition +# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ}) +# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) +# end + +function extract_symexprs(d::SummationDecapode) + topo_list = topological_sort_edges(d) + sym_list = [] + for node in topo_list + retrieve_name(d, node) != DerivOp || continue + push!(sym_list, to_symbolics(d, node)) + end + sym_list +end + +function apply_rewrites(d::SummationDecapode, rewriter) + + rewritten_list = [] + for sym in extract_symexprs(d) + res_sym = rewriter(sym) + rewritten_sym = isnothing(res_sym) ? sym : res_sym + push!(rewritten_list, rewritten_sym) + end + + rewritten_list +end + +# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet +# @syms Δ(x) d(x) ⋆(x) +# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) +# rewriter = Postwalk(RestartedChain([lap_0_rule])) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl new file mode 100644 index 0000000..b9551c8 --- /dev/null +++ b/src/graph_traversal.jl @@ -0,0 +1,72 @@ +using DiagrammaticEquations +using ACSets + +export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name + +struct TraversalNode{T} + index::Int + name::T +end + +function topological_sort_edges(d::SummationDecapode) + visited_Var = falses(nparts(d, :Var)) + visited_Var[start_nodes(d)] .= true + + # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ + visited_1 = falses(nparts(d, :Op1)) + visited_2 = falses(nparts(d, :Op2)) + visited_Σ = falses(nparts(d, :Σ)) + + # FIXME: this is a quadratic implementation of topological_sort inlined in here. + op_order = TraversalNode{Symbol}[] + + for _ in 1:number_of_ops(d) + for op in parts(d, :Op1) + if !visited_1[op] && visited_Var[d[op, :src]] + + visited_1[op] = true + visited_Var[d[op, :tgt]] = true + + push!(op_order, TraversalNode(op, :Op1)) + end + end + + for op in parts(d, :Op2) + if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] + visited_2[op] = true + visited_Var[d[op, :res]] = true + push!(op_order, TraversalNode(op, :Op2)) + end + end + + for op in parts(d, :Σ) + args = subpart(d, incident(d, op, :summation), :summand) + if !visited_Σ[op] && all(visited_Var[args]) + visited_Σ[op] = true + visited_Var[d[op, :sum]] = true + push!(op_order, TraversalNode(op, :Σ)) + end + end + end + + @assert length(op_order) == number_of_ops(d) + + op_order +end + +function number_of_ops(d::SummationDecapode) + return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) +end + +function start_nodes(d::SummationDecapode) + return vcat(infer_states(d), incident(d, :Literal, :type)) +end + +function retrieve_name(d::SummationDecapode, tsr::TraversalNode) + @match tsr.name begin + :Op1 => d[tsr.index, :op1] + :Op2 => d[tsr.index, :op2] + :Σ => :+ + _ => error("$(tsr.name) is not a valid table for names") + end +end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl new file mode 100644 index 0000000..7a5d6ac --- /dev/null +++ b/test/graph_traversal.jl @@ -0,0 +1,64 @@ +using DiagrammaticEquations +using ACSets +using MLStyle +using Test + +function is_correct_length(d::SummationDecapode, result) + return length(result) == number_of_ops(d) +end + +@testset "Topological Sort on Edges" begin + no_edge = @decapode begin + F == S + end + @test isempty(topological_sort_edges(no_edge)) + + one_op1_deca = @decapode begin + F == f(S) + end + result = topological_sort_edges(one_op1_deca) + @test is_correct_length(one_op1_deca, result) + @test retrieve_name(one_op1_deca, only(result)) == :f + + multi_op1_deca = @decapode begin + F == c(b(a(S))) + end + result = topological_sort_edges(multi_op1_deca) + @test is_correct_length(multi_op1_deca, result) + for (edge, test_name) in zip(result, [:a, :b, :c]) + @test retrieve_name(multi_op1_deca, edge) == test_name + end + + cyclic = @decapode begin + B == g(A) + A == f(B) + end + @test_throws AssertionError topological_sort_edges(cyclic) + + just_op2 = @decapode begin + C == A * B + end + result = topological_sort_edges(just_op2) + @test is_correct_length(just_op2, result) + @test retrieve_name(just_op2, only(result)) == :* + + just_simple_sum = @decapode begin + C == A + B + end + result = topological_sort_edges(just_simple_sum) + @test is_correct_length(just_simple_sum, result) + @test retrieve_name(just_simple_sum, only(result)) == :+ + + just_multi_sum = @decapode begin + F == A + B + C + D + E + end + result = topological_sort_edges(just_multi_sum) + @test is_correct_length(just_multi_sum, result) + @test retrieve_name(just_multi_sum, only(result)) == :+ + + op_combo = @decapode begin + F == h(d(A) + f(g(B) * C) + D) + end + result = topological_sort_edges(op_combo) + @test is_correct_length(op_combo, result) +end diff --git a/test/runtests.jl b/test/runtests.jl index 0cb0f30..defc285 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,3 +39,7 @@ end @testset "Open Operators" begin include("openoperators.jl") end + +@testset "Symbolic Rewriting" begin + include("graph_traversal.jl") +end From 2512181ba22850e595aaba70cae52c8c949c97ed Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sat, 24 Aug 2024 02:25:29 -0400 Subject: [PATCH 07/16] Define Floyd-Warshall algorithm on Decapodes --- src/graph_traversal.jl | 93 ++++++++++++++++++++++------------------- test/graph_traversal.jl | 3 ++ 2 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index b9551c8..322d3ed 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,5 +1,6 @@ using DiagrammaticEquations using ACSets +using DataStructures export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name @@ -8,58 +9,63 @@ struct TraversalNode{T} name::T end -function topological_sort_edges(d::SummationDecapode) - visited_Var = falses(nparts(d, :Var)) - visited_Var[start_nodes(d)] .= true - - # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ - visited_1 = falses(nparts(d, :Op1)) - visited_2 = falses(nparts(d, :Op2)) - visited_Σ = falses(nparts(d, :Σ)) - - # FIXME: this is a quadratic implementation of topological_sort inlined in here. - op_order = TraversalNode{Symbol}[] +number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) - for _ in 1:number_of_ops(d) - for op in parts(d, :Op1) - if !visited_1[op] && visited_Var[d[op, :src]] +start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) - visited_1[op] = true - visited_Var[d[op, :tgt]] = true - - push!(op_order, TraversalNode(op, :Op1)) - end - end - - for op in parts(d, :Op2) - if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] - visited_2[op] = true - visited_Var[d[op, :res]] = true - push!(op_order, TraversalNode(op, :Op2)) - end - end - - for op in parts(d, :Σ) - args = subpart(d, incident(d, op, :summation), :summand) - if !visited_Σ[op] && all(visited_Var[args]) - visited_Σ[op] = true - visited_Var[d[op, :sum]] = true - push!(op_order, TraversalNode(op, :Σ)) +#https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm#Pseudocode +function floyd_warshall(d::SummationDecapode) + # Init dists. + V = nparts(d, :Var) + dist = fill(Inf, (V, V)) + foreach(parts(d,:Op1)) do e + dist[d[e,:src], d[e,:tgt]] = 1 + end + foreach(parts(d,:Op2)) do e + dist[d[e,:proj1], d[e,:res]] = 1 + dist[d[e,:proj2], d[e,:res]] = 1 + end + foreach(parts(d,:Summand)) do e + dist[d[e,:summand], d[e,[:summation, :sum]]] = 1 + end + for v in 1:V + dist[v,v] = 0 + end + # Floyd-Warshall + for k in 1:V + for i in 1:V + for j in 1:V + if dist[i,j] > dist[i,k] + dist[k,j] + dist[i,j] = dist[i,k] + dist[k,j] + end end end end - - @assert length(op_order) == number_of_ops(d) - - op_order + dist end -function number_of_ops(d::SummationDecapode) - return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) +function topological_sort_verts(d::SummationDecapode) + m = floyd_warshall(d) + map(parts(d,:Var)) do v + maximum(filter(!isinf, m[start_nodes(d),v])) + end + # Call sortperm for the vertex ordering. end -function start_nodes(d::SummationDecapode) - return vcat(infer_states(d), incident(d, :Literal, :type)) +function topological_sort_edges(d::SummationDecapode) + tsv = topological_sort_verts(d) + op_order = [TraversalNode.(parts(d,:Op1), :Op1)..., + TraversalNode.(parts(d,:Op2), :Op2)..., + TraversalNode.(parts(d,:Σ), :Σ)...] + function by(x) + @match x.name begin + :Op1 => tsv[d[x.index,:src]] + :Op2 => max(tsv[d[x.index,:proj1]], tsv[d[x.index,:proj1]]) + :Σ => maximum(tsv[d[incident(d,x.index,:summation),:summand]]) + _ => error("Unknown function type") + end + end + sort(op_order, by = by) end function retrieve_name(d::SummationDecapode, tsr::TraversalNode) @@ -70,3 +76,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) _ => error("$(tsr.name) is not a valid table for names") end end + diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl index 7a5d6ac..825de65 100644 --- a/test/graph_traversal.jl +++ b/test/graph_traversal.jl @@ -29,11 +29,14 @@ end @test retrieve_name(multi_op1_deca, edge) == test_name end + # XXX Do cycle-detection with FW by using ∞ on the diagonal. + #= cyclic = @decapode begin B == g(A) A == f(B) end @test_throws AssertionError topological_sort_edges(cyclic) + =# just_op2 = @decapode begin C == A * B From 31dc6708f302cb84c4eae30ef3a0ca88aea4a624 Mon Sep 17 00:00:00 2001 From: Luke Morris <70283489+lukem12345@users.noreply.github.com> Date: Sat, 24 Aug 2024 02:29:21 -0400 Subject: [PATCH 08/16] Remove unnecessary DataStructures dependency --- src/graph_traversal.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 322d3ed..1600729 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,6 +1,5 @@ using DiagrammaticEquations using ACSets -using DataStructures export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name From d46cf86a61b8de7821cdc3dd4628c01328f69ade Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sat, 24 Aug 2024 12:30:37 -0400 Subject: [PATCH 09/16] Add docstrings and use multiple dispatch --- src/graph_traversal.jl | 63 +++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 1600729..e5cf2a7 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -3,16 +3,39 @@ using ACSets export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name -struct TraversalNode{T} +abstract type TraversalNode end + +struct Op1TravNode <: TraversalNode + index::Int +end +struct Op2TravNode <: TraversalNode + index::Int +end +struct ΣTravNode <: TraversalNode index::Int - name::T end +retrieve_name(d, tsr::Op1TravNode) = d[tsr.index, :op1] +retrieve_name(d, tsr::Op2TravNode) = d[tsr.index, :op2] +retrieve_name(d, tsr::ΣTravNode) = :+ + +# Induce a topological ordering of operations from one of variables. +edge2cost(d, tsv, tsr::Op1TravNode) = tsv[d[tsr.index,:src]] +edge2cost(d, tsv, tsr::Op2TravNode) = max(tsv[d[tsr.index,:proj1]], tsv[d[tsr.index,:proj1]]) +edge2cost(d, tsv, tsr::ΣTravNode) = maximum(tsv[d[incident(d,tsr.index,:summation),:summand]]) + number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) -#https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm#Pseudocode +""" function floyd_warshall(d::SummationDecapode) + +Return a |variable| × |variable| matrix of shortest paths via the Floyd-Warshall algorithm. + +Taking the maximum of the non-infinite short paths from state variables induces a topological ordering. + +https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm +""" function floyd_warshall(d::SummationDecapode) # Init dists. V = nparts(d, :Var) @@ -43,36 +66,26 @@ function floyd_warshall(d::SummationDecapode) dist end +""" function topological_sort_verts(d::SummationDecapode) + +Topologically sort the variables in a Decapode. + +The vector returned by this function maps each vertex to the order that it would be traversed in a topological sort traversal. If you want a list of vertices in the order of traversal, call `sortperm` on the output. +""" function topological_sort_verts(d::SummationDecapode) m = floyd_warshall(d) map(parts(d,:Var)) do v maximum(filter(!isinf, m[start_nodes(d),v])) end - # Call sortperm for the vertex ordering. end +""" function topological_sort_edges(d::SummationDecapode) + +Topologically sort the edges in a Decapode. +""" function topological_sort_edges(d::SummationDecapode) tsv = topological_sort_verts(d) - op_order = [TraversalNode.(parts(d,:Op1), :Op1)..., - TraversalNode.(parts(d,:Op2), :Op2)..., - TraversalNode.(parts(d,:Σ), :Σ)...] - function by(x) - @match x.name begin - :Op1 => tsv[d[x.index,:src]] - :Op2 => max(tsv[d[x.index,:proj1]], tsv[d[x.index,:proj1]]) - :Σ => maximum(tsv[d[incident(d,x.index,:summation),:summand]]) - _ => error("Unknown function type") - end - end - sort(op_order, by = by) -end - -function retrieve_name(d::SummationDecapode, tsr::TraversalNode) - @match tsr.name begin - :Op1 => d[tsr.index, :op1] - :Op2 => d[tsr.index, :op2] - :Σ => :+ - _ => error("$(tsr.name) is not a valid table for names") - end + op_order = [Op1TravNode.(parts(d,:Op1))..., Op2TravNode.(parts(d,:Op2))..., ΣTravNode.(parts(d,:Σ))...] + sort(op_order, by = x -> edge2cost(d,tsv,x)) end From db988a29e4110167ec8fed58fa87105d37de5717 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sat, 24 Aug 2024 14:12:32 -0400 Subject: [PATCH 10/16] Explicitly create hyperedge list --- src/graph_traversal.jl | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index e5cf2a7..dc60453 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -15,10 +15,12 @@ struct ΣTravNode <: TraversalNode index::Int end +# TODO: If the name of an operator is cached, only one definition is necessary. retrieve_name(d, tsr::Op1TravNode) = d[tsr.index, :op1] retrieve_name(d, tsr::Op2TravNode) = d[tsr.index, :op2] retrieve_name(d, tsr::ΣTravNode) = :+ +# TODO: If the codomain of an operator is cached, only one definition is necessary. # Induce a topological ordering of operations from one of variables. edge2cost(d, tsv, tsr::Op1TravNode) = tsv[d[tsr.index,:src]] edge2cost(d, tsv, tsr::Op2TravNode) = max(tsv[d[tsr.index,:proj1]], tsv[d[tsr.index,:proj1]]) @@ -28,6 +30,31 @@ number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) + +# TODO: Domain-Codomain pointers could be upstreamed into TraversalNode subtypes. Then this can be cleaned up with multiple dispatch. This would explicitly cache calls to incident. +# TODO: This could be Catlab'd. Hypergraph category? Migration to a DWD? +""" function hyper_edge_list(d::SummationDecapode) + +Represent a Decapode as a directed hyper-edge list. + +Interpret a: + - unary operation as a hyperedge of order (1,1) , + - binary operation as a hyperedge of order (2,1) , and + - summation as a hyperedge of order (|summands|,1) . +""" +function hyper_edge_list(d::SummationDecapode) + reduce(vcat, [ + map(parts(d,:Op1)) do e + ([d[e,:src]], d[e,:tgt]) + end, + map(parts(d,:Op2)) do e + ([d[e,:proj1],d[e,:proj2]], d[e,:res]) + end, + map(parts(d,:Σ)) do e + (d[incident(d,e,:summation),:summand], d[e,:sum]) + end]) +end + """ function floyd_warshall(d::SummationDecapode) Return a |variable| × |variable| matrix of shortest paths via the Floyd-Warshall algorithm. @@ -37,18 +64,11 @@ Taking the maximum of the non-infinite short paths from state variables induces https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm """ function floyd_warshall(d::SummationDecapode) - # Init dists. + # Init dists V = nparts(d, :Var) dist = fill(Inf, (V, V)) - foreach(parts(d,:Op1)) do e - dist[d[e,:src], d[e,:tgt]] = 1 - end - foreach(parts(d,:Op2)) do e - dist[d[e,:proj1], d[e,:res]] = 1 - dist[d[e,:proj2], d[e,:res]] = 1 - end - foreach(parts(d,:Summand)) do e - dist[d[e,:summand], d[e,[:summation, :sum]]] = 1 + foreach(hyper_edge_list(d)) do e + dist[e...] .= 1 end for v in 1:V dist[v,v] = 0 From 0b52472f68ca162adf927bac5f047dcfaea33608 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sat, 24 Aug 2024 14:51:24 -0400 Subject: [PATCH 11/16] Clean by caching hyperedge order --- src/graph_traversal.jl | 56 ++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index dc60453..68a5dab 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -3,35 +3,33 @@ using ACSets export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name -abstract type TraversalNode end - -struct Op1TravNode <: TraversalNode - index::Int -end -struct Op2TravNode <: TraversalNode - index::Int -end -struct ΣTravNode <: TraversalNode +struct TraversalNode{T} index::Int + name::T + dom::AbstractVector + cod::Int end -# TODO: If the name of an operator is cached, only one definition is necessary. -retrieve_name(d, tsr::Op1TravNode) = d[tsr.index, :op1] -retrieve_name(d, tsr::Op2TravNode) = d[tsr.index, :op2] -retrieve_name(d, tsr::ΣTravNode) = :+ +TraversalNode(i, d::SummationDecapode, ::Val{:Op1}) = + TraversalNode{Symbol}(i, d[i,:op1], [d[i,:src]], d[i,:tgt]) + +TraversalNode(i, d::SummationDecapode, ::Val{:Op2}) = + TraversalNode{Symbol}(i, d[i,:op2], [d[i,:proj1],d[i,:proj2]], d[i,:res]) + +TraversalNode(i, d::SummationDecapode, ::Val{:Σ}) = + TraversalNode{Symbol}(i, :+, d[incident(d,i,:summation),:summand], d[i,:sum]) + +retrieve_name(d::SummationDecapode, tsr::TraversalNode) = tsr.name -# TODO: If the codomain of an operator is cached, only one definition is necessary. -# Induce a topological ordering of operations from one of variables. -edge2cost(d, tsv, tsr::Op1TravNode) = tsv[d[tsr.index,:src]] -edge2cost(d, tsv, tsr::Op2TravNode) = max(tsv[d[tsr.index,:proj1]], tsv[d[tsr.index,:proj1]]) -edge2cost(d, tsv, tsr::ΣTravNode) = maximum(tsv[d[incident(d,tsr.index,:summation),:summand]]) +# Induce a topological ordering of operations from a topological ordering of variables. +# Taking Vᵢ(cod(e)ᵢ) like so is a structure preserving map. +edge2cost(tsv, tsr::TraversalNode) = maximum(tsv[tsr.dom]) number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) -# TODO: Domain-Codomain pointers could be upstreamed into TraversalNode subtypes. Then this can be cleaned up with multiple dispatch. This would explicitly cache calls to incident. # TODO: This could be Catlab'd. Hypergraph category? Migration to a DWD? """ function hyper_edge_list(d::SummationDecapode) @@ -43,16 +41,9 @@ Interpret a: - summation as a hyperedge of order (|summands|,1) . """ function hyper_edge_list(d::SummationDecapode) - reduce(vcat, [ - map(parts(d,:Op1)) do e - ([d[e,:src]], d[e,:tgt]) - end, - map(parts(d,:Op2)) do e - ([d[e,:proj1],d[e,:proj2]], d[e,:res]) - end, - map(parts(d,:Σ)) do e - (d[incident(d,e,:summation),:summand], d[e,:sum]) - end]) + [map(e -> TraversalNode(e, d, Val(:Op1)), parts(d, :Op1))..., + map(e -> TraversalNode(e, d, Val(:Op2)), parts(d, :Op2))..., + map(e -> TraversalNode(e, d, Val(:Σ )), parts(d, :Σ ))...] end """ function floyd_warshall(d::SummationDecapode) @@ -68,7 +59,7 @@ function floyd_warshall(d::SummationDecapode) V = nparts(d, :Var) dist = fill(Inf, (V, V)) foreach(hyper_edge_list(d)) do e - dist[e...] .= 1 + dist[(e.dom), e.cod] .= 1 end for v in 1:V dist[v,v] = 0 @@ -99,13 +90,14 @@ function topological_sort_verts(d::SummationDecapode) end end +# TODO: Add in-place version for sorting a given hyperedge list. """ function topological_sort_edges(d::SummationDecapode) Topologically sort the edges in a Decapode. """ function topological_sort_edges(d::SummationDecapode) tsv = topological_sort_verts(d) - op_order = [Op1TravNode.(parts(d,:Op1))..., Op2TravNode.(parts(d,:Op2))..., ΣTravNode.(parts(d,:Σ))...] - sort(op_order, by = x -> edge2cost(d,tsv,x)) + op_order = hyper_edge_list(d) + sort(op_order, by = x -> edge2cost(tsv,x)) end From 7daf15b9e4aad39be8295d01999683357cfa65fe Mon Sep 17 00:00:00 2001 From: Luke Morris <70283489+lukem12345@users.noreply.github.com> Date: Sat, 24 Aug 2024 14:55:11 -0400 Subject: [PATCH 12/16] Fix comment typo cod -> dom --- src/graph_traversal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 68a5dab..5e8c794 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -22,7 +22,7 @@ TraversalNode(i, d::SummationDecapode, ::Val{:Σ}) = retrieve_name(d::SummationDecapode, tsr::TraversalNode) = tsr.name # Induce a topological ordering of operations from a topological ordering of variables. -# Taking Vᵢ(cod(e)ᵢ) like so is a structure preserving map. +# Taking Vᵢ(dom(e)ᵢ) like so is a structure preserving map. edge2cost(tsv, tsr::TraversalNode) = maximum(tsv[tsr.dom]) number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) From 303962d5b983891074165786f34d4522b457c8d5 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sun, 25 Aug 2024 19:40:08 -0400 Subject: [PATCH 13/16] Compute longest paths from terminals --- src/DiagrammaticEquations.jl | 2 +- src/graph_traversal.jl | 7 ++++--- test/graph_traversal.jl | 6 ++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 5ab376a..978428c 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -5,7 +5,7 @@ module DiagrammaticEquations using Catlab export -DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, +DerivOp, append_dot, normalize_unicode, infer_states, infer_terminals, infer_types!, # Deca op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 5e8c794..4e61538 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -29,7 +29,6 @@ number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) - # TODO: This could be Catlab'd. Hypergraph category? Migration to a DWD? """ function hyper_edge_list(d::SummationDecapode) @@ -55,11 +54,13 @@ Taking the maximum of the non-infinite short paths from state variables induces https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm """ function floyd_warshall(d::SummationDecapode) + # Define weights. + w(e) = (length(e.dom) == 1 && e.name ∈ [:∂ₜ,:dt]) ? -Inf : -1 # Init dists V = nparts(d, :Var) dist = fill(Inf, (V, V)) foreach(hyper_edge_list(d)) do e - dist[(e.dom), e.cod] .= 1 + dist[(e.dom), e.cod] .= w(e) end for v in 1:V dist[v,v] = 0 @@ -86,7 +87,7 @@ The vector returned by this function maps each vertex to the order that it would function topological_sort_verts(d::SummationDecapode) m = floyd_warshall(d) map(parts(d,:Var)) do v - maximum(filter(!isinf, m[start_nodes(d),v])) + minimum(filter(!isinf, m[v,infer_terminals(d)])) end end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl index 825de65..9160b05 100644 --- a/test/graph_traversal.jl +++ b/test/graph_traversal.jl @@ -64,4 +64,10 @@ end end result = topological_sort_edges(op_combo) @test is_correct_length(op_combo, result) + + sum_with_single_dependency = @decapode begin + F == A + f(A) + h(g(A)) + end + result = topological_sort_edges(sum_with_single_dependency) + @test is_correct_length(sum_with_single_dependency, result) end From dfe0e03e2cb123de3753db2eddce46eaf867fca4 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Tue, 27 Aug 2024 19:20:45 -0400 Subject: [PATCH 14/16] Just use -1 as weight --- src/graph_traversal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 4e61538..96e2ccc 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -55,7 +55,7 @@ https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm """ function floyd_warshall(d::SummationDecapode) # Define weights. - w(e) = (length(e.dom) == 1 && e.name ∈ [:∂ₜ,:dt]) ? -Inf : -1 + w(e) = -1 # Init dists V = nparts(d, :Var) dist = fill(Inf, (V, V)) From db2274f3d69866a4f142060655df918d7352eccf Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:24:05 -0400 Subject: [PATCH 15/16] Addition of generic graph struct This struct organizes the data into a more generic hypergraph that can then be routed through generic graph algorithms, like topo sort or F-W, without relying on the underlying ACSet structure. --- src/DiagrammaticEquations.jl | 2 +- src/acset2symbolic.jl | 56 +++++++- src/graph_interface.jl | 133 ++++++++++++++++++ ...{graph_traversal.jl => graph_interface.jl} | 42 +++--- test/runtests.jl | 2 +- 5 files changed, 214 insertions(+), 21 deletions(-) create mode 100644 src/graph_interface.jl rename test/{graph_traversal.jl => graph_interface.jl} (58%) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 978428c..aa6468c 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -62,7 +62,7 @@ include("rewrite.jl") include("pretty.jl") include("colanguage.jl") include("openoperators.jl") -include("graph_traversal.jl") +include("graph_interface.jl") include("acset2symbolic.jl") include("deca/Deca.jl") include("learn/Learn.jl") diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index a2b24e6..101b2c5 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -4,9 +4,50 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle +import DiagrammaticEquations: HyperGraph, HyperGraphEdge, HyperGraphVertex, vertex_list, edge_list +import DiagrammaticEquations: topological_sort_edges + +export TableData, extract_symexprs, number_of_ops, retrieve_name + const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) +# Decapode graph conversion +struct TableData + table_index::Int + table_name::Symbol +end + +HyperGraphVertex(d::SummationDecapode, index::Int) = HyperGraphVertex(index, TableData(index, :Var)) + +HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Op1}) = HyperGraphEdge(d[index, :tgt], [d[index, :src]], TableData(index, :Op1)) +HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Op2}) = HyperGraphEdge(d[index, :res], [d[index,:proj1],d[index,:proj2]], TableData(index, :Op2)) +HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Σ}) = HyperGraphEdge(d[index, :sum], d[incident(d, index, :summation), :summand], TableData(index, :Σ)) + +HyperGraph(d::SummationDecapode) = HyperGraph(vertex_list(d), edge_list(d), nothing) + +vertex_list(d::SummationDecapode) = map(id -> HyperGraphVertex(d, id), parts(d, :Var)) + +function edge_list(d::SummationDecapode) + edges = HyperGraphEdge[] + for op_table in [:Op1, :Op2, :Σ] + for op in parts(d, op_table) + if op_table == :Op1 && d[op, :op1] == DerivOp + continue + end + push!(edges, HyperGraphEdge(d, op, Val(op_table))) + end + end + edges +end + +table_data(v::HyperGraphVertex) = v.metadata +table_data(v::HyperGraphEdge) = v.metadata + +topological_sort_edges(d::SummationDecapode) = table_data.(topological_sort_edges(HyperGraph(d))) + +# Decapode ACSet symbolics conversion + +to_symbolics(d::SummationDecapode, data::TableData) = to_symbolics(d, data.table_index, Val(data.table_name)) function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name]) @@ -54,6 +95,19 @@ function apply_rewrites(d::SummationDecapode, rewriter) rewritten_list end +function number_of_ops(d::SummationDecapode) + return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) +end + +function retrieve_name(d::SummationDecapode, data::TableData) + @match data.table_name begin + :Op1 => d[data.table_index, :op1] + :Op2 => d[data.table_index, :op2] + :Σ => :+ + _ => error("$(data.table_name) is a table without names") + end +end + # TODO: We need a way to get information like the d and ⋆ even when not in the ACSet # @syms Δ(x) d(x) ⋆(x) # lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) diff --git a/src/graph_interface.jl b/src/graph_interface.jl new file mode 100644 index 0000000..678f196 --- /dev/null +++ b/src/graph_interface.jl @@ -0,0 +1,133 @@ +using DiagrammaticEquations +using ACSets + +export HyperGraph, HyperGraphVertex, HyperGraphEdge, vertex_list, edge_list +export topological_sort_edges, floyd_warshall + +struct HyperGraphVertex + id::Int + metadata +end + +# Assuming we only have a single target +struct HyperGraphEdge + tgt::Int + srcs::AbstractVector{Int} + metadata +end + +struct HyperGraph + vertices::AbstractVector{HyperGraphVertex} + edges::AbstractVector{HyperGraphEdge} + metadata +end + +# Returns a list of all vertices from ACSet as HyperGraphVertex +function vertex_list() end + +# Returns a list of all edges from ACSet as HyperGraphEdge +function edge_list() end + +num_vertices(g::HyperGraph) = length(g.vertices) +num_edges(g::HyperGraph) = length(g.edges) + +# TODO: Clean this up to use better logic +function start_nodes(g::HyperGraph) + indices = HyperGraphVertex[] + + for vertex in g.vertices + v_id = vertex.id + + is_tgt = true + for edge in g.edges + if v_id == edge.tgt + is_tgt = false + break + end + end + + if is_tgt + push!(indices, vertex) + end + + end + + indices +end + +function has_unique_targets(g::HyperGraph) + seen_vertices = Set{Int}() + for edge in g.edges + if edge.tgt in seen_vertices + return false + end + push!(seen_vertices, edge.tgt) + end + return true +end + +vertex_id(v::HyperGraphVertex) = return v.id + +function topological_sort_edges(g::HyperGraph) + @assert has_unique_targets(g) + + visited_vertices = falses(num_vertices(g)) + visited_vertices[vertex_id.(start_nodes(g))] .= true + + visited_edges = falses(num_edges(g)) + + edge_order = HyperGraphEdge[] + + for _ in 1:num_edges(g) + for (idx, edge) in enumerate(g.edges) + if !visited_edges[idx] && all(visited_vertices[edge.srcs]) + visited_edges[idx] = true + visited_vertices[edge.tgt] = true + + push!(edge_order, edge) + end + end + end + + @assert length(edge_order) == num_edges(g) + + edge_order +end + +""" +floyd_warshall(g::HyperGraph) + +Return a |variable| × |variable| matrix of shortest paths via the Floyd-Warshall algorithm. + +Taking the maximum of the non-infinite short paths from state variables induces a topological ordering. + +https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm +""" +function floyd_warshall(g::HyperGraph) + # Define weights. + w(e) = -1 + + # Init dists + V = num_vertices(g) + dist = fill(Inf, (V, V)) + foreach(g.edges) do e + dist[(e.srcs), e.tgt] .= w(e) + end + + for v in 1:V + dist[v,v] = 0 + end + + # Floyd-Warshall + for k in 1:V + for i in 1:V + for j in 1:V + if dist[i,j] > dist[i,k] + dist[k,j] + dist[i,j] = dist[i,k] + dist[k,j] + end + end + end + end + + dist +end diff --git a/test/graph_traversal.jl b/test/graph_interface.jl similarity index 58% rename from test/graph_traversal.jl rename to test/graph_interface.jl index 9160b05..ccc1ce8 100644 --- a/test/graph_traversal.jl +++ b/test/graph_interface.jl @@ -3,8 +3,17 @@ using ACSets using MLStyle using Test -function is_correct_length(d::SummationDecapode, result) - return length(result) == number_of_ops(d) +function is_topo_sort_ordered(result::AbstractVector{TableData}) + seen_edges = Dict{Symbol, Int}(:Op1 => 0, :Op2 => 0, :Σ => 0) + for entry in result + table = entry.table_name + prev_seen = seen_edges[table] + if !(prev_seen < entry.table_index) + return false + end + seen_edges[table] = entry.table_index + end + return true end @testset "Topological Sort on Edges" begin @@ -13,61 +22,58 @@ end end @test isempty(topological_sort_edges(no_edge)) - one_op1_deca = @decapode begin + one_op1 = @decapode begin F == f(S) end - result = topological_sort_edges(one_op1_deca) - @test is_correct_length(one_op1_deca, result) - @test retrieve_name(one_op1_deca, only(result)) == :f + result = topological_sort_edges(one_op1) + @test retrieve_name(one_op1, only(result)) == :f + @test is_topo_sort_ordered(result) - multi_op1_deca = @decapode begin + multi_op1 = @decapode begin F == c(b(a(S))) end - result = topological_sort_edges(multi_op1_deca) - @test is_correct_length(multi_op1_deca, result) + result = topological_sort_edges(multi_op1) for (edge, test_name) in zip(result, [:a, :b, :c]) - @test retrieve_name(multi_op1_deca, edge) == test_name + @test retrieve_name(multi_op1, edge) == test_name end + @test is_topo_sort_ordered(result) - # XXX Do cycle-detection with FW by using ∞ on the diagonal. - #= cyclic = @decapode begin B == g(A) A == f(B) end @test_throws AssertionError topological_sort_edges(cyclic) - =# just_op2 = @decapode begin C == A * B end result = topological_sort_edges(just_op2) - @test is_correct_length(just_op2, result) @test retrieve_name(just_op2, only(result)) == :* + @test is_topo_sort_ordered(result) just_simple_sum = @decapode begin C == A + B end result = topological_sort_edges(just_simple_sum) - @test is_correct_length(just_simple_sum, result) @test retrieve_name(just_simple_sum, only(result)) == :+ + @test is_topo_sort_ordered(result) just_multi_sum = @decapode begin F == A + B + C + D + E end result = topological_sort_edges(just_multi_sum) - @test is_correct_length(just_multi_sum, result) @test retrieve_name(just_multi_sum, only(result)) == :+ + @test is_topo_sort_ordered(result) op_combo = @decapode begin F == h(d(A) + f(g(B) * C) + D) end result = topological_sort_edges(op_combo) - @test is_correct_length(op_combo, result) + @test is_topo_sort_ordered(result) sum_with_single_dependency = @decapode begin F == A + f(A) + h(g(A)) end result = topological_sort_edges(sum_with_single_dependency) - @test is_correct_length(sum_with_single_dependency, result) + @test is_topo_sort_ordered(result) end diff --git a/test/runtests.jl b/test/runtests.jl index defc285..7dc8b14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,5 +41,5 @@ end end @testset "Symbolic Rewriting" begin - include("graph_traversal.jl") + include("graph_interface.jl") end From a93295a3d5e247618f83675a486ae817ca0321e2 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:25:31 -0400 Subject: [PATCH 16/16] Remove old graph_traversal file --- src/graph_traversal.jl | 104 ----------------------------------------- 1 file changed, 104 deletions(-) delete mode 100644 src/graph_traversal.jl diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl deleted file mode 100644 index 96e2ccc..0000000 --- a/src/graph_traversal.jl +++ /dev/null @@ -1,104 +0,0 @@ -using DiagrammaticEquations -using ACSets - -export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name - -struct TraversalNode{T} - index::Int - name::T - dom::AbstractVector - cod::Int -end - -TraversalNode(i, d::SummationDecapode, ::Val{:Op1}) = - TraversalNode{Symbol}(i, d[i,:op1], [d[i,:src]], d[i,:tgt]) - -TraversalNode(i, d::SummationDecapode, ::Val{:Op2}) = - TraversalNode{Symbol}(i, d[i,:op2], [d[i,:proj1],d[i,:proj2]], d[i,:res]) - -TraversalNode(i, d::SummationDecapode, ::Val{:Σ}) = - TraversalNode{Symbol}(i, :+, d[incident(d,i,:summation),:summand], d[i,:sum]) - -retrieve_name(d::SummationDecapode, tsr::TraversalNode) = tsr.name - -# Induce a topological ordering of operations from a topological ordering of variables. -# Taking Vᵢ(dom(e)ᵢ) like so is a structure preserving map. -edge2cost(tsv, tsr::TraversalNode) = maximum(tsv[tsr.dom]) - -number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) - -start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) - -# TODO: This could be Catlab'd. Hypergraph category? Migration to a DWD? -""" function hyper_edge_list(d::SummationDecapode) - -Represent a Decapode as a directed hyper-edge list. - -Interpret a: - - unary operation as a hyperedge of order (1,1) , - - binary operation as a hyperedge of order (2,1) , and - - summation as a hyperedge of order (|summands|,1) . -""" -function hyper_edge_list(d::SummationDecapode) - [map(e -> TraversalNode(e, d, Val(:Op1)), parts(d, :Op1))..., - map(e -> TraversalNode(e, d, Val(:Op2)), parts(d, :Op2))..., - map(e -> TraversalNode(e, d, Val(:Σ )), parts(d, :Σ ))...] -end - -""" function floyd_warshall(d::SummationDecapode) - -Return a |variable| × |variable| matrix of shortest paths via the Floyd-Warshall algorithm. - -Taking the maximum of the non-infinite short paths from state variables induces a topological ordering. - -https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm -""" -function floyd_warshall(d::SummationDecapode) - # Define weights. - w(e) = -1 - # Init dists - V = nparts(d, :Var) - dist = fill(Inf, (V, V)) - foreach(hyper_edge_list(d)) do e - dist[(e.dom), e.cod] .= w(e) - end - for v in 1:V - dist[v,v] = 0 - end - # Floyd-Warshall - for k in 1:V - for i in 1:V - for j in 1:V - if dist[i,j] > dist[i,k] + dist[k,j] - dist[i,j] = dist[i,k] + dist[k,j] - end - end - end - end - dist -end - -""" function topological_sort_verts(d::SummationDecapode) - -Topologically sort the variables in a Decapode. - -The vector returned by this function maps each vertex to the order that it would be traversed in a topological sort traversal. If you want a list of vertices in the order of traversal, call `sortperm` on the output. -""" -function topological_sort_verts(d::SummationDecapode) - m = floyd_warshall(d) - map(parts(d,:Var)) do v - minimum(filter(!isinf, m[v,infer_terminals(d)])) - end -end - -# TODO: Add in-place version for sorting a given hyperedge list. -""" function topological_sort_edges(d::SummationDecapode) - -Topologically sort the edges in a Decapode. -""" -function topological_sort_edges(d::SummationDecapode) - tsv = topological_sort_verts(d) - op_order = hyper_edge_list(d) - sort(op_order, by = x -> edge2cost(tsv,x)) -end -