Skip to content

Commit

Permalink
testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 7, 2024
1 parent 3296f2e commit 73a3d0e
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 93 deletions.
4 changes: 2 additions & 2 deletions GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}}
ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets

include("test_utils.jl")
"""

tests = [
"chainrules",
"datastore",
Expand All @@ -39,7 +39,7 @@ tests = [
"mldatasets",
"ext/SimpleWeightedGraphs"
]
"""

!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

for graph_type in (:coo, :dense, :sparse)
Expand Down
3 changes: 3 additions & 0 deletions GNNLux/test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testitem "layers/basic" setup=[SharedTestSetup] begin
@test 1==1
"""
rng = StableRNG(17)
g = rand_graph(10, 40, seed=17)
x = randn(rng, Float32, 3, 10)
Expand All @@ -16,4 +18,5 @@
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
test_lux_layer(rng, c, g, x, outputsize=(3,), container=true)
end
"""
end
90 changes: 0 additions & 90 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,7 @@
in_dims = 3
out_dims = 5
x = randn(rng, Float32, in_dims, 10)
"""
@testset "GCNConv" begin
l = GCNConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
@testset "ChebConv" begin
l = ChebConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
@testset "GraphConv" begin
l = GraphConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
test_lux_layer(rng, l, g, x, sizey=(in_dims, 10))
end
@testset "EdgeConv" begin
nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims))
l = EdgeConv(nn, aggr = +)
test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true)
end
@testset "CGConv" begin
l = CGConv(in_dims => in_dims, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true)
end
@testset "DConv" begin
l = DConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "EGNNConv" begin
hin = 6
hout = 7
hidden = 8
l = EGNNConv(hin => hout, hidden)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
h = randn(rng, Float32, hin, g.num_nodes)
(hnew, xnew), stnew = l(g, h, x, ps, st)
@test size(hnew) == (hout, g.num_nodes)
@test size(xnew) == (in_dims, g.num_nodes)
end
"""
@testset "MEGNetConv" begin
in_dims = 6
out_dims = 8
Expand All @@ -68,45 +19,4 @@
@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)
end
"""
@testset "GATConv" begin
x = randn(rng, Float32, 6, 10)
l = GATConv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(16,))
l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5)
test_lux_layer(rng, l, g, x, outputsize=(8,))
#TODO test edge
end
@testset "GATv2Conv" begin
x = randn(rng, Float32, 6, 10)
l = GATv2Conv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(16,))
l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5)
test_lux_layer(rng, l, g, x, outputsize=(8,))
#TODO test edge
end
@testset "SGConv" begin
l = SGConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
@testset "GatedGraphConv" begin
l = GatedGraphConv(in_dims, 3)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,))
end
@testset "GINConv" begin
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end
"""
end
2 changes: 1 addition & 1 deletion GNNlib/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ using Test
using ReTestItems
using Random, Statistics

#runtests(GNNlib)
runtests(GNNlib)

0 comments on commit 73a3d0e

Please sign in to comment.