Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 30, 2024
1 parent fddb701 commit 67e5536
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 16 deletions.
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export AGNNConv,
EGNNConv,
DConv,
GATConv,
# GATv2Conv,
GATv2Conv,
# GatedGraphConv,
GCNConv,
# GINConv,
Expand Down
121 changes: 110 additions & 11 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,17 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::DConv)
end
end

LuxCore.parameterlength(l::DConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims :
l.in_dims * l.out_dims * l.k
LuxCore.outputsize(l::DConv) = (l.out_dims,)
LuxCore.parameterlength(l::DConv) = l.use_bias ? 2 * l.in_dims * l.out_dims * l.k + l.out_dims :
2 * l.in_dims * l.out_dims * l.k

function (l::DConv)(g, x, ps, st)
m = (; ps.weights, bias = _getbias(ps), l.k)
return GNNlib.d_conv(m, g, x), st
end

function Base.show(io::IO, l::DConv)
print(io, "DConv($(l.in) => $(l.out), k=$(l.k))")
print(io, "DConv($(l.in_dims) => $(l.out_dims), k=$(l.k))")
end

@concrete struct GATConv <: GNNLayer
Expand Down Expand Up @@ -389,25 +390,33 @@ function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
end

# Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a)
LuxCore.outputsize(l::GATConv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],)
##TODO: parameterlength

function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv)
(in, ein), out = l.channel
dense_x = initialparameters(rng, l.dense_x)
a = init_weight(ein > 0 ? 3out : 2out, heads)
dense_x = LuxCore.initialparameters(rng, l.dense_x)
a = l.init_weight(ein > 0 ? 3out : 2out, l.heads)
ps = (; dense_x, a)
if ein > 0
ps = (ps..., dense_e = initialparameters(rng, l.dense_e))
ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e))
end
if use_bias
ps = (ps..., bias = l.init_bias(rng, concat ? out * l.heads : out))
if l.use_bias
ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out))
end
return ps
end

(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::GATConv)(g, x, e, ps, st)
return GNNlib.gat_conv(l, g, x, e), st
function (l::GATConv)(g, x, e, ps, st)
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
dense_e = l.dense_e === nothing ? nothing :
StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e))

m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ,
ps.a, bias = _getbias(ps), dense_x, dense_e, l.negative_slope)
return GNNlib.gat_conv(m, g, x, e), st
end

function Base.show(io::IO, l::GATConv)
Expand All @@ -417,3 +426,93 @@ function Base.show(io::IO, l::GATConv)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end

@concrete struct GATv2Conv <: GNNLayer
dense_i
dense_j
dense_e
init_weight
init_bias
use_bias::Bool
σ
negative_slope
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
dropout
end

function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
end

function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
heads::Int = 1,
concat::Bool = true,
negative_slope = 0.2,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops = true,
dropout=0.0)

(in, ein), out = ch

if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_i = Dense(in => out * heads; use_bias, init_weight, init_bias)
dense_j = Dense(in => out * heads; use_bias = false, init_weight)
if ein > 0
dense_e = Dense(ein => out * heads; use_bias = false, init_weight)
else
dense_e = nothing
end
return GATv2Conv(dense_i, dense_j, dense_e,
init_weight, init_bias, use_bias,
σ, negative_slope,
ch, heads, concat, add_self_loops, dropout)
end


LuxCore.outputsize(l::GATv2Conv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],)
##TODO: parameterlength

function LuxCore.initialparameters(rng::AbstractRNG, l::GATv2Conv)
(in, ein), out = l.channel
dense_i = LuxCore.initialparameters(rng, l.dense_i)
dense_j = LuxCore.initialparameters(rng, l.dense_j)
a = l.init_weight(out, l.heads)
ps = (; dense_i, dense_j, a)
if ein > 0
ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e))
end
if l.use_bias
ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out))
end
return ps
end

(l::GATv2Conv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::GATv2Conv)(g, x, e, ps, st)
dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i))
dense_j = StatefulLuxLayer{true}(l.dense_j, ps.dense_j, _getstate(st, :dense_j))
dense_e = l.dense_e === nothing ? nothing :
StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e))

m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ,
ps.a, bias = _getbias(ps), dense_i, dense_j, dense_e, l.negative_slope)
return GNNlib.gatv2_conv(m, g, x, e), st
end

function Base.show(io::IO, l::GATv2Conv)
(in, ein), out = l.channel
print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end
18 changes: 18 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,26 @@

@testset "GATConv" begin
x = randn(rng, Float32, 6, 10)

l = GATConv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(16,))

l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5)
test_lux_layer(rng, l, g, x, outputsize=(8,))

#TODO test edge
end

@testset "GATv2Conv" begin
x = randn(rng, Float32, 6, 10)

l = GATv2Conv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(16,))

l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5)
test_lux_layer(rng, l, g, x, outputsize=(8,))

#TODO test edge
end
end

6 changes: 4 additions & 2 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ function gat_message(l, Wxi, Wxj, e)
Wxx = vcat(Wxi, Wxj, We)
end
aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges
logα = leakyrelu.(aWW, l.negative_slope)
slope = convert(eltype(aWW), l.negative_slope)
logα = leakyrelu.(aWW, slope)
return (; logα, Wxj)
end

Expand Down Expand Up @@ -207,7 +208,8 @@ function gatv2_message(l, Wxi, Wxj, e)
if e !== nothing
Wx += reshape(l.dense_e(e), out, heads, :)
end
logα = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims = 1) # 1 × heads × nedges
slope = convert(eltype(Wx), l.negative_slope)
logα = sum(l.a .* leakyrelu.(Wx, slope), dims = 1) # 1 × heads × nedges
return (; logα, Wxj)
end

Expand Down
4 changes: 2 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
end
b = bias ? Flux.create_bias(dense_i.weight, true, concat ? out * heads : out) : false
a = init(out, heads)
negative_slope = convert(eltype(dense_i.weight), negative_slope)
GATv2Conv(dense_i, dense_j, dense_e, b, a, σ, negative_slope, ch, heads, concat,
return GATv2Conv(dense_i, dense_j, dense_e,
b, a, σ, negative_slope, ch, heads, concat,
add_self_loops, dropout)
end

Expand Down

0 comments on commit 67e5536

Please sign in to comment.