Skip to content

Commit

Permalink
relu -> tanh in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Aug 7, 2024
1 parent 355d0f6 commit ccca0e6
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
x = randn(rng, Float32, in_dims, 10)

@testset "GCNConv" begin
l = GCNConv(in_dims => out_dims, relu)
l = GCNConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

Expand All @@ -16,7 +16,7 @@
end

@testset "GraphConv" begin
l = GraphConv(in_dims => out_dims, relu)
l = GraphConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

Expand All @@ -26,7 +26,7 @@
end

@testset "EdgeConv" begin
nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims))
nn = Chain(Dense(2*in_dims => 2, tanh), Dense(2 => out_dims))
l = EdgeConv(nn, aggr = +)
test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true)
end
Expand Down
4 changes: 2 additions & 2 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@testset "constructor with names" begin
m = GNNChain(GCNConv(din => d),
LayerNorm(d),
x -> relu.(x),
x -> tanh.(x),
Dense(d, dout))

m2 = GNNChain(enc = m,
Expand All @@ -34,7 +34,7 @@
@testset "constructor with vector" begin
m = GNNChain(GCNConv(din => d),
LayerNorm(d),
x -> relu.(x),
x -> tanh.(x),
Dense(d, dout))
m2 = GNNChain([m.layers...])
@test m2(g, x) == m(g, x)
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ end
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end

l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
l = GraphConv(in_channel => out_channel, tanh, bias = false, aggr = mean)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end
Expand Down
16 changes: 8 additions & 8 deletions test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
end

@testset "Constructor from pairs" begin
layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu),
(:B, :to, :A) => GraphConv(64 => 32, relu));
layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, tanh),
(:B, :to, :A) => GraphConv(64 => 32, tanh));
@test length(layer.etypes) == 2
end

Expand Down Expand Up @@ -95,8 +95,8 @@

@testset "CGConv" begin
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu),
(:B, :to, :A) => CGConv(4 => 2, relu));
layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, tanh),
(:B, :to, :A) => CGConv(4 => 2, tanh));
y = layers(hg, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end
Expand All @@ -111,8 +111,8 @@

@testset "SAGEConv" begin
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +),
(:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +));
layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, tanh, bias = false, aggr = +),
(:B, :to, :A) => SAGEConv(4 => 2, tanh, bias = false, aggr = +));
y = layers(hg, x);
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end
Expand Down Expand Up @@ -152,8 +152,8 @@
@testset "GCNConv" begin
g = rand_bipartite_heterograph((2,3), 6)
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu),
(:B, :to, :A) => GCNConv(4 => 2, relu));
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh),
(:B, :to, :A) => GCNConv(4 => 2, tanh));
y = layers(g, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end
Expand Down
4 changes: 2 additions & 2 deletions test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end
end

@testset "ResGatedGraphConv" begin
resgatedconv = ResGatedGraphConv(in_channel => out_channel, relu)
resgatedconv = ResGatedGraphConv(in_channel => out_channel, tanh)
@test length(resgatedconv(tg, tg.ndata.x)) == S
@test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N)
@test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S
Expand All @@ -147,7 +147,7 @@ end
end

@testset "GraphConv" begin
graphconv = GraphConv(in_channel => out_channel,relu)
graphconv = GraphConv(in_channel => out_channel, tanh)
@test length(graphconv(tg, tg.ndata.x)) == S
@test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N)
@test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S
Expand Down

0 comments on commit ccca0e6

Please sign in to comment.