From 1687abe27b46496034b1e97f7979d8d3d2c1401f Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Wed, 25 Dec 2024 10:47:02 +0100 Subject: [PATCH] Add `Graph Classification` tutorial (#568) * Add GlobalPool * Add graph classification tutorial * Add `GlobalPool` pooling docs * Fix ref * Add `GlobalPool` test * Fix Co-authored-by: Carlo Lucibello * Fix text Co-authored-by: Carlo Lucibello * Add changes to the src file * Fix pooling layer --------- Co-authored-by: Carlo Lucibello --- GNNLux/docs/make.jl | 1 + GNNLux/docs/make_tutorials.jl | 4 +- GNNLux/docs/src/api/pool.md | 19 ++ .../src/tutorials/graph_classification.md | 304 ++++++++++++++++++ .../src_tutorials/graph_classification.jl | 201 ++++++++++++ GNNLux/src/GNNLux.jl | 3 + GNNLux/src/layers/pool.jl | 42 +++ GNNLux/test/layers/pool.jl | 15 + 8 files changed, 588 insertions(+), 1 deletion(-) create mode 100644 GNNLux/docs/src/api/pool.md create mode 100644 GNNLux/docs/src/tutorials/graph_classification.md create mode 100644 GNNLux/docs/src_tutorials/graph_classification.jl create mode 100644 GNNLux/src/layers/pool.jl create mode 100644 GNNLux/test/layers/pool.jl diff --git a/GNNLux/docs/make.jl b/GNNLux/docs/make.jl index a4a990756..8603ef94d 100644 --- a/GNNLux/docs/make.jl +++ b/GNNLux/docs/make.jl @@ -62,6 +62,7 @@ makedocs(; "Introductory tutorials" => [ "Hands on" => "tutorials/gnn_intro.md", "Node Classification" => "tutorials/node_classification.md", + "Graph Classification" => "tutorials/graph_classification.md", ], ], diff --git a/GNNLux/docs/make_tutorials.jl b/GNNLux/docs/make_tutorials.jl index a204d4b56..e4bfc9af6 100644 --- a/GNNLux/docs/make_tutorials.jl +++ b/GNNLux/docs/make_tutorials.jl @@ -2,4 +2,6 @@ using Literate Literate.markdown("src_tutorials/gnn_intro.jl", "src/tutorials/"; execute = true) -Literate.markdown("src_tutorials/node_classification.jl", "src/tutorials/"; execute = true) \ No newline at end of file +Literate.markdown("src_tutorials/graph_classification.jl", "src/tutorials/"; execute = true) + +Literate.markdown("src_tutorials/node_classification.jl", "src/tutorials/"; execute = true) diff --git a/GNNLux/docs/src/api/pool.md b/GNNLux/docs/src/api/pool.md new file mode 100644 index 000000000..a6d6d7f8b --- /dev/null +++ b/GNNLux/docs/src/api/pool.md @@ -0,0 +1,19 @@ +```@meta +CurrentModule = GNNLux +CollapsedDocStrings = true +``` + +# Pooling Layers + +## Index + +```@index +Order = [:type, :function] +Pages = ["pool.md"] +``` + +```@autodocs +Modules = [GNNLux] +Pages = ["layers/pool.jl"] +Private = false +``` diff --git a/GNNLux/docs/src/tutorials/graph_classification.md b/GNNLux/docs/src/tutorials/graph_classification.md new file mode 100644 index 000000000..e9e2443d6 --- /dev/null +++ b/GNNLux/docs/src/tutorials/graph_classification.md @@ -0,0 +1,304 @@ +# Graph Classification with Graph Neural Networks + +*This tutorial is a Julia adaptation of the Pytorch Geometric tutorial that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* + +In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**. +Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties and possibly on some input node features. +Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand. + +A common graph classification task is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not. + +The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl. +Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**: + +````julia +using Lux, GNNLux +using MLDatasets, MLUtils +using LinearAlgebra, Random, Statistics +using Zygote, Optimisers, OneHotArrays + +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation +rng = Random.seed!(42); # for reproducibility + +dataset = TUDataset("MUTAG") +```` + +```` +dataset TUDataset: + name => MUTAG + metadata => Dict{String, Any} with 1 entry + graphs => 188-element Vector{MLDatasets.Graph} + graph_data => (targets = "188-element Vector{Int64}",) + num_nodes => 3371 + num_edges => 7442 + num_graphs => 188 +```` + +````julia +dataset.graph_data.targets |> union +```` + +```` +2-element Vector{Int64}: + 1 + -1 +```` + +````julia +g1, y1 = dataset[1] # get the first graph and target +```` + +```` +(graphs = Graph(17, 38), targets = 1) +```` + +````julia +reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union +```` + +```` +7-element Vector{Int64}: + 0 + 1 + 2 + 3 + 4 + 5 + 6 +```` + +````julia +reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union +```` + +```` +4-element Vector{Int64}: + 0 + 1 + 2 + 3 +```` + +This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**. + +By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**. +It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes). +However, for the sake of simplicity, we will not make use of edge labels. + +We now convert the `MLDatasets.jl` graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict): + +````julia +graphs = mldataset2gnngraph(dataset) +graphs = [GNNGraph(g, + ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)), + edata = nothing) + for g in graphs] +y = onehotbatch(dataset.graph_data.targets, [-1, 1]) +```` + +```` +2×188 OneHotMatrix(::Vector{UInt32}) with eltype Bool: + ⋅ 1 1 ⋅ 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ 1 1 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 1 ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ 1 1 1 1 1 ⋅ 1 ⋅ ⋅ 1 1 ⋅ 1 1 1 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 1 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ 1 1 ⋅ ⋅ 1 1 ⋅ 1 + 1 ⋅ ⋅ 1 ⋅ 1 ⋅ 1 ⋅ 1 1 1 1 ⋅ 1 1 ⋅ 1 ⋅ 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ⋅ 1 ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ 1 1 1 1 1 1 1 1 1 1 1 1 ⋅ 1 1 1 1 1 1 ⋅ 1 1 ⋅ ⋅ 1 1 1 ⋅ 1 1 ⋅ 1 1 ⋅ ⋅ ⋅ 1 1 1 1 1 ⋅ 1 1 1 ⋅ ⋅ 1 1 1 1 1 1 1 1 ⋅ 1 ⋅ 1 1 1 1 1 1 1 1 1 ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ 1 1 ⋅ ⋅ 1 1 ⋅ ⋅ 1 1 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 1 ⋅ 1 1 1 ⋅ ⋅ ⋅ 1 1 1 ⋅ 1 1 1 1 1 1 1 ⋅ 1 1 1 1 1 1 ⋅ 1 1 1 ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ 1 ⋅ +```` + +We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing: + +````julia +train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs + + +train_loader = DataLoader(train_data, batchsize = 32, shuffle = true) +test_loader = DataLoader(test_data, batchsize = 32, shuffle = false) +```` + +```` +2-element DataLoader(::Tuple{Vector{GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32) + with first element: + (32-element Vector{GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, 2×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,) +```` + +Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all $4 \cdot 32+22 = 150$ graphs. + +## Mini-batching of graphs + +Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization. +In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. +The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`. + +However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. +Therefore, GNNLux.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension). + +This procedure has some crucial advantages over other batching procedures: + +1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs. + +2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. + +GNNLux.jl can **batch multiple graphs into a single giant graph**: + +````julia +vec_gs, _ = first(train_loader) +```` + +```` +(GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(11, 22) with x: 7×11 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(28, 66) with x: 7×28 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(12, 24) with x: 7×12 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(27, 66) with x: 7×27 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(17, 36) with x: 7×17 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(16, 36) with x: 7×16 data], Bool[1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 1 1 1 1 1 0; 0 1 1 1 1 1 1 1 0 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 0 0 0 0 0 1]) +```` + +````julia +MLUtils.batch(vec_gs) +```` + +```` +GNNGraph: + num_nodes: 570 + num_edges: 1254 + num_graphs: 32 + ndata: + x = 7×570 Matrix{Float32} +```` + +Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch: + +```math +\textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ] +``` + +## Training a Graph Neural Network (GNN) + +Training a GNN for graph classification usually follows a simple recipe: + +1. Embed each node by performing multiple rounds of message passing +2. Aggregate node embeddings into a unified graph embedding (**readout layer**) +3. Train a final classifier on the graph embedding + +There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings: + +```math +\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v +``` + +GNNLux.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`. + +The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training: + +````julia +function create_model(nin, nh, nout) + GNNChain(GCNConv(nin => nh, relu), + GCNConv(nh => nh, relu), + GCNConv(nh => nh), + GlobalPool(mean), + Dropout(0.5), + Dense(nh, nout)) +end; + +nin = 7 +nh = 64 +nout = 2 +model = create_model(nin, nh, nout) + +ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model); +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 + +```` + +Here, we again make use of the `GCNConv` with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer. + +Let's train our network for a few epochs to see how well it performs on the training as well as test set: + +````julia +function custom_loss(model, ps, st, tuple) + g, x, y = tuple + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + st = Lux.trainmode(st) + ŷ, st = model(g, x, ps, st) + return logitcrossentropy(ŷ, y), (; layers = st), 0 +end + +function eval_loss_accuracy(model, ps, st, data_loader) + loss = 0.0 + acc = 0.0 + ntot = 0 + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + for (g, y) in data_loader + g = MLUtils.batch(g) + n = length(y) + ŷ, _ = model(g, g.ndata.x, ps, st) + loss += logitcrossentropy(ŷ, y) * n + acc += mean((ŷ .> 0) .== y) * n + ntot += n + end + return (loss = round(loss / ntot, digits = 4), + acc = round(acc * 100 / ntot, digits = 2)) +end + +function train_model!(model, ps, st; epochs = 500, infotime = 100) + train_state = Lux.Training.TrainState(model, ps, st, Adam(1e-2)) + + function report(epoch) + train = eval_loss_accuracy(model, ps, st, train_loader) + st = Lux.testmode(st) + test = eval_loss_accuracy(model, ps, st, test_loader) + st = Lux.trainmode(st) + @info (; epoch, train, test) + end + report(0) + for iter in 1:epochs + for (g, y) in train_loader + g = MLUtils.batch(g) + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.ndata.x, y), train_state) + end + + iter % infotime == 0 && report(iter) + end + return model, ps, st +end + +model, ps, st = train_model!(model, ps, st); +```` + +```` +┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code. +└ @ LuxLib.Utils ~/.julia/packages/LuxLib/ru5RQ/src/utils.jl:314 +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 +[ Info: (epoch = 0, train = (loss = 0.6934, acc = 51.67), test = (loss = 0.6902, acc = 50.0)) +[ Info: (epoch = 100, train = (loss = 0.3979, acc = 81.33), test = (loss = 0.5769, acc = 69.74)) +[ Info: (epoch = 200, train = (loss = 0.3904, acc = 84.0), test = (loss = 0.6402, acc = 65.79)) +[ Info: (epoch = 300, train = (loss = 0.3813, acc = 85.33), test = (loss = 0.6331, acc = 69.74)) +[ Info: (epoch = 400, train = (loss = 0.3682, acc = 85.0), test = (loss = 0.7273, acc = 69.74)) +[ Info: (epoch = 500, train = (loss = 0.3561, acc = 86.67), test = (loss = 0.6825, acc = 73.68)) + +```` + +As one can see, our model reaches around **74% test accuracy**. +Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets. + +## (Optional) Exercise + +Can we do better than this? +As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**. +An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information: + +```math +\mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)} +``` + +This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl. + +As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`. +This should bring you close to **82% test accuracy**. + +## Conclusion + +In this chapter, you have learned how to apply GNNs to the task of graph classification. +You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings. + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/GNNLux/docs/src_tutorials/graph_classification.jl b/GNNLux/docs/src_tutorials/graph_classification.jl new file mode 100644 index 000000000..caab4e409 --- /dev/null +++ b/GNNLux/docs/src_tutorials/graph_classification.jl @@ -0,0 +1,201 @@ +# # Graph Classification with Graph Neural Networks + +# *This tutorial is a Julia adaptation of the Pytorch Geometric tutorial that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* + +# In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**. +# Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties and possibly on some input node features. +# Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand. +# We will use a graph convolutional network to create a vector embedding of the input graph, and the apply a simple linear classification head to perform the final classification. + + +# A common graph classification task is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not. + +# The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl. +# Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**: + +using Lux, GNNLux +using MLDatasets, MLUtils +using LinearAlgebra, Random, Statistics +using Zygote, Optimisers, OneHotArrays + +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation +rng = Random.seed!(42); # for reproducibility + +dataset = TUDataset("MUTAG") +# +dataset.graph_data.targets |> union +# +g1, y1 = dataset[1] # get the first graph and target +# +reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union +# + +reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union + +# This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**. + +# By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**. +# It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes). +# However, for the sake of simplicity, we will not make use of edge labels. + +# We now convert the `MLDatasets.jl` graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict): + +graphs = mldataset2gnngraph(dataset) +graphs = [GNNGraph(g, + ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)), + edata = nothing) + for g in graphs] +y = onehotbatch(dataset.graph_data.targets, [-1, 1]) + + +# We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing: + +train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs + + +train_loader = DataLoader(train_data, batchsize = 32, shuffle = true) +test_loader = DataLoader(test_data, batchsize = 32, shuffle = false) + + + +# Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all $4 \cdot 32+22 = 150$ graphs. + + + +# ## Mini-batching of graphs + +# Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization. +# In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. +# The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`. + +# However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. +# Therefore, GNNLux.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension). + +# This procedure has some crucial advantages over other batching procedures: + +# 1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs. + +# 2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. + +# GNNLux.jl can **batch multiple graphs into a single giant graph**: + +vec_gs, _ = first(train_loader) +# +MLUtils.batch(vec_gs) + +# Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch: + +# ```math +# \textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ] +# ``` + +# ## Training a Graph Neural Network (GNN) + +# Training a GNN for graph classification usually follows a simple recipe: + +# 1. Embed each node by performing multiple rounds of message passing +# 2. Aggregate node embeddings into a unified graph embedding (**readout layer**) +# 3. Train a final classifier on the graph embedding + +# There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings: + +# ```math +# \mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v +# ``` + +# GNNLux.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`. + +# The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training: + +function create_model(nin, nh, nout) + GNNChain(GCNConv(nin => nh, relu), + GCNConv(nh => nh, relu), + GCNConv(nh => nh), + GlobalPool(mean), + Dropout(0.5), + Dense(nh, nout)) +end; + +nin = 7 +nh = 64 +nout = 2 +model = create_model(nin, nh, nout) + +ps, st = LuxCore.initialparameters(rng, model), LuxCore.initialstates(rng, model); + +# Here, we again make use of the `GCNConv` with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer. + +# Let's train our network for a few epochs to see how well it performs on the training as well as test set: + +function custom_loss(model, ps, st, tuple) + g, x, y = tuple + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + st = Lux.trainmode(st) + ŷ, st = model(g, x, ps, st) + return logitcrossentropy(ŷ, y), (; layers = st), 0 +end + +function eval_loss_accuracy(model, ps, st, data_loader) + loss = 0.0 + acc = 0.0 + ntot = 0 + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + for (g, y) in data_loader + g = MLUtils.batch(g) + n = length(y) + ŷ, _ = model(g, g.ndata.x, ps, st) + loss += logitcrossentropy(ŷ, y) * n + acc += mean((ŷ .> 0) .== y) * n + ntot += n + end + return (loss = round(loss / ntot, digits = 4), + acc = round(acc * 100 / ntot, digits = 2)) +end + +function train_model!(model, ps, st; epochs = 500, infotime = 100) + train_state = Lux.Training.TrainState(model, ps, st, Adam(1e-2)) + + function report(epoch) + train = eval_loss_accuracy(model, ps, st, train_loader) + st = Lux.testmode(st) + test = eval_loss_accuracy(model, ps, st, test_loader) + st = Lux.trainmode(st) + @info (; epoch, train, test) + end + report(0) + for iter in 1:epochs + for (g, y) in train_loader + g = MLUtils.batch(g) + _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), custom_loss,(g, g.ndata.x, y), train_state) + end + + iter % infotime == 0 && report(iter) + end + return model, ps, st +end + +model, ps, st = train_model!(model, ps, st); + +# As one can see, our model reaches around **74% test accuracy**. +# Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets. + +# ## (Optional) Exercise + +# Can we do better than this? +# As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**. +# An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information: + +# ```math +# \mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)} +# ``` + +# This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl. + +# As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`. +# This should bring you close to **82% test accuracy**. + +# ## Conclusion + +# In this chapter, you have learned how to apply GNNs to the task of graph classification. +# You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings. + diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 0a8c4e290..163d315b5 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -48,5 +48,8 @@ export TGCN, DCGRU, EvolveGCNO +include("layers/pool.jl") +export GlobalPool + end #module \ No newline at end of file diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl new file mode 100644 index 000000000..7fcc044f6 --- /dev/null +++ b/GNNLux/src/layers/pool.jl @@ -0,0 +1,42 @@ +@doc raw""" + GlobalPool(aggr) + +Global pooling layer for graph neural networks. +Takes a graph and feature nodes as inputs +and performs the operation + +```math +\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i +``` + +where ``V`` is the set of nodes of the input graph and +the type of aggregation represented by ``\square`` is selected by the `aggr` argument. +Commonly used aggregations are `mean`, `max`, and `+`. + +See also [`GNNlib.reduce_nodes`](@ref). + +# Examples + +```julia +using Lux, GNNLux, Graphs, MLUtils + +using Graphs +pool = GlobalPool(mean) + +g = GNNGraph(erdos_renyi(10, 4)) +X = rand(32, 10) +pool(g, X, ps, st) # => 32x1 matrix + + +g = MLUtils.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5]) +X = rand(32, 50) +pool(g, X, ps, st) # => 32x5 matrix +``` +""" +struct GlobalPool{F} <: GNNLayer + aggr::F +end + +(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st + +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) \ No newline at end of file diff --git a/GNNLux/test/layers/pool.jl b/GNNLux/test/layers/pool.jl new file mode 100644 index 000000000..f1f7faeae --- /dev/null +++ b/GNNLux/test/layers/pool.jl @@ -0,0 +1,15 @@ +@testitem "Pooling" setup=[TestModuleLux] begin + using .TestModuleLux + @testset "GlobalPool" begin + + rng = StableRNG(1234) + g = rand_graph(rng, 10, 40) + in_dims = 3 + x = randn(rng, Float32, in_dims, 10) + + @testset "GCNConv" begin + l = GlobalPool(mean) + test_lux_layer(rng, l, g, x, sizey=(in_dims,1)) + end + end +end