Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression level rewriting #69

Merged
merged 43 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0ab1ef1
Set version to 0.1.7
algebraicjuliabot Aug 20, 2024
e46fdec
Added more exports (#44)
GeorgeR227 Aug 23, 2024
87a9c5c
Add type rules for vectorfields
lukem12345 Aug 22, 2024
d921c70
Add musical overload resolution
lukem12345 Aug 22, 2024
804ca95
Take advantage of :infer in type rules
lukem12345 Aug 22, 2024
3e916c4
Initial attempt at rewriting
GeorgeR227 Aug 23, 2024
5fbe4b4
Added proof of concept
GeorgeR227 Aug 30, 2024
33db813
Added ability to do through-op rewrites
GeorgeR227 Sep 5, 2024
c9c6aef
Added Space import
GeorgeR227 Sep 5, 2024
8045545
Completed full pipeline
GeorgeR227 Sep 13, 2024
8097521
Remove metadata usage
GeorgeR227 Sep 13, 2024
e0ff9a8
Added DECQuantity types
GeorgeR227 Sep 16, 2024
b9b4146
Completed pipeline again
GeorgeR227 Sep 18, 2024
87f65fe
fixed bug where type-checking subtraction uses +(S1,S2), which is obs…
quffaro Sep 18, 2024
90b1adc
George and I debugged rewriting. Incorrect type passed to resulting t…
quffaro Sep 19, 2024
d4427b1
Cleaning up pipeline
GeorgeR227 Sep 20, 2024
6a3877f
Fixed order of inclusions
GeorgeR227 Sep 20, 2024
ea2d8c0
adding support for Parameters and Constants
quffaro Sep 23, 2024
20fdef6
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 23, 2024
9bb6269
Added tests for acset2symbolic
GeorgeR227 Sep 23, 2024
8661e2f
etc
quffaro Sep 23, 2024
b29e991
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 23, 2024
2228d7a
Literals testing
quffaro Sep 23, 2024
bc9ab00
parameters test passing after some debugging.
quffaro Sep 26, 2024
2b3198f
supporting Infer, better Base.nameof, better tests
quffaro Sep 27, 2024
31ad602
Clean out-of-order vector constructions
lukem12345 Sep 27, 2024
d408c26
Convert to symbolics inside merge_equations
lukem12345 Sep 27, 2024
3cd624e
Reduce cases of topological sort
lukem12345 Sep 27, 2024
67079cb
Reify via recursive function, not lambda case
lukem12345 Sep 27, 2024
367414d
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 27, 2024
5b84cc8
Further improvement of acset2symbolics
GeorgeR227 Sep 28, 2024
35f7b8e
Remove extraneous tangents
GeorgeR227 Sep 28, 2024
da0f81a
Remove redundant helper functions
lukem12345 Sep 28, 2024
fb4927c
Pass indexed names and types directly
lukem12345 Sep 28, 2024
2d158e8
Removed extraneous d arg
GeorgeR227 Sep 28, 2024
0b32bab
fixing work on tumor invasion
quffaro Sep 30, 2024
735edfa
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 30, 2024
fe21de4
macros which create export stmts will fail inside @testset due to Jul…
quffaro Oct 1, 2024
ffa7c8c
removed ghost emoji and added convenience function for rules. aqua's …
quffaro Oct 1, 2024
97e8b2d
Added more tests for acset2symbolics
GeorgeR227 Oct 2, 2024
2549322
Fixed persistence issue
GeorgeR227 Oct 2, 2024
1fae01c
Final touches
GeorgeR227 Oct 2, 2024
6125c1e
Remove unused fuctionality
GeorgeR227 Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "DiagrammaticEquations"
uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317"
license = "MIT"
authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"]
version = "0.1.6"
version = "0.1.7"

[deps]
ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Expand Down
8 changes: 7 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""
module DiagrammaticEquations

using Catlab

export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
# Deca
Expand All @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
Collage, collate,
## composition
oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram,
apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
Expand All @@ -25,12 +28,12 @@ unique_lits!,
Plus, AppCirc1, Var, Tan, App1, App2,
## visualization
to_graphviz_property_graph, typename, draw_composition,
to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!

using Catlab
using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
Expand Down Expand Up @@ -62,11 +65,14 @@ include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("symbolictheoryutils.jl")
include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")

@reexport using .Deca
@reexport using .SymbolicUtilsInterop

include("acset2symbolic.jl")

end
31 changes: 17 additions & 14 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module SymbolicUtilsInterop

using ..DiagrammaticEquations: AbstractDecapode, Quantity
using ACSets
using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp
using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique!
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..decapodes
using ..Deca
Expand All @@ -14,6 +16,7 @@ struct SymbolicEquation{E}
lhs::E
rhs::E
end
export SymbolicEquation

Base.show(io::IO, e::SymbolicEquation) = begin
print(io, e.lhs); print(io, " == "); print(io, e.rhs)
Expand Down Expand Up @@ -48,7 +51,7 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
decapodes.Plus(termargs)
elseif op == *
decapodes.Mult(termargs)
elseif op == ∂ₜ
elseif op ∈ [DerivOp, ∂ₜ]
decapodes.Tan(only(termargs))
elseif length(args) == 1
decapodes.App1(nameof(op, symtype.(args)...), termargs...)
Expand Down Expand Up @@ -82,9 +85,9 @@ Example:
SymbolicUtils.BasicSymbolic(context, Term(a))
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__)
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term)
# user must import symbols into scope
! = (f -> getfield(__module__, f))
! = (f -> getfield(@__MODULE__, f))
@match t begin
Var(name) => SymbolicUtils.Sym{context[name]}(name)
Lit(v) => Meta.parse(string(v))
Expand All @@ -95,17 +98,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapode
# see test/language.jl
(f, x) -> (!(f))(x),
fs;
init=BasicSymbolic(context, arg, __module__)
init=BasicSymbolic(context, arg)
)
App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...)
Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__))
App1(f, x) => (!(f))(BasicSymbolic(context, x))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...)
Tan(x) => (!(DerivOp))(BasicSymbolic(context, x))
end
end

