From ccca0e6ea8ecc08712d2f657f8f83e66cad59869 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 7 Aug 2024 08:54:18 +0200 Subject: [PATCH] relu -> tanh in tests --- GNNLux/test/layers/conv_tests.jl | 6 +++--- test/layers/basic.jl | 4 ++-- test/layers/conv.jl | 2 +- test/layers/heteroconv.jl | 16 ++++++++-------- test/layers/temporalconv.jl | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index d28ae8a7a..ab06c9445 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -6,7 +6,7 @@ x = randn(rng, Float32, in_dims, 10) @testset "GCNConv" begin - l = GCNConv(in_dims => out_dims, relu) + l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @@ -16,7 +16,7 @@ end @testset "GraphConv" begin - l = GraphConv(in_dims => out_dims, relu) + l = GraphConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @@ -26,7 +26,7 @@ end @testset "EdgeConv" begin - nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) + nn = Chain(Dense(2*in_dims => 2, tanh), Dense(2 => out_dims)) l = EdgeConv(nn, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 9a3b6ee9f..2428865ae 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -21,7 +21,7 @@ @testset "constructor with names" begin m = GNNChain(GCNConv(din => d), LayerNorm(d), - x -> relu.(x), + x -> tanh.(x), Dense(d, dout)) m2 = GNNChain(enc = m, @@ -34,7 +34,7 @@ @testset "constructor with vector" begin m = GNNChain(GCNConv(din => d), LayerNorm(d), - x -> relu.(x), + x -> tanh.(x), Dense(d, dout)) m2 = GNNChain([m.layers...]) @test m2(g, x) == m(g, x) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 224b98697..5d624e4a5 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -104,7 +104,7 @@ end test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) end - l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) + l = GraphConv(in_channel => out_channel, tanh, bias = false, aggr = mean) for g in test_graphs test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index d6cfa390c..d9eaf0c7f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -30,8 +30,8 @@ end @testset "Constructor from pairs" begin - layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu), - (:B, :to, :A) => GraphConv(64 => 32, relu)); + layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, tanh), + (:B, :to, :A) => GraphConv(64 => 32, tanh)); @test length(layer.etypes) == 2 end @@ -95,8 +95,8 @@ @testset "CGConv" begin x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu), - (:B, :to, :A) => CGConv(4 => 2, relu)); + layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, tanh), + (:B, :to, :A) => CGConv(4 => 2, tanh)); y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), - (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, tanh, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, tanh, bias = false, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end @@ -152,8 +152,8 @@ @testset "GCNConv" begin g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu), - (:B, :to, :A) => GCNConv(4 => 2, relu)); + layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), + (:B, :to, :A) => GCNConv(4 => 2, tanh)); y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index b55aff808..45c8acf04 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -133,7 +133,7 @@ end end @testset "ResGatedGraphConv" begin - resgatedconv = ResGatedGraphConv(in_channel => out_channel, relu) + resgatedconv = ResGatedGraphConv(in_channel => out_channel, tanh) @test length(resgatedconv(tg, tg.ndata.x)) == S @test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N) @test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S @@ -147,7 +147,7 @@ end end @testset "GraphConv" begin - graphconv = GraphConv(in_channel => out_channel,relu) + graphconv = GraphConv(in_channel => out_channel, tanh) @test length(graphconv(tg, tg.ndata.x)) == S @test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N) @test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S