From 87f3c60e9d0fdcfe77c68f9255ad203f341faeec Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 8 Aug 2024 10:41:16 +0200 Subject: [PATCH] `@functor` ->`@layer` (#484) * `@functor` ->`@layer` * reinstate tests --- GNNlib/src/msgpass.jl | 2 +- docs/src/messagepassing.md | 2 +- docs/src/models.md | 4 +- .../introductory_tutorials/gnn_intro_pluto.jl | 2 +- .../node_classification_pluto.jl | 4 +- .../temporal_graph_classification_pluto.jl | 2 +- .../graph_classification_temporalbrains.jl | 2 +- notebooks/gnn_intro.ipynb | 2 +- notebooks/graph_classification_solved.ipynb | 2 +- src/GraphNeuralNetworks.jl | 2 +- src/layers/basic.jl | 4 +- src/layers/conv.jl | 51 +++++++++---------- src/layers/heteroconv.jl | 2 +- src/layers/pool.jl | 4 +- src/layers/temporalconv.jl | 10 ++-- test/runtests.jl | 7 +-- 16 files changed, 49 insertions(+), 53 deletions(-) diff --git a/GNNlib/src/msgpass.jl b/GNNlib/src/msgpass.jl index acab02217..1aa17437a 100644 --- a/GNNlib/src/msgpass.jl +++ b/GNNlib/src/msgpass.jl @@ -45,7 +45,7 @@ struct GNNConv <: GNNLayer σ end -Flux.@functor GNNConv +Flux.@layer GNNConv function GNNConv(ch::Pair{Int,Int}, σ=identity) in, out = ch diff --git a/docs/src/messagepassing.md b/docs/src/messagepassing.md index 9db0062e6..f59ad6561 100644 --- a/docs/src/messagepassing.md +++ b/docs/src/messagepassing.md @@ -109,7 +109,7 @@ struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer σ::F end -Flux.@functor GCN # allow gpu movement, select trainable params etc... +Flux.@layer GCN # allow gpu movement, select trainable params etc... function GCN(ch::Pair{Int,Int}, σ=identity) in, out = ch diff --git a/docs/src/models.md b/docs/src/models.md index d0265964e..96e49055a 100644 --- a/docs/src/models.md +++ b/docs/src/models.md @@ -13,7 +13,7 @@ and the *implicit modeling* style based on [`GNNChain`](@ref), more concise but In the explicit modeling style, the model is created according to the following steps: 1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields. -2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...) +2. Apply `Flux.@layer` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...) 3. Optionally define a convenience constructor for your model. 4. Define the forward pass by implementing the call method for your type. 5. Instantiate the model. @@ -30,7 +30,7 @@ struct GNN # step 1 dense end -Flux.@functor GNN # step 2 +Flux.@layer GNN # step 2 function GNN(din::Int, d::Int, dout::Int) # step 3 GNN(GCNConv(din => d), diff --git a/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl b/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl index 76e2e870e..977f621ce 100644 --- a/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl +++ b/docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl @@ -182,7 +182,7 @@ begin layers::NamedTuple end - Flux.@functor GCN # provides parameter collection, gpu movement and more + Flux.@layer GCN # provides parameter collection, gpu movement and more function GCN(num_features, num_classes) layers = (conv1 = GCNConv(num_features => 4), diff --git a/docs/tutorials/introductory_tutorials/node_classification_pluto.jl b/docs/tutorials/introductory_tutorials/node_classification_pluto.jl index 9b3876b20..edf73d4fc 100644 --- a/docs/tutorials/introductory_tutorials/node_classification_pluto.jl +++ b/docs/tutorials/introductory_tutorials/node_classification_pluto.jl @@ -138,7 +138,7 @@ begin layers::NamedTuple end - Flux.@functor MLP + Flux.@layer :expand MLP function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5) layers = (hidden = Dense(num_features => hidden_channels), @@ -235,7 +235,7 @@ begin layers::NamedTuple end - Flux.@functor GCN # provides parameter collection, gpu movement and more + Flux.@layer GCN # provides parameter collection, gpu movement and more function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5) layers = (conv1 = GCNConv(num_features => hidden_channels), diff --git a/docs/tutorials_broken/temporal_graph_classification_pluto.jl b/docs/tutorials_broken/temporal_graph_classification_pluto.jl index b5460c1ec..6afd988c3 100644 --- a/docs/tutorials_broken/temporal_graph_classification_pluto.jl +++ b/docs/tutorials_broken/temporal_graph_classification_pluto.jl @@ -117,7 +117,7 @@ begin dense::Dense end - Flux.@functor GenderPredictionModel + Flux.@layer GenderPredictionModel function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) diff --git a/examples/graph_classification_temporalbrains.jl b/examples/graph_classification_temporalbrains.jl index 2aac2a3e8..e25e9c1f0 100644 --- a/examples/graph_classification_temporalbrains.jl +++ b/examples/graph_classification_temporalbrains.jl @@ -62,7 +62,7 @@ struct GenderPredictionModel dense::Dense end -Flux.@functor GenderPredictionModel +Flux.@layer GenderPredictionModel function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu) mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation)) diff --git a/notebooks/gnn_intro.ipynb b/notebooks/gnn_intro.ipynb index 3f9748f93..db3721ea5 100644 --- a/notebooks/gnn_intro.ipynb +++ b/notebooks/gnn_intro.ipynb @@ -354,7 +354,7 @@ " layers::NamedTuple\n", "end\n", "\n", - "Flux.@functor GCN # provides parameter collection, gpu movement and more\n", + "Flux.@layer :expand GCN # provides parameter collection, gpu movement and more\n", "\n", "function GCN(num_features, num_classes)\n", " layers = (conv1 = GCNConv(num_features => 4),\n", diff --git a/notebooks/graph_classification_solved.ipynb b/notebooks/graph_classification_solved.ipynb index af2e6bf38..a54c5b359 100644 --- a/notebooks/graph_classification_solved.ipynb +++ b/notebooks/graph_classification_solved.ipynb @@ -857,7 +857,7 @@ "\tact::F\n", "end\n", "\n", - "Flux.@functor MyConv\n", + "Flux.@layer MyConv\n", "\n", "function MyConv((nin, nout)::Pair, act=identity)\n", "\tW1 = Flux.glorot_uniform(nout, nin)\n", diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 021d4d8b2..bf6991155 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -3,7 +3,7 @@ module GraphNeuralNetworks using Statistics: mean using LinearAlgebra, Random using Flux -using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch +using Flux: glorot_uniform, leakyrelu, GRUCell, batch using MacroTools: @forward using NNlib using NNlib: scatter, gather diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 22fd029f9..4f99ddba4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -45,7 +45,7 @@ end WithGraph(model, g::GNNGraph; traingraph = false) = WithGraph(model, g, traingraph) -@functor WithGraph +Flux.@layer :expand WithGraph Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model) (l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...) @@ -107,7 +107,7 @@ struct GNNChain{T <: Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer layers::T end -@functor GNNChain +Flux.@layer :expand GNNChain GNNChain(xs...) = GNNChain(xs) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4a9f31783..8c3565dce 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -76,7 +76,7 @@ struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer use_edge_weight::Bool end -@functor GCNConv +Flux.@layer GCNConv function GCNConv(ch::Pair{Int, Int}, σ = identity; init = glorot_uniform, @@ -167,7 +167,7 @@ function ChebConv(ch::Pair{Int, Int}, k::Int; ChebConv(W, b, k) end -@functor ChebConv +Flux.@layer ChebConv (l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x) @@ -225,7 +225,7 @@ struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer aggr::A end -@functor GraphConv +Flux.@layer GraphConv function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, init = glorot_uniform, bias::Bool = true) @@ -300,8 +300,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea y = l(g, x) ``` """ -struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMatrix, F, B} <: - GNNLayer +struct GATConv{DX<:Dense,DE<:Union{Dense, Nothing},DV,T,A<:AbstractMatrix,F,B} <: GNNLayer dense_x::DX dense_e::DE bias::B @@ -315,8 +314,8 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMat dropout::DV end -@functor GATConv -Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a) +Flux.@layer GATConv +Flux.trainable(l::GATConv) = (; l.dense_x, l.dense_e, l.bias, l.a) GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) @@ -420,7 +419,7 @@ struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer dropout::DV end -@functor GATv2Conv +Flux.@layer GATv2Conv Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) @@ -515,7 +514,7 @@ struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer aggr::A end -@functor GatedGraphConv +Flux.@layer GatedGraphConv function GatedGraphConv(dims::Int, num_layers::Int; aggr = +, init = glorot_uniform) @@ -572,7 +571,7 @@ struct EdgeConv{NN, A} <: GNNLayer aggr::A end -@functor EdgeConv +Flux.@layer :expand EdgeConv EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) @@ -626,7 +625,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer aggr::A end -@functor GINConv +Flux.@layer :expand GINConv Flux.trainable(l::GINConv) = (nn = l.nn,) GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) @@ -680,7 +679,7 @@ edim = 10 g = GNNGraph(s, t) # create dense layer -nn = Dense(edim, out_channel * in_channel) +nn = Dense(edim => out_channel * in_channel) # create layer l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +) @@ -697,7 +696,7 @@ struct NNConv{W, B, NN, F, A} <: GNNLayer aggr::A end -@functor NNConv +Flux.@layer :expand NNConv function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, init = glorot_uniform) @@ -763,7 +762,7 @@ struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer aggr::A end -@functor SAGEConv +Flux.@layer SAGEConv function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, init = glorot_uniform, bias::Bool = true) @@ -833,7 +832,7 @@ struct ResGatedGraphConv{W, B, F} <: GNNLayer σ::F end -@functor ResGatedGraphConv +Flux.@layer ResGatedGraphConv function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; init = glorot_uniform, bias::Bool = true) @@ -907,7 +906,7 @@ struct CGConv{D1, D2} <: GNNLayer residual::Bool end -@functor CGConv +Flux.@layer CGConv CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) @@ -980,7 +979,7 @@ struct AGNNConv{A <: AbstractVector} <: GNNLayer trainable::Bool end -@functor AGNNConv +Flux.@layer AGNNConv Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;) @@ -1027,7 +1026,7 @@ struct MEGNetConv{TE, TV, A} <: GNNLayer aggr::A end -@functor MEGNetConv +Flux.@layer :expand MEGNetConv MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) @@ -1108,7 +1107,7 @@ struct GMMConv{A <: AbstractMatrix, B, F} <: GNNLayer residual::Bool end -@functor GMMConv +Flux.@layer GMMConv function GMMConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; @@ -1191,7 +1190,7 @@ struct SGConv{A <: AbstractMatrix, B} <: GNNLayer use_edge_weight::Bool end -@functor SGConv +Flux.@layer SGConv function SGConv(ch::Pair{Int, Int}, k = 1; init = glorot_uniform, @@ -1259,7 +1258,7 @@ struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer use_edge_weight::Bool end -@functor TAGConv +Flux.@layer TAGConv function TAGConv(ch::Pair{Int, Int}, k = 3; init = glorot_uniform, @@ -1269,7 +1268,7 @@ function TAGConv(ch::Pair{Int, Int}, k = 3; in, out = ch W = init(out, in) b = bias ? Flux.create_bias(W, true, out) : false - TAGConv(W, b, k, add_self_loops, use_edge_weight) + return TAGConv(W, b, k, add_self_loops, use_edge_weight) end (l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight) @@ -1343,10 +1342,10 @@ struct EGNNConv{TE, TX, TH, NF} <: GNNLayer residual::Bool end -@functor EGNNConv +Flux.@layer EGNNConv function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false) - EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) + return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual) end #Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py @@ -1477,7 +1476,7 @@ struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLaye sqrt_out::Float32 end -@functor TransformerConv +Flux.@layer TransformerConv function Flux.trainable(l::TransformerConv) (; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2) @@ -1568,7 +1567,7 @@ struct DConv <: GNNLayer k::Int end -@functor DConv +Flux.@layer DConv function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true) in, out = ch diff --git a/src/layers/heteroconv.jl b/src/layers/heteroconv.jl index b2603e455..a10ebb0c7 100644 --- a/src/layers/heteroconv.jl +++ b/src/layers/heteroconv.jl @@ -43,7 +43,7 @@ struct HeteroGraphConv aggr::Function end -Flux.@functor HeteroGraphConv +Flux.@layer HeteroGraphConv HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr) HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index ed2f7eca6..59164e199 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -90,7 +90,7 @@ struct GlobalAttentionPool{G, F} ffeat::F end -@functor GlobalAttentionPool +Flux.@layer GlobalAttentionPool GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) @@ -146,7 +146,7 @@ struct Set2Set{L} <: GNNLayer num_iters::Int end -@functor Set2Set +Flux.@layer Set2Set function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) @assert n_layers >= 1 diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 44688cea4..443ef2a3a 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -18,7 +18,7 @@ struct TGCNCell <: GNNLayer out::Int end -Flux.@functor TGCNCell +Flux.@layer TGCNCell function TGCNCell(ch::Pair{Int, Int}; bias::Bool = true, @@ -156,7 +156,7 @@ struct A3TGCN <: GNNLayer out::Int end -Flux.@functor A3TGCN +Flux.@layer A3TGCN function A3TGCN(ch::Pair{Int, Int}, bias::Bool = true, @@ -200,7 +200,7 @@ struct GConvGRUCell <: GNNLayer out::Int end -Flux.@functor GConvGRUCell +Flux.@layer GConvGRUCell function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; bias::Bool = true, @@ -302,7 +302,7 @@ struct GConvLSTMCell <: GNNLayer out::Int end -Flux.@functor GConvLSTMCell +Flux.@layer GConvLSTMCell function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; bias::Bool = true, @@ -411,7 +411,7 @@ struct DCGRUCell dconv_c::DConv end -Flux.@functor DCGRUCell +Flux.@layer DCGRUCell function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) in, out = ch diff --git a/test/runtests.jl b/test/runtests.jl index b32f8541c..05cb6fd5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using GNNGraphs: sort_edge_index using GNNGraphs: getn, getdata using Functors using Flux -using Flux: gpu, @functor +using Flux: gpu using LinearAlgebra, Statistics, Random using NNlib import MLUtils @@ -35,14 +35,11 @@ tests = [ !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") # @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) -# for graph_type in (:coo, :dense, :sparse) -for graph_type in (:dense,) +for graph_type in (:coo, :dense, :sparse) @info "Testing graph format :$graph_type" global GRAPH_T = graph_type global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) - # global GRAPH_T = :sparse - # global TEST_GPU = false @testset "$t" for t in tests startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI