Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floyd-Warshall #70

Open
wants to merge 16 commits into
base: gr/acset2sym
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ name = "DiagrammaticEquations"
uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317"
license = "MIT"
authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"]
version = "0.1.6"
version = "0.1.7"

[deps]
ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
SymbolicUtils = "3.4"
Unicode = "1.6"
julia = "1.6"
9 changes: 7 additions & 2 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""
module DiagrammaticEquations

using Catlab

export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
DerivOp, append_dot, normalize_unicode, infer_states, infer_terminals, infer_types!,
# Deca
op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D,
Expand All @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
Collage, collate,
## composition
oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram,
apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
Expand All @@ -25,12 +28,12 @@ unique_lits!,
Plus, AppCirc1, Var, Tan, App1, App2,
## visualization
to_graphviz_property_graph, typename, draw_composition,
to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!

using Catlab
using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
Expand Down Expand Up @@ -59,6 +62,8 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("graph_traversal.jl")
include("acset2symbolic.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")

Expand Down
46 changes: 21 additions & 25 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ end
# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]
:PVF, :DVF,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const VECTORFIELD_TYPES = [:PVF, :DVF]

const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]
Expand Down Expand Up @@ -427,12 +430,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
end

"""
filterfor_forms(types::AbstractVector{Symbol})
filterfor_ec_types(types::AbstractVector{Symbol})

Return any form type symbols.
Return any form or vector-field type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
function filterfor_ec_types(types::AbstractVector{Symbol})
conditions = x -> !(x in NON_EC_TYPES)
filter(conditions, types)
end

Expand All @@ -447,29 +450,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))
ec_types = unique(filterfor_ec_types(types))

form = @match length(forms) begin
ec_type = @match length(ec_types) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
1 => only(ec_types)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types")
end

for idx in idxs
applied |= safe_modifytype!(d, idx, form)
applied |= safe_modifytype!(d, idx, ec_type)
end

return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
type_src = d[d[op1_id, :src], :type]
type_tgt = d[d[op1_id, :tgt], :type]
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])

score_src = (rule.src_type == type_src)
score_tgt = (rule.tgt_type == type_tgt)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
Expand All @@ -480,19 +480,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
type_proj1 = d[d[op2_id, :proj1], :type]
type_proj2 = d[d[op2_id, :proj2], :type]
type_res = d[d[op2_id, :res], :type]
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])

score_proj1 = (rule.proj1_type == type_proj1)
score_proj2 = (rule.proj2_type == type_proj2)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
end

Expand Down
60 changes: 60 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using DiagrammaticEquations
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name))

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1})
input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
end

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2})
input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name])
input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ})
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...))
# end

function extract_symexprs(d::SummationDecapode)
topo_list = topological_sort_edges(d)
sym_list = []
for node in topo_list
retrieve_name(d, node) != DerivOp || continue
push!(sym_list, to_symbolics(d, node))
end
sym_list
end

function apply_rewrites(d::SummationDecapode, rewriter)

rewritten_list = []
for sym in extract_symexprs(d)
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
end

rewritten_list
end

# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet
# @syms Δ(x) d(x) ⋆(x)
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
# rewriter = Postwalk(RestartedChain([lap_0_rule]))
37 changes: 28 additions & 9 deletions src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ op1_inf_rules_1D = [

# Rules for the averaging operator
(src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]),

# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]),
(src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]),

# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]),

# Rules for magnitude/ norm
(src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]),
(src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])]
(src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]),
(src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])]

op2_inf_rules_1D = [
# Rules for ∧₀₀, ∧₁₀, ∧₀₁
Expand Down Expand Up @@ -83,7 +90,11 @@ op2_inf_rules_1D = [
(proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])]
(proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]),

# These rules contain infer:
(proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]),
(proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])]

"""
These are the default rules used to do type inference in the 2D exterior calculus.
Expand Down Expand Up @@ -133,13 +144,16 @@ op1_inf_rules_2D = [
(src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]),
(src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]),

# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]),
(src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]),

# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]),

# Rules for magnitude/ norm
(src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]),
(src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]),
(src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]),
(src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]),
(src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]),
(src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])]
(src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]),
(src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])]