function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__)
function SymbolicContext(d::decapodes.DecaExpr)
# associates each var to its sort...
context = map(d.context) do j
j.var => symtype(Deca.DECQuantity, j.dim, j.space)
Expand All @@ -116,13 +119,13 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__)
end
context = Dict{Symbol,DataType}(context)
eqs = map(d.equations) do eq
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...)
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...)
end
SymbolicContext(vars, eqs)
end

function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
end

""" function SummationDecapode(e::SymbolicContext) """
Expand All @@ -132,7 +135,7 @@ function SummationDecapode(e::SymbolicContext)

foreach(e.vars) do var
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var)))
var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var)))
symbol_table[var.name] = var_id
end

Expand Down
47 changes: 22 additions & 25 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ end
# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]
:PVF, :DVF,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const VECTORFIELD_TYPES = [:PVF, :DVF]

const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]
Expand All @@ -184,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode)
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized. CHECK: $types")
end
export recognize_types

""" is_expanded(d::AbstractNamedDecapode)

Expand Down Expand Up @@ -427,12 +431,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
end

"""
filterfor_forms(types::AbstractVector{Symbol})
filterfor_ec_types(types::AbstractVector{Symbol})

Return any form type symbols.
Return any form or vector-field type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
function filterfor_ec_types(types::AbstractVector{Symbol})
conditions = x -> !(x in NON_EC_TYPES)
filter(conditions, types)
end

Expand All @@ -447,29 +451,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))
ec_types = unique(filterfor_ec_types(types))

form = @match length(forms) begin
ec_type = @match length(ec_types) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
1 => only(ec_types)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types")
end

for idx in idxs
applied |= safe_modifytype!(d, idx, form)
applied |= safe_modifytype!(d, idx, ec_type)
end

return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
type_src = d[d[op1_id, :src], :type]
type_tgt = d[d[op1_id, :tgt], :type]
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])

score_src = (rule.src_type == type_src)
score_tgt = (rule.tgt_type == type_tgt)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
Expand All @@ -480,19 +481,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
type_proj1 = d[d[op2_id, :proj1], :type]
type_proj2 = d[d[op2_id, :proj2], :type]
type_res = d[d[op2_id, :res], :type]
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])

score_proj1 = (rule.proj1_type == type_proj1)
score_proj2 = (rule.proj2_type == type_proj2)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
end

Expand Down
97 changes: 97 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using DiagrammaticEquations
using ACSets
using SymbolicUtils
using SymbolicUtils: BasicSymbolic, Symbolic

# TODO: Expose only the symbolic_rewriting function
export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting

const DECA_EQUALITY_SYMBOL = (==)

function symbolics_lookup(d::SummationDecapode)
Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i
(d[i, :name], decavar_to_symbolics(d, i))
end)
end

function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space)
SymbolicUtils.Sym{new_type}(d[idx, :name])
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name])
op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type)))

S = promote_symtype(op_sym, input_syms...)
SymbolicEquation{Symbolic}(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms))
end

function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = infer_types!(deepcopy(old_d))
eqns = merge_equations(d)
to_acset(d, apply_rewrites(eqns, rewriter))
end

apply_rewrites(eqns, rewriter) = isnothing(rewriter) ? eqns : map(rewriter, eqns)

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic})
map(topological_sort_edges(d)) do node
to_symbolics(d, symvar_lookup, node.index, node.name)
end
end

# XXX SymbolicUtils.substitute swaps the order of multiplication.
# e.g. @decapode begin
# ∂ₜ(G) == κ*u
# end
# will have the κ*u term rewritten to u*κ
function merge_equations(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
symexpr_list = extract_symexprs(d, symvar_lookup)

eqn_lookup = Dict()

terminal_vars = infer_terminal_names(d)
terminal_eqns = SymbolicEquation{Symbolic}[]

foreach(symexpr_list) do x
push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup)))
if x.lhs.name in terminal_vars
push!(terminal_eqns, SymbolicEquation{Symbolic}(x.lhs, eqn_lookup[x.lhs]))
end
end

formed_deca_eqn.(terminal_eqns)
end

formed_deca_eqn(symeqn::SymbolicEquation{Symbolic}) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [symeqn.lhs, symeqn.rhs])

function to_acset(d::SummationDecapode, sym_exprs)
outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i
:($(d[i, :name])::$(d[i, :type]))
end

tangents = map(incident(d, DerivOp, :op1)) do op1
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name])))
end

#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)
reify!(exprs) = foreach(exprs) do x
if typeof(x)==Expr && x.head == :call
x.args[1] = nameof(x.args[1])
reify!(x.args[2:end])
end
end
reify!(final_exprs)

deca_block = quote
$(outer_types...)
$(final_exprs...)
end

∘(infer_types!, SummationDecapode, parse_decapode)(deca_block)
end

Loading
Loading