diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 802d182..a813a3b 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -5,7 +5,7 @@ using SymbolicUtils.Rewriters using SymbolicUtils.Code using MLStyle -export extract_symexprs, apply_rewrites, merge_equations, to_acset +export extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup const DECA_EQUALITY_SYMBOL = (==) @@ -21,14 +21,19 @@ end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) var = d[index, :name] - new_type = symtype(Deca.DECQuantity, d[index, :type], space) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) + @info new_type SymbolicUtils.Sym{new_type}(var) end function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1}) input_sym = symvar_lookup[d[d[op_index, :src], :name]] output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) + + op_sym = getfield(@__MODULE__, d[op_index, :op1]) + + @info typeof(op_sym) rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) @@ -38,26 +43,31 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] output_sym = symvar_lookup[d[d[op_index, :res], :name]] - op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + # op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) + + op_sym = getfield(@__MODULE__, d[op_index, :op2]) rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end #XXX: Always converting + -> .+ here since summation doesn't store the style of addition -# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ}) -# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) -# end +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Σ}) + syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] + output_sym = symvar_lookup[d[d[op_index, :sum], :name]] + + rhs = SymbolicUtils.Term{Number}(+, syms_array) + SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) +end function symbolic_rewriting(old_d::SummationDecapode) d = deepcopy(old_d) infer_types!(d) - resolve_overloads!(d) + # resolve_overloads!(d) symvar_lookup = symbolics_lookup(d) - - symexprs = extract_symexprs(d, symvar_lookup) + merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}) @@ -70,10 +80,10 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symb sym_list end -function apply_rewrites(d::SummationDecapode, rewriter) +function apply_rewrites(symexprs, rewriter) rewritten_list = [] - for sym in extract_symexprs(d) + for sym in symexprs res_sym = rewriter(sym) rewritten_sym = isnothing(res_sym) ? sym : res_sym push!(rewritten_list, rewritten_sym) @@ -113,7 +123,18 @@ end function to_acset(og_d, sym_exprs) final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - map(x -> x.args[1] = :(==), final_exprs) + + recursive_descent = @λ begin + e::Expr => begin + if e.head == :call + @show nameof(e.args[1]) + e.args[1] = nameof(e.args[1]) + map(recursive_descent, e.args[2:end]) + end + end + sym => nothing + end + map(recursive_descent, final_exprs) deca_block = quote end @@ -124,12 +145,18 @@ function to_acset(og_d, sym_exprs) append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) + for op1 in parts(og_d, :Op1) + if og_d[op1, :op1] == DerivOp + push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name])))) + end + end + append!(deca_block.args, final_exprs) d = SummationDecapode(parse_decapode(deca_block)) infer_types!(d) - resolve_overloads!(d) + # resolve_overloads!(d) d end diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 383274e..84186f8 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -178,6 +178,8 @@ end @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) end +@alias (Δ₀, Δ₁, Δ₂) => Δ + @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin (PatScalar(_), PatScalar(_)) => Scalar diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index de21b89..a4b6f2c 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -4,47 +4,71 @@ using MLStyle Heat = @decapode begin C::Form0 + G::Form1 D::Constant - ∂ₜ(C) == D*Δ(d(C)) + ∂ₜ(G) == D*Δ(d(C)) end infer_types!(Heat) -resolve_overloads!(Heat) -@syms Δ(x) d(x) ⋆(x) Δ₀(x) Δ₁(x) Δ₂(x) d₀(x) +Brusselator = @decapode begin + (U, V)::Form0 + U2V::Form0 + (U̇, V̇)::Form0 -lap_0_convert = @rule Δ₀(~x) => Δ(~x) -lap_1_convert = @rule Δ₁(~x) => Δ(~x) -lap_2_convert = @rule Δ₂(~x) => Δ(~x) + (α)::Constant + F::Parameter -d_0_convert = @rule d₀(~x) => d(~x) + U2V == (U .* U) .* V -overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] + U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F + V̇ == (3.4 * U) - U2V + (α * Δ(V)) + ∂ₜ(U) == U̇ + ∂ₜ(V) == V̇ +end +infer_types!(Brusselator) + +Phytodynamics = @decapode begin + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w + m*n + Δ(n) +end +infer_types!(Phytodynamics) +test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics)) + +# resolve_overloads!(Heat) + +# lap_0_convert = @rule Δ₀(~x) => Δ(~x) +# lap_1_convert = @rule Δ₁(~x) => Δ(~x) +# lap_2_convert = @rule Δ₂(~x) => Δ(~x) + +# d_0_convert = @rule d₀(~x) => d(~x) + +# overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert] -lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) -lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) -lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) +# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) +# lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x)))) +# lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) -openers = [lap_0_rule, lap_1_rule, lap_2_rule] +# openers = [lap_0_rule, lap_1_rule, lap_2_rule] -heat_exprs = extract_symexprs(Heat) +r = rules(Δ, Val(1)) -rewriter = SymbolicUtils.Postwalk( - SymbolicUtils.Fixpoint(SymbolicUtils.Chain(vcat(overloaders, openers)))) +heat_exprs = symbolic_rewriting(Heat) -res_exprs = apply_rewrites(Heat, rewriter) +rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(test_rule)) -merge_exprs = merge_equations(Heat, res_exprs) +res_exprs = apply_rewrites(heat_exprs, rewriter) optm_dd_0 = @rule d(d(~x)) => 0 -star_0 = @rule ⋆(0) => 0 +star_0 = @rule ★(0) => 0 d_0 = @rule d(0) => 0 optm_rewriter = SymbolicUtils.Postwalk( SymbolicUtils.Fixpoint(SymbolicUtils.Chain([optm_dd_0, star_0, d_0]))) -res_merge_exprs = map(optm_rewriter, merge_exprs) +res_merge_exprs = map(optm_rewriter, res_exprs) -deca_test = to_acset(Heat, res_merge_exprs) +deca_test = to_acset(Heat, res_exprs) infer_types!(deca_test) resolve_overloads!(deca_test)