diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 72cada6..558a839 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -69,9 +69,10 @@ include("graph_traversal.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") -include("acset2symbolic.jl") @reexport using .Deca @reexport using .SymbolicUtilsInterop +include("acset2symbolic.jl") + end diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 61502a8..28bde3f 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -14,6 +14,7 @@ struct SymbolicEquation{E} lhs::E rhs::E end +export SymbolicEquation Base.show(io::IO, e::SymbolicEquation) = begin print(io, e.lhs); print(io, " == "); print(io, e.rhs) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 0bf24fd..e5edd1c 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -5,14 +5,16 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle -export extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup +import SymbolicUtils: BasicSymbolic, Symbolic + +export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) +to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) function symbolics_lookup(d::SummationDecapode) - lookup = Dict{Symbol, SymbolicUtils.BasicSymbolic}() + lookup = Dict{Symbol, BasicSymbolic}() for i in parts(d, :Var) push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) end @@ -22,80 +24,69 @@ end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) var = d[index, :name] new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - @info new_type + SymbolicUtils.Sym{new_type}(var) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1}) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1}) input_sym = symvar_lookup[d[d[op_index, :src], :name]] output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) op_sym = getfield(@__MODULE__, d[op_index, :op1]) S = promote_symtype(op_sym, input_sym) rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) - SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) + SymbolicEquation{Symbolic}(output_sym, rhs) end -# TODO add promote_symtype as Op1 -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2}) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2}) input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] output_sym = symvar_lookup[d[d[op_index, :res], :name]] - # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) op_sym = getfield(@__MODULE__, d[op_index, :op2]) S = promote_symtype(op_sym, input1_sym, input2_sym) rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) - SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) + SymbolicEquation{Symbolic}(output_sym, rhs) end #XXX: Always converting + -> .+ here since summation doesn't store the style of addition -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Σ}) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ}) syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] output_sym = symvar_lookup[d[d[op_index, :sum], :name]] # TODO pls test S = promote_symtype(+, syms_array...) rhs = SymbolicUtils.Term{S}(+, syms_array) - SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) + SymbolicEquation{Symbolic}(output_sym,rhs) end -function symbolic_rewriting(old_d::SummationDecapode) +function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = deepcopy(old_d) infer_types!(d) - # resolve_overloads!(d) symvar_lookup = symbolics_lookup(d) - merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) -end + eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) -function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}) - topo_list = topological_sort_edges(d) - sym_list = [] - for node in topo_list - retrieve_name(d, node) != DerivOp || continue - push!(sym_list, to_symbolics(d, symvar_lookup, node)) + if !isnothing(rewriter) + eqns = map(rewriter, eqns) end - sym_list -end -function apply_rewrites(symexprs, rewriter) + to_acset(d, eqns) +end - rewritten_list = [] - for sym in symexprs - res_sym = rewriter(sym) - rewritten_sym = isnothing(res_sym) ? sym : res_sym - push!(rewritten_list, rewritten_sym) +function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) + sym_list = SymbolicEquation{Symbolic}[] + for node in topological_sort_edges(d) + retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC + push!(sym_list, to_symbolics(d, symvar_lookup, node)) end - - rewritten_list + sym_list end -function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, rewritten_syms) +function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) eqn_lookup = Dict() @@ -108,22 +99,35 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbo final_nodes = infer_terminal_names(d) - for expr in rewritten_syms + for expr in symexpr_list - merged_eqn = SymbolicUtils.substitute(expr, eqn_lookup) - lhs = merged_eqn.arguments[1] - rhs = merged_eqn.arguments[2] + merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (lhs => rhs)) + push!(eqn_lookup, (expr.lhs => merged_rhs)) - if lhs.name in final_nodes - push!(final_list, merged_eqn) + if expr.lhs.name in final_nodes + push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) end end final_list end +formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) + +function apply_rewrites(symexprs, rewriter) + + rewritten_list = [] + for sym in symexprs + res_sym = rewriter(sym) + rewritten_sym = isnothing(res_sym) ? sym : res_sym + push!(rewritten_list, rewritten_sym) + end + + rewritten_list +end + + function to_acset(og_d, sym_exprs) final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) @@ -159,7 +163,6 @@ function to_acset(og_d, sym_exprs) d = SummationDecapode(parse_decapode(deca_block)) infer_types!(d) - # resolve_overloads!(d) d end diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 43b8860..f2875b2 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name, start_nodes +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes struct TraversalNode{T} index::Int @@ -20,7 +20,7 @@ function topological_sort_edges(d::SummationDecapode) # FIXME: this is a quadratic implementation of topological_sort inlined in here. op_order = TraversalNode{Symbol}[] - for _ in 1:number_of_ops(d) + for _ in 1:n_ops(d) for op in parts(d, :Op1) if !visited_1[op] && visited_Var[d[op, :src]] @@ -49,12 +49,12 @@ function topological_sort_edges(d::SummationDecapode) end end - @assert length(op_order) == number_of_ops(d) + @assert length(op_order) == n_ops(d) op_order end -function number_of_ops(d::SummationDecapode) +function n_ops(d::SummationDecapode) return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) end diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index 73f4cd9..cec55de 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -8,9 +8,15 @@ Heat = @decapode begin C::Form0 G::Form1 D::Constant - ∂ₜ(G) == D*Δ(d(C)) + ∂ₜ(G) == D*Δ(C) end; infer_types!(Heat) +test_heat_same = symbolic_rewriting(Heat) + +r = rules(Δ, Val(1)) + +rwr = Fixpoint(Prewalk(Chain(r))) +test_heat_open = symbolic_rewriting(Heat, rwr) Brusselator = @decapode begin (U, V)::Form0 @@ -35,24 +41,7 @@ Phytodynamics = @decapode begin ∂ₜ(n) == w + m*n + Δ(n) end infer_types!(Phytodynamics) -test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics)) - -# resolve_overloads!(Heat) - -# lap_0_convert = @rule Δ₀(~x) => Δ(~x) -# lap_1_convert = @rule Δ₁(~x) => Δ(~x) -# lap_2_convert = @rule Δ₂(~x) => Δ(~x) - -# d_0_convert = @rule d₀(~x) => d(~x) - -# overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] - -# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) -# lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) -# lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) - -# openers = [lap_0_rule, lap_1_rule, lap_2_rule] - +test_phy = symbolic_rewriting(Phytodynamics) # it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms. heat_exprs = symbolic_rewriting(Heat) @@ -103,8 +92,6 @@ R1 = @rule Δ(~~x::(x->isForm1(x))) => ★(d(★(d(~x)))) # pulling out the subexpression rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r))) - - res_exprs = apply_rewrites(heat_exprs, rewriter) sub_exprs = apply_rewrites([sub], rewriter) diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl index 7a5d6ac..b09fd24 100644 --- a/test/graph_traversal.jl +++ b/test/graph_traversal.jl @@ -4,7 +4,7 @@ using MLStyle using Test function is_correct_length(d::SummationDecapode, result) - return length(result) == number_of_ops(d) + return length(result) == n_ops(d) end @testset "Topological Sort on Edges" begin