Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] more layers pt. 3 #471

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, GRUCell,
glorot_uniform, zeros32,
StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -22,9 +25,9 @@ export AGNNConv,
DConv,
GATConv,
GATv2Conv,
# GatedGraphConv,
GatedGraphConv,
GCNConv,
# GINConv,
GINConv,
# GMMConv,
GraphConv,
# MEGNetConv,
Expand Down
72 changes: 64 additions & 8 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
end

LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::GCNConv) = 0
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)

function Base.show(io::IO, l::GCNConv)
Expand Down Expand Up @@ -549,7 +548,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv)
end

LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::SGConv) = 0
LuxCore.outputsize(d::SGConv) = (d.out_dims,)

function Base.show(io::IO, l::SGConv)
Expand All @@ -561,14 +559,72 @@ function Base.show(io::IO, l::SGConv)
print(io, ")")
end

(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) =
l(g, x, edge_weight, ps, st; conv_weight)

function (l::SGConv)(g, x, edge_weight, ps, st;
conv_weight=nothing, )
(l::SGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::SGConv)(g, x, edge_weight, ps, st)
m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.k)
y = GNNlib.sg_conv(m, g, x, edge_weight)
return y, st
end
end

@concrete struct GatedGraphConv <: GNNLayer
gru
init_weight
dims::Int
num_layers::Int
aggr
end


function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init_weight = glorot_uniform)
gru = GRUCell(dims => dims)
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
end

LuxCore.outputsize(l::GatedGraphConv) = (l.dims,)

function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
gru = LuxCore.initialparameters(rng, l.gru)
weight = l.init_weight(rng, l.dims, l.dims, l.num_layers)
return (; gru, weight)
end

LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers


function (l::GatedGraphConv)(g, x, ps, st)
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
return GNNlib.gated_graph_conv(m, g, x), st
end

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end

@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
ϵ <: Real
aggr
end

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)

function (l::GINConv)(g, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
m = (; nn, l.ϵ, l.aggr)
y = GNNlib.gin_conv(m, g, x)
stnew = _getstate(nn)
return y, stnew
end

function Base.show(io::IO, l::GINConv)
print(io, "GINConv($(l.nn)")
print(io, ", $(l.ϵ)")
print(io, ")")
end
12 changes: 11 additions & 1 deletion GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,15 @@
l = SGConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end

@testset "GatedGraphConv" begin
l = GatedGraphConv(in_dims, 3)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,))
end

@testset "GINConv" begin
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end
end
1 change: 1 addition & 0 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
end
Expand Down
35 changes: 17 additions & 18 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_w
if edge_weight !== nothing
# Pad weights with ones
# TODO for ADJMAT_T the new edges are not generally at the end
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down Expand Up @@ -215,23 +215,22 @@ end

####################### GatedGraphConv ######################################

# TODO PIRACY! remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@non_differentiable fill!(x...)

function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real}
check_num_nodes(g, H)
m, n = size(H)
@assert (m<=l.out_ch) "number of input features must less or equals to output features."
if m < l.out_ch
Hpad = similar(H, S, l.out_ch - m, n)
H = vcat(H, fill!(Hpad, 0))
function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m, n = size(x)
@assert m <= l.dims "number of input features must be less or equal to output features."
if m < l.dims
xpad = zeros_like(x, (l.dims - m, n))
x = vcat(x, xpad)
end
h = x
for i in 1:(l.num_layers)
M = view(l.weight, :, :, i) * H
M = propagate(copy_xj, g, l.aggr; xj = M)
H, _ = l.gru(H, M)
m = view(l.weight, :, :, i) * h
m = propagate(copy_xj, g, l.aggr; xj = m)
# in gru forward, hidden state is first argument, input is second
h, _ = l.gru(h, m)
end
return H
return h
end

####################### EdgeConv ######################################
Expand Down Expand Up @@ -419,7 +418,7 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T},
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; onse_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down Expand Up @@ -512,7 +511,7 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T},
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down Expand Up @@ -644,7 +643,7 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T},
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down
16 changes: 8 additions & 8 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing throu
# Arguments

- `out`: The dimension of output features.
- `num_layers`: The number of gated recurrent unit.
- `num_layers`: The number of recursion steps.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `init`: Weight initialization function.

Expand All @@ -510,25 +510,25 @@ y = l(g, x)
struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer
weight::W
gru::R
out_ch::Int
dims::Int
num_layers::Int
aggr::A
end

@functor GatedGraphConv

function GatedGraphConv(out_ch::Int, num_layers::Int;
function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init = glorot_uniform)
w = init(out_ch, out_ch, num_layers)
gru = GRUCell(out_ch, out_ch)
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
w = init(dims, dims, num_layers)
gru = GRUCell(dims => dims)
GatedGraphConv(w, gru, dims, num_layers, aggr)
end


(l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H)

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)")
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end
Expand Down Expand Up @@ -1201,7 +1201,7 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
SGConv(W, b, k, add_self_loops, use_edge_weight)
return SGConv(W, b, k, add_self_loops, use_edge_weight)
end

(l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight)
Expand Down
Loading