Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression level rewriting #69

Merged
merged 43 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0ab1ef1
Set version to 0.1.7
algebraicjuliabot Aug 20, 2024
e46fdec
Added more exports (#44)
GeorgeR227 Aug 23, 2024
87a9c5c
Add type rules for vectorfields
lukem12345 Aug 22, 2024
d921c70
Add musical overload resolution
lukem12345 Aug 22, 2024
804ca95
Take advantage of :infer in type rules
lukem12345 Aug 22, 2024
3e916c4
Initial attempt at rewriting
GeorgeR227 Aug 23, 2024
5fbe4b4
Added proof of concept
GeorgeR227 Aug 30, 2024
33db813
Added ability to do through-op rewrites
GeorgeR227 Sep 5, 2024
c9c6aef
Added Space import
GeorgeR227 Sep 5, 2024
8045545
Completed full pipeline
GeorgeR227 Sep 13, 2024
8097521
Remove metadata usage
GeorgeR227 Sep 13, 2024
e0ff9a8
Added DECQuantity types
GeorgeR227 Sep 16, 2024
b9b4146
Completed pipeline again
GeorgeR227 Sep 18, 2024
87f65fe
fixed bug where type-checking subtraction uses +(S1,S2), which is obs…
quffaro Sep 18, 2024
90b1adc
George and I debugged rewriting. Incorrect type passed to resulting t…
quffaro Sep 19, 2024
d4427b1
Cleaning up pipeline
GeorgeR227 Sep 20, 2024
6a3877f
Fixed order of inclusions
GeorgeR227 Sep 20, 2024
ea2d8c0
adding support for Parameters and Constants
quffaro Sep 23, 2024
20fdef6
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 23, 2024
9bb6269
Added tests for acset2symbolic
GeorgeR227 Sep 23, 2024
8661e2f
etc
quffaro Sep 23, 2024
b29e991
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 23, 2024
2228d7a
Literals testing
quffaro Sep 23, 2024
bc9ab00
parameters test passing after some debugging.
quffaro Sep 26, 2024
2b3198f
supporting Infer, better Base.nameof, better tests
quffaro Sep 27, 2024
31ad602
Clean out-of-order vector constructions
lukem12345 Sep 27, 2024
d408c26
Convert to symbolics inside merge_equations
lukem12345 Sep 27, 2024
3cd624e
Reduce cases of topological sort
lukem12345 Sep 27, 2024
67079cb
Reify via recursive function, not lambda case
lukem12345 Sep 27, 2024
367414d
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 27, 2024
5b84cc8
Further improvement of acset2symbolics
GeorgeR227 Sep 28, 2024
35f7b8e
Remove extraneous tangents
GeorgeR227 Sep 28, 2024
da0f81a
Remove redundant helper functions
lukem12345 Sep 28, 2024
fb4927c
Pass indexed names and types directly
lukem12345 Sep 28, 2024
2d158e8
Removed extraneous d arg
GeorgeR227 Sep 28, 2024
0b32bab
fixing work on tumor invasion
quffaro Sep 30, 2024
735edfa
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
quffaro Sep 30, 2024
fe21de4
macros which create export stmts will fail inside @testset due to Jul…
quffaro Oct 1, 2024
ffa7c8c
removed ghost emoji and added convenience function for rules. aqua's …
quffaro Oct 1, 2024
97e8b2d
Added more tests for acset2symbolics
GeorgeR227 Oct 2, 2024
2549322
Fixed persistence issue
GeorgeR227 Oct 2, 2024
1fae01c
Final touches
GeorgeR227 Oct 2, 2024
6125c1e
Remove unused fuctionality
GeorgeR227 Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "DiagrammaticEquations"
uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317"
license = "MIT"
authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"]
version = "0.1.6"
version = "0.1.7"

[deps]
ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Expand Down
8 changes: 7 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""
module DiagrammaticEquations

using Catlab

export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
# Deca
Expand All @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
Collage, collate,
## composition
oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram,
apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
Expand All @@ -25,12 +28,12 @@ unique_lits!,
Plus, AppCirc1, Var, Tan, App1, App2,
## visualization
to_graphviz_property_graph, typename, draw_composition,
to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!

using Catlab
using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
Expand Down Expand Up @@ -62,11 +65,14 @@ include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("symbolictheoryutils.jl")
include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")

@reexport using .Deca
@reexport using .SymbolicUtilsInterop

include("acset2symbolic.jl")

end
7 changes: 5 additions & 2 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module SymbolicUtilsInterop

using ACSets
using ..DiagrammaticEquations: AbstractDecapode, Quantity
using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique!
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..decapodes
using ..Deca
Expand All @@ -14,6 +16,7 @@ struct SymbolicEquation{E}
lhs::E
rhs::E
end
export SymbolicEquation

Base.show(io::IO, e::SymbolicEquation) = begin
print(io, e.lhs); print(io, " == "); print(io, e.rhs)
Expand Down Expand Up @@ -122,7 +125,7 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__)
end

