Skip to content

Commit

Permalink
@functor ->@layer (#484)
Browse files Browse the repository at this point in the history
* `@functor` ->`@layer`

* reinstate tests
  • Loading branch information
CarloLucibello authored Aug 8, 2024
1 parent 6b58b75 commit 87f3c60
Show file tree
Hide file tree
Showing 16 changed files with 49 additions and 53 deletions.
2 changes: 1 addition & 1 deletion GNNlib/src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct GNNConv <: GNNLayer
σ
end
Flux.@functor GNNConv
Flux.@layer GNNConv
function GNNConv(ch::Pair{Int,Int}, σ=identity)
in, out = ch
Expand Down
2 changes: 1 addition & 1 deletion docs/src/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
σ::F
end

Flux.@functor GCN # allow gpu movement, select trainable params etc...
Flux.@layer GCN # allow gpu movement, select trainable params etc...

function GCN(ch::Pair{Int,Int}, σ=identity)
in, out = ch
Expand Down
4 changes: 2 additions & 2 deletions docs/src/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and the *implicit modeling* style based on [`GNNChain`](@ref), more concise but
In the explicit modeling style, the model is created according to the following steps:

1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
2. Apply `Flux.@layer` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
3. Optionally define a convenience constructor for your model.
4. Define the forward pass by implementing the call method for your type.
5. Instantiate the model.
Expand All @@ -30,7 +30,7 @@ struct GNN # step 1
dense
end

Flux.@functor GNN # step 2
Flux.@layer GNN # step 2

function GNN(din::Int, d::Int, dout::Int) # step 3
GNN(GCNConv(din => d),
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ begin
layers::NamedTuple
end

Flux.@functor GCN # provides parameter collection, gpu movement and more
Flux.@layer GCN # provides parameter collection, gpu movement and more

function GCN(num_features, num_classes)
layers = (conv1 = GCNConv(num_features => 4),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ begin
layers::NamedTuple
end

Flux.@functor MLP
Flux.@layer :expand MLP

function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5)
layers = (hidden = Dense(num_features => hidden_channels),
Expand Down Expand Up @@ -235,7 +235,7 @@ begin
layers::NamedTuple
end

Flux.@functor GCN # provides parameter collection, gpu movement and more
Flux.@layer GCN # provides parameter collection, gpu movement and more

function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5)
layers = (conv1 = GCNConv(num_features => hidden_channels),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ begin
dense::Dense
end

Flux.@functor GenderPredictionModel
Flux.@layer GenderPredictionModel

function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu)
mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation))
Expand Down
2 changes: 1 addition & 1 deletion examples/graph_classification_temporalbrains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct GenderPredictionModel
dense::Dense
end

Flux.@functor GenderPredictionModel
Flux.@layer GenderPredictionModel

function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu)
mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation))
Expand Down
2 changes: 1 addition & 1 deletion notebooks/gnn_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@
" layers::NamedTuple\n",
"end\n",
"\n",
"Flux.@functor GCN # provides parameter collection, gpu movement and more\n",
"Flux.@layer :expand GCN # provides parameter collection, gpu movement and more\n",
"\n",
"function GCN(num_features, num_classes)\n",
" layers = (conv1 = GCNConv(num_features => 4),\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/graph_classification_solved.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@
"\tact::F\n",
"end\n",
"\n",
"Flux.@functor MyConv\n",
"Flux.@layer MyConv\n",
"\n",
"function MyConv((nin, nout)::Pair, act=identity)\n",
"\tW1 = Flux.glorot_uniform(nout, nin)\n",
Expand Down
2 changes: 1 addition & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module GraphNeuralNetworks
using Statistics: mean
using LinearAlgebra, Random
using Flux
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch
using Flux: glorot_uniform, leakyrelu, GRUCell, batch
using MacroTools: @forward
using NNlib
using NNlib: scatter, gather
Expand Down
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end

WithGraph(model, g::GNNGraph; traingraph = false) = WithGraph(model, g, traingraph)

@functor WithGraph
Flux.@layer :expand WithGraph
Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model)