op2_inf_rules_2D = vcat(op2_inf_rules_1D, [
# Rules for ∧₁₁, ∧₂₀, ∧₀₂
Expand Down Expand Up @@ -243,6 +257,11 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [
(src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ),
(src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif),
(src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif),
# Rules for ♯.
(src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯),
(src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯),
# Rules for ♭.
(src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭),
# Rules for ∇².
# TODO: Call this :nabla2 in ASCII?
(src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²),
Expand Down
104 changes: 104 additions & 0 deletions src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using DiagrammaticEquations
using ACSets

export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name

struct TraversalNode{T}
index::Int
name::T
dom::AbstractVector
cod::Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call these src/tgt since we're talking about hypergraphs.

end

TraversalNode(i, d::SummationDecapode, ::Val{:Op1}) =
TraversalNode{Symbol}(i, d[i,:op1], [d[i,:src]], d[i,:tgt])

TraversalNode(i, d::SummationDecapode, ::Val{:Op2}) =
TraversalNode{Symbol}(i, d[i,:op2], [d[i,:proj1],d[i,:proj2]], d[i,:res])

TraversalNode(i, d::SummationDecapode, ::Val{:Σ}) =
TraversalNode{Symbol}(i, :+, d[incident(d,i,:summation),:summand], d[i,:sum])

retrieve_name(d::SummationDecapode, tsr::TraversalNode) = tsr.name

# Induce a topological ordering of operations from a topological ordering of variables.
# Taking Vᵢ(dom(e)ᵢ) like so is a structure preserving map.
edge2cost(tsv, tsr::TraversalNode) = maximum(tsv[tsr.dom])

number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ)

start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type))

Check warning on line 30 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L30

Added line #L30 was not covered by tests

# TODO: This could be Catlab'd. Hypergraph category? Migration to a DWD?
""" function hyper_edge_list(d::SummationDecapode)

Represent a Decapode as a directed hyper-edge list.

Interpret a:
- unary operation as a hyperedge of order (1,1) ,
- binary operation as a hyperedge of order (2,1) , and
- summation as a hyperedge of order (|summands|,1) .
"""
function hyper_edge_list(d::SummationDecapode)
[map(e -> TraversalNode(e, d, Val(:Op1)), parts(d, :Op1))...,
map(e -> TraversalNode(e, d, Val(:Op2)), parts(d, :Op2))...,
map(e -> TraversalNode(e, d, Val(:Σ )), parts(d, :Σ ))...]
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a function that returns all of the "edge" symbols, so in this case Op1, Op2, Σ and just iterate over those with the interior functionality. Can also do this with all the "vertex" symbols.


""" function floyd_warshall(d::SummationDecapode)

Return a |variable| × |variable| matrix of shortest paths via the Floyd-Warshall algorithm.

Taking the maximum of the non-infinite short paths from state variables induces a topological ordering.

https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm
"""
function floyd_warshall(d::SummationDecapode)
# Define weights.
w(e) = (length(e.dom) == 1 && e.name ∈ [:∂ₜ,:dt]) ? -Inf : -1
# Init dists
V = nparts(d, :Var)
dist = fill(Inf, (V, V))
foreach(hyper_edge_list(d)) do e
dist[(e.dom), e.cod] .= w(e)
end
for v in 1:V
dist[v,v] = 0
end
# Floyd-Warshall
for k in 1:V
for i in 1:V
for j in 1:V
if dist[i,j] > dist[i,k] + dist[k,j]
dist[i,j] = dist[i,k] + dist[k,j]
end
end
end
end
dist

Check warning on line 78 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L78

Added line #L78 was not covered by tests
end

""" function topological_sort_verts(d::SummationDecapode)

Topologically sort the variables in a Decapode.

The vector returned by this function maps each vertex to the order that it would be traversed in a topological sort traversal. If you want a list of vertices in the order of traversal, call `sortperm` on the output.
"""
function topological_sort_verts(d::SummationDecapode)
m = floyd_warshall(d)
map(parts(d,:Var)) do v
minimum(filter(!isinf, m[v,infer_terminals(d)]))
end
end

# TODO: Add in-place version for sorting a given hyperedge list.
""" function topological_sort_edges(d::SummationDecapode)

Topologically sort the edges in a Decapode.
"""
function topological_sort_edges(d::SummationDecapode)
tsv = topological_sort_verts(d)
op_order = hyper_edge_list(d)
sort(op_order, by = x -> edge2cost(tsv,x))
end

Loading
Loading