function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
end

""" function SummationDecapode(e::SymbolicContext) """
Expand All @@ -132,7 +135,7 @@ function SummationDecapode(e::SymbolicContext)

foreach(e.vars) do var
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var)))
var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var)))
symbol_table[var.name] = var_id
end

Expand Down
47 changes: 22 additions & 25 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ end
# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]
:PVF, :DVF,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const VECTORFIELD_TYPES = [:PVF, :DVF]

const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]
Expand All @@ -184,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode)
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized. CHECK: $types")
end
export recognize_types

""" is_expanded(d::AbstractNamedDecapode)

Expand Down Expand Up @@ -427,12 +431,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
end

"""
filterfor_forms(types::AbstractVector{Symbol})
filterfor_ec_types(types::AbstractVector{Symbol})

Return any form type symbols.
Return any form or vector-field type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
function filterfor_ec_types(types::AbstractVector{Symbol})
conditions = x -> !(x in NON_EC_TYPES)
filter(conditions, types)
end

Expand All @@ -447,29 +451,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))
ec_types = unique(filterfor_ec_types(types))

form = @match length(forms) begin
ec_type = @match length(ec_types) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
1 => only(ec_types)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types")
end

for idx in idxs
applied |= safe_modifytype!(d, idx, form)
applied |= safe_modifytype!(d, idx, ec_type)
end

return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
type_src = d[d[op1_id, :src], :type]
type_tgt = d[d[op1_id, :tgt], :type]
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])

score_src = (rule.src_type == type_src)
score_tgt = (rule.tgt_type == type_tgt)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
Expand All @@ -480,19 +481,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
type_proj1 = d[d[op2_id, :proj1], :type]
type_proj2 = d[d[op2_id, :proj2], :type]
type_res = d[d[op2_id, :res], :type]
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])

score_proj1 = (rule.proj1_type == type_proj1)
score_proj2 = (rule.proj2_type == type_proj2)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
end

Expand Down
107 changes: 107 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using DiagrammaticEquations
using ACSets
using MLStyle
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils: BasicSymbolic, Symbolic

export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup

const DECA_EQUALITY_SYMBOL = (==)

function symbolics_lookup(d::SummationDecapode)
Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i
(d[i, :name], decavar_to_symbolics(d, i))
end)
end

function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space)
SymbolicUtils.Sym{new_type}(d[idx, :name])
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name])

op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type)))

S = promote_symtype(op_sym, input_syms...)
rhs = SymbolicUtils.Term{S}(op_sym, input_syms)
SymbolicEquation{Symbolic}(output_sym, rhs)
end

function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = infer_types!(deepcopy(old_d))
eqns = merge_equations(d)
to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns))
end

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic})
non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d))
map(non_tangents) do node
to_symbolics(d, symvar_lookup, node.index, node.name)
end
end

