Skip to content

Commit

Permalink
Completed pipeline again
Browse files Browse the repository at this point in the history
Addition now works as well but rewriting seems to be janky, unrelated to this pipeline specifically I believe.
  • Loading branch information
GeorgeR227 committed Sep 18, 2024
1 parent e0ff9a8 commit b9b4146
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 34 deletions.
55 changes: 41 additions & 14 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

export extract_symexprs, apply_rewrites, merge_equations, to_acset
export extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup

const DECA_EQUALITY_SYMBOL = (==)

Expand All @@ -21,14 +21,19 @@ end

function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I)
var = d[index, :name]
new_type = symtype(Deca.DECQuantity, d[index, :type], space)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space)
@info new_type
SymbolicUtils.Sym{new_type}(var)

Check warning on line 26 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L22-L26

Added lines #L22 - L26 were not covered by tests
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1})
input_sym = symvar_lookup[d[d[op_index, :src], :name]]
output_sym = symvar_lookup[d[d[op_index, :tgt], :name]]

Check warning on line 31 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L29-L31

Added lines #L29 - L31 were not covered by tests
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])
# op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

op_sym = getfield(@__MODULE__, d[op_index, :op1])

Check warning on line 34 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L34

Added line #L34 was not covered by tests

@info typeof(op_sym)

Check warning on line 36 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L36

Added line #L36 was not covered by tests

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 39 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
Expand All @@ -38,26 +43,31 @@ function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbolic
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]]
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]]
output_sym = symvar_lookup[d[d[op_index, :res], :name]]

Check warning on line 45 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L42-L45

Added lines #L42 - L45 were not covered by tests
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])
# op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

op_sym = getfield(@__MODULE__, d[op_index, :op2])

Check warning on line 48 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L48

Added line #L48 was not covered by tests

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 51 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ})
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...))
# end
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Σ})
syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]]
output_sym = symvar_lookup[d[d[op_index, :sum], :name]]

Check warning on line 57 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L55-L57

Added lines #L55 - L57 were not covered by tests

rhs = SymbolicUtils.Term{Number}(+, syms_array)
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 60 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
end

function symbolic_rewriting(old_d::SummationDecapode)
d = deepcopy(old_d)

Check warning on line 64 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L63-L64

Added lines #L63 - L64 were not covered by tests

infer_types!(d)

Check warning on line 66 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L66

Added line #L66 was not covered by tests
resolve_overloads!(d)
# resolve_overloads!(d)

symvar_lookup = symbolics_lookup(d)

symexprs = extract_symexprs(d, symvar_lookup)
merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup))

Check warning on line 70 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic})
Expand All @@ -70,10 +80,10 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symb
sym_list

Check warning on line 80 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L73-L80

Added lines #L73 - L80 were not covered by tests
end

function apply_rewrites(d::SummationDecapode, rewriter)
function apply_rewrites(symexprs, rewriter)

Check warning on line 83 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L83

Added line #L83 was not covered by tests

rewritten_list = []
for sym in extract_symexprs(d)
for sym in symexprs
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
Expand Down Expand Up @@ -113,7 +123,18 @@ end

function to_acset(og_d, sym_exprs)
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)

Check warning on line 125 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
map(x -> x.args[1] = :(==), final_exprs)

recursive_descent = begin
e::Expr => begin
if e.head == :call
@show nameof(e.args[1])
e.args[1] = nameof(e.args[1])
map(recursive_descent, e.args[2:end])

Check warning on line 132 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L127-L132

Added lines #L127 - L132 were not covered by tests
end
end
sym => nothing

Check warning on line 135 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L135

Added line #L135 was not covered by tests
end
map(recursive_descent, final_exprs)

Check warning on line 137 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L137

Added line #L137 was not covered by tests

deca_block = quote end

Check warning on line 139 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L139

Added line #L139 was not covered by tests

Expand All @@ -124,12 +145,18 @@ function to_acset(og_d, sym_exprs)

append!(deca_block.args, map(deca_type_gen, vcat(states, terminals)))

Check warning on line 146 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L146

Added line #L146 was not covered by tests

for op1 in parts(og_d, :Op1)
if og_d[op1, :op1] == DerivOp
push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name]))))

Check warning on line 150 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L148-L150

Added lines #L148 - L150 were not covered by tests
end
end

