Skip to content

Commit

Permalink
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
Browse files Browse the repository at this point in the history
…Equations.jl into gr/acset2sym
  • Loading branch information
quffaro committed Sep 23, 2024
2 parents ea2d8c0 + 6a3877f commit 20fdef6
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 69 deletions.
3 changes: 2 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")
include("acset2symbolic.jl")

@reexport using .Deca
@reexport using .SymbolicUtilsInterop

include("acset2symbolic.jl")

end
1 change: 1 addition & 0 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct SymbolicEquation{E}
lhs::E
rhs::E
end
export SymbolicEquation

Base.show(io::IO, e::SymbolicEquation) = begin
print(io, e.lhs); print(io, " == "); print(io, e.rhs)
Expand Down
87 changes: 45 additions & 42 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

export extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup
import SymbolicUtils: BasicSymbolic, Symbolic

export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name))
to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name))

Check warning on line 14 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L14

Added line #L14 was not covered by tests

function symbolics_lookup(d::SummationDecapode)
lookup = Dict{Symbol, SymbolicUtils.BasicSymbolic}()
lookup = Dict{Symbol, BasicSymbolic}()
for i in parts(d, :Var)
push!(lookup, d[i, :name] => decavar_to_symbolics(d, i))
end
Expand All @@ -22,80 +24,69 @@ end
function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I)
var = d[index, :name]
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space)

Check warning on line 26 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L24-L26

Added lines #L24 - L26 were not covered by tests
@info new_type

SymbolicUtils.Sym{new_type}(var)

Check warning on line 28 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L28

Added line #L28 was not covered by tests
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1})
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1})
input_sym = symvar_lookup[d[d[op_index, :src], :name]]
output_sym = symvar_lookup[d[d[op_index, :tgt], :name]]

Check warning on line 33 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L31-L33

Added lines #L31 - L33 were not covered by tests
# op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

op_sym = getfield(@__MODULE__, d[op_index, :op1])

Check warning on line 35 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L35

Added line #L35 was not covered by tests

S = promote_symtype(op_sym, input_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
SymbolicEquation{Symbolic}(output_sym, rhs)

Check warning on line 39 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L37-L39

Added lines #L37 - L39 were not covered by tests
end

# TODO add promote_symtype as Op1
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2})
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2})
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]]
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]]
output_sym = symvar_lookup[d[d[op_index, :res], :name]]

Check warning on line 45 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L42-L45

Added lines #L42 - L45 were not covered by tests
# op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

op_sym = getfield(@__MODULE__, d[op_index, :op2])

Check warning on line 47 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L47

Added line #L47 was not covered by tests

S = promote_symtype(op_sym, input1_sym, input2_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
SymbolicEquation{Symbolic}(output_sym, rhs)

Check warning on line 51 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L49-L51

Added lines #L49 - L51 were not covered by tests
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Σ})
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ})
syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]]
output_sym = symvar_lookup[d[d[op_index, :sum], :name]]

Check warning on line 57 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L55-L57

Added lines #L55 - L57 were not covered by tests

# TODO pls test
S = promote_symtype(+, syms_array...)
rhs = SymbolicUtils.Term{S}(+, syms_array)
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
SymbolicEquation{Symbolic}(output_sym,rhs)

Check warning on line 62 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L60-L62

Added lines #L60 - L62 were not covered by tests
end

function symbolic_rewriting(old_d::SummationDecapode)
function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = deepcopy(old_d)

Check warning on line 66 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L65-L66

Added lines #L65 - L66 were not covered by tests

infer_types!(d)

Check warning on line 68 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L68

Added line #L68 was not covered by tests
# resolve_overloads!(d)

symvar_lookup = symbolics_lookup(d)
merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup))
end
eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup))

Check warning on line 71 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L70-L71

Added lines #L70 - L71 were not covered by tests

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic})
topo_list = topological_sort_edges(d)
sym_list = []
for node in topo_list
retrieve_name(d, node) != DerivOp || continue
push!(sym_list, to_symbolics(d, symvar_lookup, node))
if !isnothing(rewriter)
eqns = map(rewriter, eqns)

Check warning on line 74 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L73-L74

Added lines #L73 - L74 were not covered by tests
end
sym_list
end

function apply_rewrites(symexprs, rewriter)
to_acset(d, eqns)

Check warning on line 77 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L77

Added line #L77 was not covered by tests
end

rewritten_list = []
for sym in symexprs
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic})
sym_list = SymbolicEquation{Symbolic}[]
for node in topological_sort_edges(d)
retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC
push!(sym_list, to_symbolics(d, symvar_lookup, node))
end

rewritten_list
sym_list

Check warning on line 86 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L80-L86

Added lines #L80 - L86 were not covered by tests
end

function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, rewritten_syms)
function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}})

Check warning on line 89 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L89

Added line #L89 was not covered by tests

eqn_lookup = Dict()

Check warning on line 91 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L91

Added line #L91 was not covered by tests

Expand All @@ -108,22 +99,35 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Symbo

final_nodes = infer_terminal_names(d)

Check warning on line 100 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L100

Added line #L100 was not covered by tests

for expr in rewritten_syms
for expr in symexpr_list

Check warning on line 102 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L102

Added line #L102 was not covered by tests

