-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9338ed7
commit 11515eb
Showing
15 changed files
with
293 additions
and
254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
name: GNNlib | ||
on: | ||
pull_request: | ||
branches: | ||
- master | ||
push: | ||
branches: | ||
- master | ||
jobs: | ||
test: | ||
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
version: | ||
- '1.10' # Replace this with the minimum Julia version that your package supports. | ||
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia. | ||
# - 'pre' | ||
os: | ||
- ubuntu-latest | ||
arch: | ||
- x64 | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: julia-actions/setup-julia@v2 | ||
with: | ||
version: ${{ matrix.version }} | ||
arch: ${{ matrix.arch }} | ||
- uses: julia-actions/cache@v2 | ||
- uses: julia-actions/julia-buildpkg@v1 | ||
- name: Install Julia dependencies and run tests | ||
shell: julia --project=monorepo {0} | ||
run: | | ||
using Pkg | ||
# dev mono repo versions | ||
pkg"registry up" | ||
Pkg.update() | ||
pkg"dev ./GNNGraphs ./GNNlib" | ||
Pkg.test("GNNlib"; coverage=true) | ||
- uses: julia-actions/julia-processcoverage@v1 | ||
with: | ||
# directories: ./GNNlib/src, ./GNNlib/ext | ||
directories: ./GNNlib/src | ||
- uses: codecov/codecov-action@v4 | ||
with: | ||
files: lcov.info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
@testitem "msgpass" setup=[SharedTestSetup] begin | ||
#TODO test all graph types | ||
GRAPH_T = :coo | ||
in_channel = 10 | ||
out_channel = 5 | ||
num_V = 6 | ||
num_E = 14 | ||
T = Float32 | ||
|
||
adj = [0 1 0 0 0 0 | ||
1 0 0 1 1 1 | ||
0 0 0 0 0 1 | ||
0 1 0 0 1 0 | ||
0 1 0 1 0 1 | ||
0 1 1 0 1 0] | ||
|
||
X = rand(T, in_channel, num_V) | ||
E = rand(T, in_channel, num_E) | ||
|
||
g = GNNGraph(adj, graph_type = GRAPH_T) | ||
|
||
@testset "propagate" begin | ||
function message(xi, xj, e) | ||
@test xi === nothing | ||
@test e === nothing | ||
ones(T, out_channel, size(xj, 2)) | ||
end | ||
|
||
m = propagate(message, g, +, xj = X) | ||
|
||
@test size(m) == (out_channel, num_V) | ||
|
||
@testset "isolated nodes" begin | ||
x1 = rand(1, 6) | ||
g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6) | ||
y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1) | ||
@test size(y1) == (1, 6) | ||
end | ||
end | ||
|
||
@testset "apply_edges" begin | ||
m = apply_edges(g, e = E) do xi, xj, e | ||
@test xi === nothing | ||
@test xj === nothing | ||
ones(out_channel, size(e, 2)) | ||
end | ||
|
||
@test m == ones(out_channel, num_E) | ||
|
||
# With NamedTuple input | ||
m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e | ||
@test xi === nothing | ||
@test xj.b == 2 * xj.a | ||
@test size(xj.a, 2) == size(xj.b, 2) == size(e, 2) | ||
ones(out_channel, size(e, 2)) | ||
end | ||
|
||
# NamedTuple output | ||
m = apply_edges(g, e = E) do xi, xj, e | ||
@test xi === nothing | ||
@test xj === nothing | ||
(; a = ones(out_channel, size(e, 2))) | ||
end | ||
|
||
@test m.a == ones(out_channel, num_E) | ||
|
||
@testset "sizecheck" begin | ||
x = rand(3, g.num_nodes - 1) | ||
@test_throws AssertionError apply_edges(copy_xj, g, xj = x) | ||
@test_throws AssertionError apply_edges(copy_xj, g, xi = x) | ||
|
||
x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1)) | ||
@test_throws AssertionError apply_edges(copy_xj, g, xj = x) | ||
@test_throws AssertionError apply_edges(copy_xj, g, xi = x) | ||
|
||
e = rand(3, g.num_edges - 1) | ||
@test_throws AssertionError apply_edges(copy_xj, g, e = e) | ||
end | ||
end | ||
|
||
@testset "copy_xj" begin | ||
n = 128 | ||
A = sprand(n, n, 0.1) | ||
Adj = map(x -> x > 0 ? 1 : 0, A) | ||
X = rand(10, n) | ||
|
||
g = GNNGraph(A, ndata = X, graph_type = GRAPH_T) | ||
|
||
function spmm_copyxj_fused(g) | ||
propagate(copy_xj, | ||
g, +; xj = g.ndata.x) | ||
end | ||
|
||
function spmm_copyxj_unfused(g) | ||
propagate((xi, xj, e) -> xj, | ||
g, +; xj = g.ndata.x) | ||
end | ||
|
||
@test spmm_copyxj_unfused(g) ≈ X * Adj | ||
@test spmm_copyxj_fused(g) ≈ X * Adj | ||
end | ||
|
||
@testset "e_mul_xj and w_mul_xj for weighted conv" begin | ||
n = 128 | ||
A = sprand(n, n, 0.1) | ||
Adj = map(x -> x > 0 ? 1 : 0, A) | ||
X = rand(10, n) | ||
|
||
g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T) | ||
|
||
function spmm_unfused(g) | ||
propagate((xi, xj, e) -> reshape(e, 1, :) .* xj, | ||
g, +; xj = g.ndata.x, e = g.edata.e) | ||
end | ||
function spmm_fused(g) | ||
propagate(e_mul_xj, | ||
g, +; xj = g.ndata.x, e = g.edata.e) | ||
end | ||
|
||
function spmm_fused2(g) | ||
propagate(w_mul_xj, | ||
g, +; xj = g.ndata.x) | ||
end | ||
|
||
@test spmm_unfused(g) ≈ X * A | ||
@test spmm_fused(g) ≈ X * A | ||
@test spmm_fused2(g) ≈ X * A | ||
end | ||
|
||
@testset "aggregate_neighbors" begin | ||
@testset "sizecheck" begin | ||
m = rand(2, g.num_edges - 1) | ||
@test_throws AssertionError aggregate_neighbors(g, +, m) | ||
|
||
m = (a = rand(2, g.num_edges + 1), b = nothing) | ||
@test_throws AssertionError aggregate_neighbors(g, +, m) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
using GNNlib | ||
using Test | ||
using ReTestItems | ||
using Random, Statistics | ||
|
||
runtests(GNNlib) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
@testsetup module SharedTestSetup | ||
|
||
import Reexport: @reexport | ||
|
||
@reexport using GNNlib | ||
@reexport using GNNGraphs | ||
@reexport using NNlib | ||
@reexport using MLUtils | ||
@reexport using SparseArrays | ||
@reexport using Test, Random, Statistics | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
@testitem "utils" setup=[SharedTestSetup] begin | ||
# TODO test all graph types | ||
GRAPH_T = :coo | ||
De, Dx = 3, 2 | ||
g = MLUtils.batch([rand_graph(10, 60, bidirected=true, | ||
ndata = rand(Dx, 10), | ||
edata = rand(De, 30), | ||
graph_type = GRAPH_T) for i in 1:5]) | ||
x = g.ndata.x | ||
e = g.edata.e | ||
|
||
@testset "reduce_nodes" begin | ||
r = reduce_nodes(mean, g, x) | ||
@test size(r) == (Dx, g.num_graphs) | ||
@test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) | ||
|
||
r2 = reduce_nodes(mean, graph_indicator(g), x) | ||
@test r2 == r | ||
end | ||
|
||
@testset "reduce_edges" begin | ||
r = reduce_edges(mean, g, e) | ||
@test size(r) == (De, g.num_graphs) | ||
@test r[:, 2] ≈ mean(getgraph(g, 2).edata.e, dims = 2) | ||
end | ||
|
||
@testset "softmax_nodes" begin | ||
r = softmax_nodes(g, x) | ||
@test size(r) == size(x) | ||
@test r[:, 1:10] ≈ softmax(getgraph(g, 1).ndata.x, dims = 2) | ||
end | ||
|
||
@testset "softmax_edges" begin | ||
r = softmax_edges(g, e) | ||
@test size(r) == size(e) | ||
@test r[:, 1:60] ≈ softmax(getgraph(g, 1).edata.e, dims = 2) | ||
end | ||
|
||
@testset "broadcast_nodes" begin | ||
z = rand(4, g.num_graphs) | ||
r = broadcast_nodes(g, z) | ||
@test size(r) == (4, g.num_nodes) | ||
@test r[:, 1] ≈ z[:, 1] | ||
@test r[:, 10] ≈ z[:, 1] | ||
@test r[:, 11] ≈ z[:, 2] | ||
end | ||
|
||
@testset "broadcast_edges" begin | ||
z = rand(4, g.num_graphs) | ||
r = broadcast_edges(g, z) | ||
@test size(r) == (4, g.num_edges) | ||
@test r[:, 1] ≈ z[:, 1] | ||
@test r[:, 60] ≈ z[:, 1] | ||
@test r[:, 61] ≈ z[:, 2] | ||
end | ||
|
||
@testset "softmax_edge_neighbors" begin | ||
s = [1, 2, 3, 4] | ||
t = [5, 5, 6, 6] | ||
g2 = GNNGraph(s, t) | ||
e2 = randn(Float32, 3, g2.num_edges) | ||
z = softmax_edge_neighbors(g2, e2) | ||
@test size(z) == size(e2) | ||
@test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) | ||
@test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) | ||
end | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.