Skip to content

Commit

Permalink
First draft
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Jul 29, 2024
1 parent fc67808 commit fbbe84d
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 5 deletions.
4 changes: 3 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ export AGNNConv,
# SGConv,
# TAGConv,
# TransformerConv


include("layers/temporalconv.jl")
export TGCNCell

end #module

26 changes: 26 additions & 0 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
in_dims::Int
out_dims::Int
conv
gru
end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru)
end

LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)

function (l::TGCNCell)(h, g, x)
conv = StatefulLuxLayer{true}(l.conv, ps.conv, _getstate(st, :conv))
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
m = (; conv, gru)
return GNNlib.tgcn_conv(m, h, g, x)
end

function Base.show(io::IO, tgcn::TGCNCell)
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
end
2 changes: 1 addition & 1 deletion GNNlib/src/GNNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export agnn_conv,
transformer_conv

include("layers/temporalconv.jl")
export a3tgcn_conv
export tgcn_conv

include("layers/pool.jl")
export global_pool,
Expand Down
9 changes: 9 additions & 0 deletions GNNlib/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
####################### TGCN ######################################

function tgcn_conv(l, h, g::GNNGraph, x::AbstractArray)
= l.conv(g, x)
h, x̃ = l.gru(h, x̃)
return h, x̃
end


function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray)
h = a3tgcn.tgcn(g, x)
e = a3tgcn.dense1(h)
Expand Down
4 changes: 1 addition & 3 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ function TGCNCell(ch::Pair{Int, Int};
end

function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray)
= tgcn.conv(g, x)
h, x̃ = tgcn.gru(h, x̃)
return h, x̃
return GNNlib.tgcn_conv(tgcn, h, g, x)
end

function Base.show(io::IO, tgcn::TGCNCell)
Expand Down

0 comments on commit fbbe84d

Please sign in to comment.