diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 0d0de25..e708043 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -51,6 +51,7 @@ DerivOp = Symbol("∂ₜ") append_dot(s::Symbol) = Symbol(string(s)*'\U0307') include("acset.jl") +include("query.jl") include("language.jl") include("composition.jl") include("collages.jl") diff --git a/src/acset.jl b/src/acset.jl index 9665afd..ed97543 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -235,10 +235,7 @@ See also: [`infer_terminals`](@ref). """ function infer_states(d::SummationDecapode) parentless = filter(parts(d, :Var)) do v - length(incident(d, v, :tgt)) == 0 && - length(incident(d, v, :res)) == 0 && - length(incident(d, v, :sum)) == 0 && - d[v, :type] != :Literal + !is_var_target(d, v) && d[v, :type] != :Literal end parents_of_tvars = union(d[incident(d,:∂ₜ, :op1), :src], @@ -259,10 +256,7 @@ See also: [`infer_states`](@ref). """ function infer_terminals(d::SummationDecapode) filter(parts(d, :Var)) do v - length(incident(d, v, :src)) == 0 && - length(incident(d, v, :proj1)) == 0 && - length(incident(d, v, :proj2)) == 0 && - length(incident(d, v, :summand)) == 0 + !is_var_source(d, v) end end diff --git a/src/deca/Deca.jl b/src/deca/Deca.jl index 8202704..b5b19cb 100644 --- a/src/deca/Deca.jl +++ b/src/deca/Deca.jl @@ -10,6 +10,7 @@ export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, s include("deca_acset.jl") include("deca_visualization.jl") +include("deca_query.jl") """ function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64}) diff --git a/src/deca/deca_query.jl b/src/deca/deca_query.jl new file mode 100644 index 0000000..b4c0816 --- /dev/null +++ b/src/deca/deca_query.jl @@ -0,0 +1,25 @@ +using DiagrammaticEquations +using ACSets + +export is_var_target, is_var_source, get_variable_parents, get_next_op1s, get_next_op2s + +function is_var_target(d::SummationDecapode, var::Int) + return !isempty(collected_incident(d, var, [:tgt, :res, :sum])) +end + +function is_var_source(d::SummationDecapode, var::Int) + return !isempty(collected_incident(d, var, [:src, :proj1, :proj2, :summand])) +end + +function get_variable_parents(d::SummationDecapode, var::Int) + return collected_incident(d, var, [:tgt, :res, :res, [:summation, :sum]], [:src, :proj1, :proj2, :summand]) +end + +function get_next_op1s(d::SummationDecapode, var::Int) + collected_incident(d, var, [:src]) +end + +function get_next_op2s(d::SummationDecapode, var::Int) + collected_incident(d, var, [:proj1, :proj2]) +end + diff --git a/src/query.jl b/src/query.jl new file mode 100644 index 0000000..3522720 --- /dev/null +++ b/src/query.jl @@ -0,0 +1,41 @@ +using DiagrammaticEquations +using ACSets + +export collected_incident + +function collected_incident(d::ACSet, searches::AbstractVector, args...) + + isempty(searches) && error("Cannot have an empty search") + + query_result = mapreduce(vcat, searches) do search + collected_incident(d, search, args...) + end + + return unique!(query_result) +end + +function collected_incident(d::ACSet, search, lookup_array) + numof_channels = length(lookup_array) + empty_outputchannels = fill(nothing, numof_channels) + return collected_incident(d, search, lookup_array, empty_outputchannels) +end + + +function collected_incident(d::ACSet, search, lookup_array, output_array) + length(lookup_array) == length(output_array) || error("Input and output channels are different lengths") + isempty(lookup_array) && error("Cannot have an empty lookup") + + query_result = mapreduce(vcat, zip(lookup_array, output_array)) do (lookup, output) + runincident_output_result(d, search, lookup, output) + end + + return unique!(query_result) +end + +function runincident_output_result(d::ACSet, search, lookup::Union{Symbol, AbstractVector{Symbol}}, output_channel::Union{Symbol, Nothing}) + index_result = incident(d, search, lookup) + isnothing(output_channel) ? index_result : d[index_result, output_channel] +end + + + \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index f0c5c02..e1f3a13 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -25,7 +25,7 @@ trivial_comp_from_vector = oapply(trivial_relation, [otrivial]) trivial_comp_from_single = oapply(trivial_relation, otrivial) # Test the oapply is correct. -@test apex(trivial_comp_from_vector) == Trivial +@test apex(trivial_comp_from_vector) == Trivial @test apex(trivial_comp_from_single) == Trivial # Test none of the decapodes were mutated @test isequal(otrivial, deep_copies) diff --git a/test/deca_query.jl b/test/deca_query.jl new file mode 100644 index 0000000..cf8ae68 --- /dev/null +++ b/test/deca_query.jl @@ -0,0 +1,111 @@ +using Test +using DiagrammaticEquations +using DiagrammaticEquations.Deca +using ACSets + +function array_contains_same(test, expected) + sort(test) == sort(expected) +end + +get_index_from_name(d::SummationDecapode, varname::Symbol) = only(incident(d, varname, :name)) + +@testset "Check sources and targets" begin + singleton_deca = @decapode begin + V::infer + end + + @test !is_var_source(singleton_deca, 1) + @test !is_var_target(singleton_deca, 1) + + + path_op1_deca = @decapode begin + (X,Z)::infer + X == d(d(Z)) + end + + idxX = get_index_from_name(path_op1_deca, :X) + idxZ = get_index_from_name(path_op1_deca, :Z) + + @test is_var_source(path_op1_deca, idxZ) + @test !is_var_target(path_op1_deca, idxZ) + + @test !is_var_source(path_op1_deca, idxX) + @test is_var_target(path_op1_deca, idxX) + + + path_op2_deca = @decapode begin + X == ∧(Y,Z) + end + + idxX = get_index_from_name(path_op2_deca, :X) + idxY = get_index_from_name(path_op2_deca, :Y) + idxZ = get_index_from_name(path_op2_deca, :Z) + + idxsYZ = [idxY, idxZ] + + for idx in idxsYZ + @test is_var_source(path_op2_deca, idx) + @test !is_var_target(path_op2_deca, idx) + end + @test !is_var_source(path_op2_deca, idxX) + @test is_var_target(path_op2_deca, idxX) + + + path_sum_deca = @decapode begin + X == Y + Z + end + + idxX = get_index_from_name(path_sum_deca, :X) + idxY = get_index_from_name(path_sum_deca, :Y) + idxZ = get_index_from_name(path_sum_deca, :Z) + + idxsYZ = [idxY, idxZ] + + for idx in idxsYZ + @test is_var_source(path_sum_deca, idx) + @test !is_var_target(path_sum_deca, idx) + end + @test !is_var_source(path_sum_deca, idxX) + @test is_var_target(path_sum_deca, idxX) + + mixedop_deca = @decapode begin + Inter == d(X) + ∧(Y, Z) + Res == d(Inter) + end + + idxX = get_index_from_name(mixedop_deca, :X) + idxY = get_index_from_name(mixedop_deca, :Y) + idxZ = get_index_from_name(mixedop_deca, :Z) + idxInter = get_index_from_name(mixedop_deca, :Inter) + idxRes = get_index_from_name(mixedop_deca, :Res) + + @test is_var_source(mixedop_deca, idxX) + @test is_var_source(mixedop_deca, idxY) + @test is_var_source(mixedop_deca, idxZ) + + @test is_var_target(mixedop_deca, idxRes) + + @test is_var_target(mixedop_deca, idxInter) && is_var_source(mixedop_deca, idxInter) +end + +# TODO: Finish writing these tests +@testset "Get states and terminals" begin + singleton_deca = @decapode begin + V::Form1 + end + @test infer_state_names(singleton_deca) == infer_terminal_names(singleton_deca) + + path_op1_deca = @decapode begin + (X,Z)::infer + X == d(d(Z)) + end + @test array_contains_same(infer_state_names(path_op1_deca), [:Z]) + @test array_contains_same(infer_terminal_names(path_op1_deca), [:X]) + + path_op2_deca = @decapode begin + (X,Y,Z)::infer + X == ∧(Y,Z) + end + @test array_contains_same(infer_state_names(path_op2_deca), [:Y, :Z]) + @test array_contains_same(infer_terminal_names(path_op2_deca), [:X]) +end \ No newline at end of file diff --git a/test/query.jl b/test/query.jl new file mode 100644 index 0000000..705d13e --- /dev/null +++ b/test/query.jl @@ -0,0 +1,261 @@ +using Test +using ACSets +using DiagrammaticEquations + +# Prevent output order invariance +function check_queryoutput(query, expected) + @test sort(query) == sort(expected) + @test allunique(query) +end + +SchTestBasicQueryACSet = BasicSchema([:src,:tgt], [(:f,:src,:tgt)]) +@acset_type TestBasicQueryACSet(SchTestBasicQueryACSet, index=[:f]) + +SchTestDeepQueryACSet = BasicSchema([:Lvl1, :Lvl2, :Lvl3], [(:Map1,:Lvl1, :Lvl2), (:Map2,:Lvl2, :Lvl3)]) +@acset_type TestDeepQueryACSet(SchTestDeepQueryACSet, index=[:Map1, :Map2]) + +SchTestMultiTableQueryACSet = BasicSchema([:x, :y, :a, :b], [(:f,:x,:y), (:g,:a,:b)]) +@acset_type TestMultiTableQueryACSet(SchTestMultiTableQueryACSet, index=[:f, :g]) + +SchTestDecGraphQueryACSet = BasicSchema([:E,:V], [(:src,:E,:V),(:tgt,:E,:V)], [:X], [(:dec,:E,:X)]) +@acset_type TestDecGraphQueryACSet(SchTestDecGraphQueryACSet, index=[:src,:tgt]) + +@testset "Basic single queries" begin + singlesrctgt_example = @acset TestBasicQueryACSet begin + src = 1 + tgt = 1 + f = [1] + end + + check_queryoutput(collected_incident(singlesrctgt_example, 1, [:f]), [1]) + check_queryoutput(collected_incident(singlesrctgt_example, 1, [:f], [:f]), [1]) + + doublesrctgt_example = @acset TestBasicQueryACSet begin + src = 2 + tgt = 2 + f = [2, 1] + end + + check_queryoutput(collected_incident(doublesrctgt_example, 1, [:f]), [2]) + check_queryoutput(collected_incident(doublesrctgt_example, 1, [:f], [:f]), [1]) + + check_queryoutput(collected_incident(doublesrctgt_example, 2, [:f]), [1]) + check_queryoutput(collected_incident(doublesrctgt_example, 2, [:f], [:f]), [2]) + + noresult_example = @acset TestBasicQueryACSet begin + src = 1 + tgt = 2 + f = [1] + end + + @test isempty(collected_incident(noresult_example, 2, [:f])) + @test isempty(collected_incident(noresult_example, 2, [:f], [:f])) + + multipleresult_example = @acset TestBasicQueryACSet begin + src = 3 + tgt = 2 + f = [1, 2, 2] + end + + check_queryoutput(collected_incident(multipleresult_example, 2, [:f]), [2, 3]) + check_queryoutput(collected_incident(multipleresult_example, 2, [:f], [:f]), [2]) + + # Check that using arrays does not affect query + check_queryoutput(collected_incident(multipleresult_example, [2], [:f]), [2, 3]) + check_queryoutput(collected_incident(multipleresult_example, [2], [:f], [:f]), [2]) +end + +@testset "Deep single queries" begin + singlepath_example = @acset TestDeepQueryACSet begin + Lvl1=1 + Lvl2=1 + Lvl3=1 + Map1=[1] + Map2=[1] + end + + check_queryoutput(collected_incident(singlepath_example, 1, [[:Map1, :Map2]]), [1]) + check_queryoutput(collected_incident(singlepath_example, 1, [[:Map1, :Map2]], [:Map1]), [1]) + + manysinglepaths_example = @acset TestDeepQueryACSet begin + Lvl1=2 + Lvl2=2 + Lvl3=2 + Map1=[2,1] + Map2=[2,1] + end + + check_queryoutput(collected_incident(manysinglepaths_example, 2, [[:Map1, :Map2]]), [2]) + check_queryoutput(collected_incident(manysinglepaths_example, 2, [[:Map1, :Map2]], [:Map1]), [1]) + + manyresults_firstquery_example = @acset TestDeepQueryACSet begin + Lvl1=3 + Lvl2=3 + Lvl3=2 + Map1=[3, 2, 1] + Map2=[1, 2, 2] + end + + check_queryoutput(collected_incident(manyresults_firstquery_example, 2, [[:Map1, :Map2]]), [1, 2]) + check_queryoutput(collected_incident(manyresults_firstquery_example, 2, [[:Map1, :Map2]], [:Map1]), [3, 2]) + + manyresults_lastquery_example = @acset TestDeepQueryACSet begin + Lvl1=3 + Lvl2=2 + Lvl3=3 + Map1=[1, 2, 2] + Map2=[1, 2] + end + + check_queryoutput(collected_incident(manyresults_lastquery_example, 2, [[:Map1, :Map2]]), [2,3]) + check_queryoutput(collected_incident(manyresults_lastquery_example, 2, [[:Map1, :Map2]], [:Map1]), [2]) + + manyresults_allqueries_example = @acset TestDeepQueryACSet begin + Lvl1=3 + Lvl2=3 + Lvl3=2 + Map1=[1, 2, 2] + Map2=[2, 2, 1] + end + + check_queryoutput(collected_incident(manyresults_allqueries_example, 2, [[:Map1, :Map2]]), [1,2,3]) +end + +@testset "Multi-Table queries" begin + presentinboth_example = @acset TestMultiTableQueryACSet begin + x=1 + y=1 + a=1 + b=1 + f=[1] + g=[1] + end + + check_queryoutput(collected_incident(presentinboth_example, 1, [:f, :g]), [1]) + + multires_example = @acset TestMultiTableQueryACSet begin + x=3 + y=2 + a=2 + b=1 + f=[1,1,2] + g=[1,1] + end + check_queryoutput(collected_incident(multires_example, 1, [:f, :g]), [1,2]) + # Check that querying works if no results in one table + check_queryoutput(collected_incident(multires_example, 2, [:f, :g]), [3]) + +end + +@testset "Combined queries" begin + doublesrctgt_example = @acset TestBasicQueryACSet begin + src = 2 + tgt = 2 + f = [2, 1] + end + + check_queryoutput(collected_incident(doublesrctgt_example, [1, 2], [:f]), [2, 1]) + check_queryoutput(collected_incident(doublesrctgt_example, [1, 2], [:f], [:f]), [1, 2]) + + # Check that input order does not affect queries + check_queryoutput(collected_incident(doublesrctgt_example, [2, 1], [:f]), [2, 1]) + check_queryoutput(collected_incident(doublesrctgt_example, [2, 1], [:f], [:f]), [1, 2]) + + manyresults_deep_example = @acset TestDeepQueryACSet begin + Lvl1=3 + Lvl2=3 + Lvl3=3 + Map1=[2, 1, 3] + Map2=[3, 2, 1] + end + + check_queryoutput(collected_incident(manyresults_deep_example, [1, 2], [[:Map1, :Map2]]), [1, 3]) + + stargraph_example = @acset TestDecGraphQueryACSet{Symbol} begin + V = 4 + E = 3 + src = [2,3,4] + tgt = [1,1,1] + dec = [:a, :b, :c] + end + names_ofedges_tgtcenter = collected_incident(stargraph_example, 1, [:tgt], [:dec]) + check_queryoutput(names_ofedges_tgtcenter, [:a, :b, :c]) + names_ofedges_srcrest = collected_incident(stargraph_example, [2,3,4], [:src], [:dec]) + check_queryoutput(names_ofedges_srcrest, names_ofedges_tgtcenter) + + #Collect vertices of edge named :a + check_queryoutput(collected_incident(stargraph_example, :a, [:dec, :dec], [:src, :tgt]), [1,2]) + + clustergraph_example = @acset TestDecGraphQueryACSet{Symbol} begin + V = 6 + E = 7 + src = [1,2,3,4,5,6,1] + tgt = [2,3,1,5,6,4,4] + dec = [:a, :b, :c, :d, :e, :f, :g] + end + + names_ofedges_src1 = collected_incident(clustergraph_example, 1, [:src], [:dec]) + check_queryoutput(names_ofedges_src1, [:a, :g]) + names_ofedges_withvertex1 = collected_incident(clustergraph_example, 1, [:src, :tgt], [:dec, :dec]) + check_queryoutput(names_ofedges_withvertex1, [:a, :c, :g]) + + tgts_with_src1 = collected_incident(clustergraph_example, 1, [:src], [:tgt]) + check_queryoutput(collected_incident(clustergraph_example, tgts_with_src1, [:src], [:tgt]), [3, 5]) + + # Get vertex 4 neighbors + neighbors_of4 = collected_incident(clustergraph_example, 4, [:src, :tgt], [:tgt, :src]) + check_queryoutput(neighbors_of4, [1,5,6]) + + treegraph_exaxmple = @acset TestDecGraphQueryACSet{Symbol} begin + V = 5 + E = 4 + src = [1,2,3,5] + tgt = [2,3,4,4] + dec = [:a, :b, :c, :d] + end + + distance2_from4 = [4] + distance1_from4 = collected_incident(treegraph_exaxmple, 4, [:tgt], [:src]) + check_queryoutput(distance1_from4, [3,5]) + distance2_from4 = collected_incident(treegraph_exaxmple, distance1_from4, [:tgt], [:src]) + check_queryoutput(distance2_from4, [2]) +end + +@testset "Edge case queries" begin + singlesrctgt_example = @acset TestBasicQueryACSet begin + src = 1 + tgt = 1 + f = [1] + end + + @test_throws "empty lookup" collected_incident(singlesrctgt_example, 1, []) + @test_throws "empty lookup" isempty(collected_incident(singlesrctgt_example, 1, [], [])) + @test_throws "empty search" isempty(collected_incident(singlesrctgt_example, [], [], [])) + @test_throws "empty search" collected_incident(singlesrctgt_example, [], [:f], [:f]) + + @test_throws "different lengths" collected_incident(singlesrctgt_example, 1, [], [:f]) + @test_throws "different lengths" collected_incident(singlesrctgt_example, 1, [:f], [:f, :f]) + @test_throws "different lengths" collected_incident(singlesrctgt_example, 1, [:f], []) +end + +@testset "Proper result typing" begin + singlesrctgt_example = @acset TestBasicQueryACSet begin + src = 1 + tgt = 1 + f = [1] + end + + @test collected_incident(singlesrctgt_example, 1, [:f]) isa Vector{Int} + @test collected_incident(singlesrctgt_example, 1, [:f], [:f]) isa Vector{Int} + + stargraph_example = @acset TestDecGraphQueryACSet{Symbol} begin + V = 4 + E = 3 + src = [2,3,4] + tgt = [1,1,1] + dec = [:a, :b, :c] + end + @test collected_incident(stargraph_example, 1, [:tgt], [:dec]) isa Vector{Symbol} + @test collected_incident(stargraph_example, 1, [:tgt]) isa Vector{Int} + @test collected_incident(stargraph_example, 1, [:tgt, :tgt], [:src, :dec]) isa Vector{Any} +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0cb0f30..af1cd90 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,11 @@ include("aqua.jl") include("core.jl") end +@testset "Querying" begin + include("query.jl") + include("deca_query.jl") +end + @testset "Composition" begin include("composition.jl") end @@ -38,4 +43,4 @@ end @testset "Open Operators" begin include("openoperators.jl") -end +end \ No newline at end of file