Check warning on line 152 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L152

Added line #L152 was not covered by tests

append!(deca_block.args, final_exprs)

Check warning on line 154 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L154

Added line #L154 was not covered by tests

d = SummationDecapode(parse_decapode(deca_block))

Check warning on line 156 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L156

Added line #L156 was not covered by tests

infer_types!(d)

Check warning on line 158 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L158

Added line #L158 was not covered by tests
resolve_overloads!(d)
# resolve_overloads!(d)

d

Check warning on line 161 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L161

Added line #L161 was not covered by tests
end
2 changes: 2 additions & 0 deletions src/deca/ThDEC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ end
@rule Δ(~x::isForm1) => (d((d(~x)))) + d((d((~x))))
end

@alias (Δ₀, Δ₁, Δ₂) => Δ

@operator +(S1, S2)::DECQuantity begin
@match (S1, S2) begin
(PatScalar(_), PatScalar(_)) => Scalar
Expand Down
64 changes: 44 additions & 20 deletions src/sym_rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,71 @@ using MLStyle

Heat = @decapode begin
C::Form0
G::Form1
D::Constant
∂ₜ(C) == D*Δ(d(C))
∂ₜ(G) == D*Δ(d(C))
end

infer_types!(Heat)
resolve_overloads!(Heat)

@syms Δ(x) d(x) (x) Δ₀(x) Δ₁(x) Δ₂(x) d₀(x)
Brusselator = @decapode begin
(U, V)::Form0
U2V::Form0
(U̇, V̇)::Form0

lap_0_convert = @rule Δ₀(~x) => Δ(~x)
lap_1_convert = @rule Δ₁(~x) => Δ(~x)
lap_2_convert = @rule Δ₂(~x) => Δ(~x)
(α)::Constant
F::Parameter

d_0_convert = @rule d₀(~x) => d(~x)
U2V == (U .* U) .* V

overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert]
== 1 + U2V - (4.4 * U) +* Δ(U)) + F
== (3.4 * U) - U2V +* Δ(V))
∂ₜ(U) ==
∂ₜ(V) ==
end
infer_types!(Brusselator)

Phytodynamics = @decapode begin
(n,w)::Form0
m::Constant
∂ₜ(n) == w + m*n + Δ(n)
end
infer_types!(Phytodynamics)
test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics))

# resolve_overloads!(Heat)

# lap_0_convert = @rule Δ₀(~x) => Δ(~x)
# lap_1_convert = @rule Δ₁(~x) => Δ(~x)
# lap_2_convert = @rule Δ₂(~x) => Δ(~x)

# d_0_convert = @rule d₀(~x) => d(~x)

# overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert]

lap_0_rule = @rule Δ(~x) => (d((d(~x))))
lap_1_rule = @rule Δ(~x) => d((d((~x)))) + (d((d(~x))))
lap_2_rule = @rule Δ(~x) => d((d((~x))))
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
# lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x))))
# lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x))))

openers = [lap_0_rule, lap_1_rule, lap_2_rule]
# openers = [lap_0_rule, lap_1_rule, lap_2_rule]

heat_exprs = extract_symexprs(Heat)
r = rules(Δ, Val(1))

rewriter = SymbolicUtils.Postwalk(
SymbolicUtils.Fixpoint(SymbolicUtils.Chain(vcat(overloaders, openers))))
heat_exprs = symbolic_rewriting(Heat)

res_exprs = apply_rewrites(Heat, rewriter)
rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(test_rule))

merge_exprs = merge_equations(Heat, res_exprs)
res_exprs = apply_rewrites(heat_exprs, rewriter)

optm_dd_0 = @rule d(d(~x)) => 0
star_0 = @rule (0) => 0
star_0 = @rule (0) => 0
d_0 = @rule d(0) => 0

optm_rewriter = SymbolicUtils.Postwalk(
SymbolicUtils.Fixpoint(SymbolicUtils.Chain([optm_dd_0, star_0, d_0])))

res_merge_exprs = map(optm_rewriter, merge_exprs)
res_merge_exprs = map(optm_rewriter, res_exprs)

deca_test = to_acset(Heat, res_merge_exprs)
deca_test = to_acset(Heat, res_exprs)
infer_types!(deca_test)
resolve_overloads!(deca_test)

0 comments on commit b9b4146

Please sign in to comment.