Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floyd-Warshall #70

Open
wants to merge 16 commits into
base: gr/acset2sym
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ name = "DiagrammaticEquations"
uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317"
license = "MIT"
authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"]
version = "0.1.6"
version = "0.1.7"

[deps]
ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
SymbolicUtils = "3.4"
Unicode = "1.6"
julia = "1.6"
9 changes: 7 additions & 2 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""
module DiagrammaticEquations

using Catlab

export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
DerivOp, append_dot, normalize_unicode, infer_states, infer_terminals, infer_types!,
# Deca
op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D,
Expand All @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
Collage, collate,
## composition
oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram,
apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
Expand All @@ -25,12 +28,12 @@ unique_lits!,
Plus, AppCirc1, Var, Tan, App1, App2,
## visualization
to_graphviz_property_graph, typename, draw_composition,
to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!

using Catlab
using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
Expand Down Expand Up @@ -59,6 +62,8 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("graph_interface.jl")
include("acset2symbolic.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")

Expand Down
46 changes: 21 additions & 25 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ end
# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]
:PVF, :DVF,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const VECTORFIELD_TYPES = [:PVF, :DVF]

const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]
Expand Down Expand Up @@ -427,12 +430,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
end

"""
filterfor_forms(types::AbstractVector{Symbol})
filterfor_ec_types(types::AbstractVector{Symbol})

Return any form type symbols.
Return any form or vector-field type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
function filterfor_ec_types(types::AbstractVector{Symbol})
conditions = x -> !(x in NON_EC_TYPES)
filter(conditions, types)
end

Expand All @@ -447,29 +450,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))
ec_types = unique(filterfor_ec_types(types))

form = @match length(forms) begin
ec_type = @match length(ec_types) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
1 => only(ec_types)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types")
end

for idx in idxs
applied |= safe_modifytype!(d, idx, form)
applied |= safe_modifytype!(d, idx, ec_type)
end

return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
type_src = d[d[op1_id, :src], :type]
type_tgt = d[d[op1_id, :tgt], :type]
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])

score_src = (rule.src_type == type_src)
score_tgt = (rule.tgt_type == type_tgt)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
Expand All @@ -480,19 +480,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
type_proj1 = d[d[op2_id, :proj1], :type]
type_proj2 = d[d[op2_id, :proj2], :type]
type_res = d[d[op2_id, :res], :type]
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])

score_proj1 = (rule.proj1_type == type_proj1)
score_proj2 = (rule.proj2_type == type_proj2)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
end

Expand Down
114 changes: 114 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using DiagrammaticEquations
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

import DiagrammaticEquations: HyperGraph, HyperGraphEdge, HyperGraphVertex, vertex_list, edge_list
import DiagrammaticEquations: topological_sort_edges

export TableData, extract_symexprs, number_of_ops, retrieve_name

const DECA_EQUALITY_SYMBOL = (==)

# Decapode graph conversion
struct TableData
table_index::Int
table_name::Symbol
end

HyperGraphVertex(d::SummationDecapode, index::Int) = HyperGraphVertex(index, TableData(index, :Var))

HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Op1}) = HyperGraphEdge(d[index, :tgt], [d[index, :src]], TableData(index, :Op1))
HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Op2}) = HyperGraphEdge(d[index, :res], [d[index,:proj1],d[index,:proj2]], TableData(index, :Op2))
HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Σ}) = HyperGraphEdge(d[index, :sum], d[incident(d, index, :summation), :summand], TableData(index, :Σ))

HyperGraph(d::SummationDecapode) = HyperGraph(vertex_list(d), edge_list(d), nothing)

vertex_list(d::SummationDecapode) = map(id -> HyperGraphVertex(d, id), parts(d, :Var))

function edge_list(d::SummationDecapode)
edges = HyperGraphEdge[]
for op_table in [:Op1, :Op2, :Σ]
for op in parts(d, op_table)
if op_table == :Op1 && d[op, :op1] == DerivOp
continue

Check warning on line 35 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L35

Added line #L35 was not covered by tests
end
push!(edges, HyperGraphEdge(d, op, Val(op_table)))
end
end
edges

Check warning on line 40 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L40

Added line #L40 was not covered by tests
end

table_data(v::HyperGraphVertex) = v.metadata

Check warning on line 43 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L43

Added line #L43 was not covered by tests
table_data(v::HyperGraphEdge) = v.metadata

topological_sort_edges(d::SummationDecapode) = table_data.(topological_sort_edges(HyperGraph(d)))

# Decapode ACSet symbolics conversion

to_symbolics(d::SummationDecapode, data::TableData) = to_symbolics(d, data.table_index, Val(data.table_name))

Check warning on line 50 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L50

Added line #L50 was not covered by tests

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1})
input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
end

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2})
input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name])
input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ})
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...))
# end

function extract_symexprs(d::SummationDecapode)
topo_list = topological_sort_edges(d)
sym_list = []
for node in topo_list
retrieve_name(d, node) != DerivOp || continue
push!(sym_list, to_symbolics(d, node))
end
sym_list
end

function apply_rewrites(d::SummationDecapode, rewriter)

rewritten_list = []
for sym in extract_symexprs(d)
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
end

rewritten_list
end

function number_of_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ)

Check warning on line 99 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
end

function retrieve_name(d::SummationDecapode, data::TableData)
@match data.table_name begin
:Op1 => d[data.table_index, :op1]
:Op2 => d[data.table_index, :op2]
:Σ => :+
_ => error("$(data.table_name) is a table without names")

Check warning on line 107 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L107

Added line #L107 was not covered by tests
end
end

# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet
# @syms Δ(x) d(x) ⋆(x)
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
# rewriter = Postwalk(RestartedChain([lap_0_rule]))
37 changes: 28 additions & 9 deletions src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ op1_inf_rules_1D = [

# Rules for the averaging operator
(src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]),

# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]),
(src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]),

# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]),

# Rules for magnitude/ norm
(src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]),
(src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])]
(src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]),
(src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])]

op2_inf_rules_1D = [
# Rules for ∧₀₀, ∧₁₀, ∧₀₁
Expand Down Expand Up @@ -83,7 +90,11 @@ op2_inf_rules_1D = [
(proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])]
(proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]),

# These rules contain infer:
(proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]),
(proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])]

"""
These are the default rules used to do type inference in the 2D exterior calculus.
Expand Down Expand Up @@ -133,13 +144,16 @@ op1_inf_rules_2D = [
(src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]),
(src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]),

# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]),
(src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]),

# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]),

# Rules for magnitude/ norm
(src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]),
(src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]),
(src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]),
(src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]),
(src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]),
(src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])]
(src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]),
(src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])]

op2_inf_rules_2D = vcat(op2_inf_rules_1D, [
# Rules for ∧₁₁, ∧₂₀, ∧₀₂
Expand Down Expand Up @@ -243,6 +257,11 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [
(src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ),
(src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif),
(src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif),
# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯),
(src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯),
# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭),
# Rules for ∇².
# TODO: Call this :nabla2 in ASCII?
(src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²),
Expand Down
Loading
Loading