From 90b1adc6f67c085a359df66e610e80dc167868f8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 19 Sep 2024 17:29:29 -0400 Subject: [PATCH] George and I debugged rewriting. Incorrect type passed to resulting term meant typed rewriting would fail --- src/acset2symbolic.jl | 13 ++++++---- src/sym_rewrite.jl | 56 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index a813a3b..0bf24fd 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -33,12 +33,12 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic op_sym = getfield(@__MODULE__, d[op_index, :op1]) - @info typeof(op_sym) - - rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) + S = promote_symtype(op_sym, input_sym) + rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end +# TODO add promote_symtype as Op1 function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2}) input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] @@ -47,7 +47,8 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic op_sym = getfield(@__MODULE__, d[op_index, :op2]) - rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) + S = promote_symtype(op_sym, input1_sym, input2_sym) + rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end @@ -56,7 +57,9 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] output_sym = symvar_lookup[d[d[op_index, :sum], :name]] - rhs = SymbolicUtils.Term{Number}(+, syms_array) + # TODO pls test + S = promote_symtype(+, syms_array...) + rhs = SymbolicUtils.Term{S}(+, syms_array) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index a4b6f2c..73f4cd9 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -1,5 +1,7 @@ +using Test using DiagrammaticEquations using SymbolicUtils +using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, symtype, promote_symtype using MLStyle Heat = @decapode begin @@ -7,8 +9,7 @@ Heat = @decapode begin G::Form1 D::Constant ∂ₜ(G) == D*Δ(d(C)) -end - +end; infer_types!(Heat) Brusselator = @decapode begin @@ -52,13 +53,60 @@ test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics)) # openers = [lap_0_rule, lap_1_rule, lap_2_rule] -r = rules(Δ, Val(1)) +# it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms. heat_exprs = symbolic_rewriting(Heat) +sub = heat_exprs[1].arguments[2].arguments[2] + +a, b = @syms a::Scalar b::Scalar +u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} + +r = rules(Δ, Val(1)) + +# rule without predication works +R = @rule Δ(~x) => ★(d(★(d(~x)))) +rwR = Fixpoint(Prewalk(Chain([R]))) + +R(Δ(d(u))) + +# since promote_symtype(d(u)) returns Any while promote_symtype(d, u). I wonder +# if `d(u)` is not subjected to `symtype` + +Rp = @rule Δ(~x::isForm1) => "Success" +Rp(Δ(v)) # works +Rp(Δ(d(u))) # works + +Rp1 = @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + +Rp1(Δ(v)) # works +Rp1(Δ(d(u))) # works +rwRp1 = Fixpoint(Prewalk(Chain([Rp1]))) +rwRp1(Δ(d(u))) + +rwr = Fixpoint(Prewalk(Chain(r))) +rwr(heat_exprs[1]) # THIS WORKS! + +rwr(Δ(d(u))) # rwr +rwr(heat_exprs[1].arguments[2]) + +r[2](Δ(d(u))) # works + + +# rwR(heat_exprs[1]) +# rwR(sub) + +# tilde? +R1 = @rule Δ(~~x::(x->isForm1(x))) => ★(d(★(d(~x)))) + +@macroexpand @rule Δ(~x::isForm1) => "Success" + +# pulling out the subexpression +rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r))) + -rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(test_rule)) res_exprs = apply_rewrites(heat_exprs, rewriter) +sub_exprs = apply_rewrites([sub], rewriter) optm_dd_0 = @rule d(d(~x)) => 0 star_0 = @rule ★(0) => 0