# XXX SymbolicUtils.substitute swaps the order of multiplication.
# example: @decapode begin
# u::Form0
# G::Form0
# κ::Constant
# ∂ₜ(G) == κ*★(d(★(d(u))))
# end
# will have the kappa*var term rewritten to var*kappa
function merge_equations(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
symexpr_list = extract_symexprs(d, symvar_lookup)

eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node
sym = symvar_lookup[d[node, :name]]
(sym, sym)
end)
foreach(symexpr_list) do expr
merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup)
push!(eqn_lookup, (expr.lhs => merged_rhs))
end

terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list)
map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals)
end

formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs])

function apply_rewrites(symexprs, rewriter)
map(symexprs) do sym
res_sym = rewriter(sym)
isnothing(res_sym) ? sym : res_sym

Check warning on line 77 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L74-L77

Added lines #L74 - L77 were not covered by tests
end
end

function to_acset(d::SummationDecapode, sym_exprs)
outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx
:($(d[idx, :name])::$(d[idx, :type]))
end

tangents = map(incident(d, DerivOp, :op1)) do op1
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name])))
end

#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)
recursive_descent = @λ begin
e::Expr => begin
if e.head == :call
e.args[1] = nameof(e.args[1])
map(recursive_descent, e.args[2:end])
end
end
sym => nothing
end
foreach(recursive_descent, final_exprs)

deca_block = quote end
deca_block.args = [outer_types..., tangents..., final_exprs...]
infer_types!(SummationDecapode(parse_decapode(deca_block)))
end

29 changes: 21 additions & 8 deletions src/deca/ThDEC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@
# this ensures symtype doesn't recurse endlessly
SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S

struct Scalar <: DECQuantity end
export Scalar
abstract type AbstractScalar <: DECQuantity end

struct Scalar <: AbstractScalar end
struct Parameter <: AbstractScalar end
struct ConstScalar <: AbstractScalar end
struct Literal <: AbstractScalar end
export Scalar, Parameter, ConstScalar, Literal

struct FormParams
dim::Int
Expand Down Expand Up @@ -107,7 +112,7 @@
export PatFormDim

@active PatScalar(T) begin
if T <: Scalar
if T <: AbstractScalar
Some(T)
end
end
Expand Down Expand Up @@ -178,6 +183,8 @@
@rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x))))
end

@alias (Δ₀, Δ₁, Δ₂) => Δ

@operator +(S1, S2)::DECQuantity begin
@match (S1, S2) begin
(PatScalar(_), PatScalar(_)) => Scalar
Expand All @@ -193,7 +200,9 @@
end
end

@operator -(S1, S2)::DECQuantity begin +(S1, S2) end
@operator -(S1, S2)::DECQuantity begin
promote_symtype(+, S1, S2)
end

@operator *(S1, S2)::DECQuantity begin
@match (S1, S2) begin
Expand All @@ -219,9 +228,10 @@

abstract type SortError <: Exception end

# struct WedgeDimError <: SortError end

Base.nameof(s::Scalar) = :Constant
Base.nameof(s::Literal) = :Literal
Base.nameof(s::ConstScalar) = :ConstScalar
Base.nameof(s::Parameter) = :Parameter
Base.nameof(s::Scalar) = :Scalar

Check warning on line 234 in src/deca/ThDEC.jl

View check run for this annotation

Codecov / codecov/patch

src/deca/ThDEC.jl#L231-L234

Added lines #L231 - L234 were not covered by tests

function Base.nameof(f::Form; with_dim_parameter=false)
dual = isdual(f) ? "Dual" : ""
Expand Down Expand Up @@ -265,7 +275,10 @@

function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol)
@match qty begin
:Scalar || :Constant => Scalar
:Scalar => Scalar
:Constant => ConstScalar
:Parameter => Parameter
:Literal => Literal
:Form0 => PrimalForm{0, space, 1}
:Form1 => PrimalForm{1, space, 1}
:Form2 => PrimalForm{2, space, 1}
Expand Down
Loading
Loading