merged_eqn = SymbolicUtils.substitute(expr, eqn_lookup)
lhs = merged_eqn.arguments[1]
rhs = merged_eqn.arguments[2]
merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup)

Check warning on line 104 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L104

Added line #L104 was not covered by tests

push!(eqn_lookup, (lhs => rhs))
push!(eqn_lookup, (expr.lhs => merged_rhs))

Check warning on line 106 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L106

Added line #L106 was not covered by tests

if lhs.name in final_nodes
push!(final_list, merged_eqn)
if expr.lhs.name in final_nodes
push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs))

Check warning on line 109 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L108-L109

Added lines #L108 - L109 were not covered by tests
end
end

Check warning on line 111 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L111

Added line #L111 was not covered by tests

final_list

Check warning on line 113 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L113

Added line #L113 was not covered by tests
end

formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs])

Check warning on line 116 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L116

Added line #L116 was not covered by tests

function apply_rewrites(symexprs, rewriter)

Check warning on line 118 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L118

Added line #L118 was not covered by tests

rewritten_list = []
for sym in symexprs
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
end

Check warning on line 125 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L120-L125

Added lines #L120 - L125 were not covered by tests

rewritten_list

Check warning on line 127 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L127

Added line #L127 was not covered by tests
end


function to_acset(og_d, sym_exprs)
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)

Check warning on line 132 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L131-L132

Added lines #L131 - L132 were not covered by tests

Expand Down Expand Up @@ -159,7 +163,6 @@ function to_acset(og_d, sym_exprs)
d = SummationDecapode(parse_decapode(deca_block))

Check warning on line 163 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L163

Added line #L163 was not covered by tests

infer_types!(d)

Check warning on line 165 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L165

Added line #L165 was not covered by tests
# resolve_overloads!(d)

d

Check warning on line 167 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L167

Added line #L167 was not covered by tests
end
8 changes: 4 additions & 4 deletions src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using DiagrammaticEquations
using ACSets

export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name, start_nodes
export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes

struct TraversalNode{T}
index::Int
Expand All @@ -20,7 +20,7 @@ function topological_sort_edges(d::SummationDecapode)
# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = TraversalNode{Symbol}[]

for _ in 1:number_of_ops(d)
for _ in 1:n_ops(d)
for op in parts(d, :Op1)
if !visited_1[op] && visited_Var[d[op, :src]]

Expand Down Expand Up @@ -49,12 +49,12 @@ function topological_sort_edges(d::SummationDecapode)
end
end

@assert length(op_order) == number_of_ops(d)
@assert length(op_order) == n_ops(d)

op_order

Check warning on line 54 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L54

Added line #L54 was not covered by tests
end

function number_of_ops(d::SummationDecapode)
function n_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )
end

Expand Down
29 changes: 8 additions & 21 deletions src/sym_rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ Heat = @decapode begin
C::Form0
G::Form1
D::Constant
∂ₜ(G) == D*Δ(d(C))
∂ₜ(G) == D*Δ(C)
end;
infer_types!(Heat)
test_heat_same = symbolic_rewriting(Heat)

r = rules(Δ, Val(1))

rwr = Fixpoint(Prewalk(Chain(r)))
test_heat_open = symbolic_rewriting(Heat, rwr)

Brusselator = @decapode begin
(U, V)::Form0
Expand All @@ -35,24 +41,7 @@ Phytodynamics = @decapode begin
∂ₜ(n) == w + m*n + Δ(n)
end
infer_types!(Phytodynamics)
test = to_acset(Phytodynamics, symbolic_rewriting(Phytodynamics))

# resolve_overloads!(Heat)

# lap_0_convert = @rule Δ₀(~x) => Δ(~x)
# lap_1_convert = @rule Δ₁(~x) => Δ(~x)
# lap_2_convert = @rule Δ₂(~x) => Δ(~x)

# d_0_convert = @rule d₀(~x) => d(~x)

# overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert]

# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
# lap_1_rule = @rule Δ(~x) => d(⋆(d(⋆(~x)))) + ⋆(d(⋆(d(~x))))
# lap_2_rule = @rule Δ(~x) => d(⋆(d(⋆(~x))))

# openers = [lap_0_rule, lap_1_rule, lap_2_rule]

test_phy = symbolic_rewriting(Phytodynamics)

# it seems that type-instability or improper type promotion is happening. expressions derived from this have BasicSymbolic{Number} type, which means we can't conditionally rewrite on forms.
heat_exprs = symbolic_rewriting(Heat)
Expand Down Expand Up @@ -103,8 +92,6 @@ R1 = @rule Δ(~~x::(x->isForm1(x))) => ★(d(★(d(~x))))
# pulling out the subexpression
rewriter = SymbolicUtils.Fixpoint(SymbolicUtils.Prewalk(SymbolicUtils.Chain(r)))



res_exprs = apply_rewrites(heat_exprs, rewriter)
sub_exprs = apply_rewrites([sub], rewriter)

Expand Down
2 changes: 1 addition & 1 deletion test/graph_traversal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MLStyle
using Test

function is_correct_length(d::SummationDecapode, result)
return length(result) == number_of_ops(d)
return length(result) == n_ops(d)
end

@testset "Topological Sort on Edges" begin
Expand Down

0 comments on commit 20fdef6

Please sign in to comment.