Skip to content

Commit

Permalink
George and I debugged rewriting. Incorrect type passed to resulting t…
Browse files Browse the repository at this point in the history
…erm meant typed rewriting would fail
  • Loading branch information
quffaro committed Sep 19, 2024
1 parent 87f65fe commit 90b1adc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
13 changes: 8 additions & 5 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic

op_sym = getfield(@__MODULE__, d[op_index, :op1])

Check warning on line 34 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L34

Added line #L34 was not covered by tests

@info typeof(op_sym)

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
S = promote_symtype(op_sym, input_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 38 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L36-L38

Added lines #L36 - L38 were not covered by tests
end

# TODO add promote_symtype as Op1
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2})
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]]
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]]
Expand All @@ -47,7 +47,8 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic

op_sym = getfield(@__MODULE__, d[op_index, :op2])

Check warning on line 48 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L48

Added line #L48 was not covered by tests

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
S = promote_symtype(op_sym, input1_sym, input2_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 52 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L50-L52

Added lines #L50 - L52 were not covered by tests
end

Expand All @@ -56,7 +57,9 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic
syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]]
output_sym = symvar_lookup[d[d[op_index, :sum], :name]]

Check warning on line 58 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L56-L58

Added lines #L56 - L58 were not covered by tests

rhs = SymbolicUtils.Term{Number}(+, syms_array)
# TODO pls test
S = promote_symtype(+, syms_array...)
rhs = SymbolicUtils.Term{S}(+, syms_array)
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 63 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L61-L63

Added lines #L61 - L63 were not covered by tests
end

Expand Down
56 changes: 52 additions & 4 deletions src/sym_rewrite.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
using Test
using DiagrammaticEquations
using SymbolicUtils
using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, symtype, promote_symtype
using MLStyle

Heat = @decapode begin
C::Form0
G::Form1
D::Constant
∂ₜ(G) == D*Δ(d(C))
end

end;
infer_types!(Heat)

Brusselator = @decapode begin
Expand Down Expand Up @@ -52,13 +53,60 @@ test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics))

# openers = [lap_0_rule, lap_1_rule, lap_2_rule]

r = rules(Δ, Val(1))

# it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms.
heat_exprs = symbolic_rewriting(Heat)
sub = heat_exprs[1].arguments[2].arguments[2]

a, b = @syms a::Scalar b::Scalar
u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2}

r = rules(Δ, Val(1))

# rule without predication works
R = @rule Δ(~x) => (d((d(~x))))
rwR = Fixpoint(Prewalk(Chain([R])))

R(Δ(d(u)))

# since promote_symtype(d(u)) returns Any while promote_symtype(d, u). I wonder
# if `d(u)` is not subjected to `symtype`

Rp = @rule Δ(~x::isForm1) => "Success"
Rp(Δ(v)) # works
Rp(Δ(d(u))) # works

Rp1 = @rule Δ(~x::isForm1) => (d((d(~x))))

Rp1(Δ(v)) # works
Rp1(Δ(d(u))) # works
rwRp1 = Fixpoint(Prewalk(Chain([Rp1])))
rwRp1(Δ(d(u)))

rwr = Fixpoint(Prewalk(Chain(r)))
rwr(heat_exprs[1]) # THIS WORKS!

rwr(Δ(d(u))) # rwr
rwr(heat_exprs[1].arguments[2])

r[2](Δ(d(u))) # works


# rwR(heat_exprs[1])
# rwR(sub)

# tilde?
R1 = @rule Δ(~~x::(x->isForm1(x))) => (d((d(~x))))

Check warning on line 99 in src/sym_rewrite.jl

View check run for this annotation

Codecov / codecov/patch

src/sym_rewrite.jl#L99

Added line #L99 was not covered by tests

@macroexpand @rule Δ(~x::isForm1) => "Success"

# pulling out the subexpression
rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r)))


rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(test_rule))

res_exprs = apply_rewrites(heat_exprs, rewriter)
sub_exprs = apply_rewrites([sub], rewriter)

optm_dd_0 = @rule d(d(~x)) => 0
star_0 = @rule (0) => 0
Expand Down

0 comments on commit 90b1adc

Please sign in to comment.