(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
Expand Down Expand Up @@ -107,7 +107,7 @@ struct GNNChain{T <: Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer
layers::T
end

@functor GNNChain
Flux.@layer :expand GNNChain

GNNChain(xs...) = GNNChain(xs)

Expand Down
51 changes: 25 additions & 26 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer
use_edge_weight::Bool
end

@functor GCNConv
Flux.@layer GCNConv

function GCNConv(ch::Pair{Int, Int}, σ = identity;
init = glorot_uniform,
Expand Down Expand Up @@ -167,7 +167,7 @@ function ChebConv(ch::Pair{Int, Int}, k::Int;
ChebConv(W, b, k)
end

@functor ChebConv
Flux.@layer ChebConv

(l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x)

Expand Down Expand Up @@ -225,7 +225,7 @@ struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer
aggr::A
end

@functor GraphConv
Flux.@layer GraphConv

function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +,
init = glorot_uniform, bias::Bool = true)
Expand Down Expand Up @@ -300,8 +300,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea
y = l(g, x)
```
"""
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMatrix, F, B} <:
GNNLayer
struct GATConv{DX<:Dense,DE<:Union{Dense, Nothing},DV,T,A<:AbstractMatrix,F,B} <: GNNLayer
dense_x::DX
dense_e::DE
bias::B
Expand All @@ -315,8 +314,8 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMat
dropout::DV
end

@functor GATConv
Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a)
Flux.@layer GATConv
Flux.trainable(l::GATConv) = (; l.dense_x, l.dense_e, l.bias, l.a)

GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)

Expand Down Expand Up @@ -420,7 +419,7 @@ struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer
dropout::DV
end

@functor GATv2Conv
Flux.@layer GATv2Conv
Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a)

function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
Expand Down Expand Up @@ -515,7 +514,7 @@ struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer
aggr::A
end

@functor GatedGraphConv
Flux.@layer GatedGraphConv

function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init = glorot_uniform)
Expand Down Expand Up @@ -572,7 +571,7 @@ struct EdgeConv{NN, A} <: GNNLayer
aggr::A
end

@functor EdgeConv
Flux.@layer :expand EdgeConv

EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)

Expand Down Expand Up @@ -626,7 +625,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer
aggr::A
end

@functor GINConv
Flux.@layer :expand GINConv
Flux.trainable(l::GINConv) = (nn = l.nn,)

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
Expand Down Expand Up @@ -680,7 +679,7 @@ edim = 10
g = GNNGraph(s, t)
# create dense layer
nn = Dense(edim, out_channel * in_channel)
nn = Dense(edim => out_channel * in_channel)
# create layer
l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +)
Expand All @@ -697,7 +696,7 @@ struct NNConv{W, B, NN, F, A} <: GNNLayer
aggr::A
end

@functor NNConv
Flux.@layer :expand NNConv

function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true,
init = glorot_uniform)
Expand Down Expand Up @@ -763,7 +762,7 @@ struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer
aggr::A
end

@functor SAGEConv
Flux.@layer SAGEConv

function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean,
init = glorot_uniform, bias::Bool = true)
Expand Down Expand Up @@ -833,7 +832,7 @@ struct ResGatedGraphConv{W, B, F} <: GNNLayer
σ::F
end

@functor ResGatedGraphConv
Flux.@layer ResGatedGraphConv

function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
init = glorot_uniform, bias::Bool = true)
Expand Down Expand Up @@ -907,7 +906,7 @@ struct CGConv{D1, D2} <: GNNLayer
residual::Bool
end

@functor CGConv
Flux.@layer CGConv

CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)

Expand Down Expand Up @@ -980,7 +979,7 @@ struct AGNNConv{A <: AbstractVector} <: GNNLayer
trainable::Bool
end

@functor AGNNConv
Flux.@layer AGNNConv

Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;)

Expand Down Expand Up @@ -1027,7 +1026,7 @@ struct MEGNetConv{TE, TV, A} <: GNNLayer
aggr::A
end

@functor MEGNetConv
Flux.@layer :expand MEGNetConv

MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr)

Expand Down Expand Up @@ -1108,7 +1107,7 @@ struct GMMConv{A <: AbstractMatrix, B, F} <: GNNLayer
residual::Bool
end

@functor GMMConv
Flux.@layer GMMConv

function GMMConv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
Expand Down Expand Up @@ -1191,7 +1190,7 @@ struct SGConv{A <: AbstractMatrix, B} <: GNNLayer
use_edge_weight::Bool
end

@functor SGConv
Flux.@layer SGConv

function SGConv(ch::Pair{Int, Int}, k = 1;
init = glorot_uniform,
Expand Down Expand Up @@ -1259,7 +1258,7 @@ struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer
use_edge_weight::Bool
end

@functor TAGConv
Flux.@layer TAGConv

function TAGConv(ch::Pair{Int, Int}, k = 3;
init = glorot_uniform,
Expand All @@ -1269,7 +1268,7 @@ function TAGConv(ch::Pair{Int, Int}, k = 3;
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
TAGConv(W, b, k, add_self_loops, use_edge_weight)
return TAGConv(W, b, k, add_self_loops, use_edge_weight)
end

(l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight)
Expand Down Expand Up @@ -1343,10 +1342,10 @@ struct EGNNConv{TE, TX, TH, NF} <: GNNLayer
residual::Bool
end

@functor EGNNConv
Flux.@layer EGNNConv

function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false)
EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
end

#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
Expand Down Expand Up @@ -1477,7 +1476,7 @@ struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLaye
sqrt_out::Float32
end

@functor TransformerConv
Flux.@layer TransformerConv

function Flux.trainable(l::TransformerConv)
(; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2)
Expand Down Expand Up @@ -1568,7 +1567,7 @@ struct DConv <: GNNLayer
k::Int
end

@functor DConv
Flux.@layer DConv

function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true)
in, out = ch
Expand Down
2 changes: 1 addition & 1 deletion src/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct HeteroGraphConv
aggr::Function
end

Flux.@functor HeteroGraphConv
Flux.@layer HeteroGraphConv

HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr)
HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct GlobalAttentionPool{G, F}
ffeat::F
end

@functor GlobalAttentionPool
Flux.@layer GlobalAttentionPool

GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)

Expand Down Expand Up @@ -146,7 +146,7 @@ struct Set2Set{L} <: GNNLayer
num_iters::Int
end

@functor Set2Set
Flux.@layer Set2Set

function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
@assert n_layers >= 1
Expand Down
Loading

0 comments on commit 87f3c60

Please sign in to comment.