diff --git a/GNNLux/docs/Project.toml b/GNNLux/docs/Project.toml index 63940977a..36822253f 100644 --- a/GNNLux/docs/Project.toml +++ b/GNNLux/docs/Project.toml @@ -1,5 +1,6 @@ [deps] CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" @@ -12,4 +13,5 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +TSne = "24678dba-d5e9-5843-a4c6-250288b04835" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/GNNLux/docs/make.jl b/GNNLux/docs/make.jl index feae0f3c5..a4a990756 100644 --- a/GNNLux/docs/make.jl +++ b/GNNLux/docs/make.jl @@ -61,6 +61,7 @@ makedocs(; "Tutorials" => [ "Introductory tutorials" => [ "Hands on" => "tutorials/gnn_intro.md", + "Node Classification" => "tutorials/node_classification.md", ], ], diff --git a/GNNLux/docs/make_tutorials.jl b/GNNLux/docs/make_tutorials.jl index ebeaaff9d..a204d4b56 100644 --- a/GNNLux/docs/make_tutorials.jl +++ b/GNNLux/docs/make_tutorials.jl @@ -1,3 +1,5 @@ using Literate -Literate.markdown("src_tutorials/gnn_intro.jl", "src/tutorials/"; execute = true) \ No newline at end of file +Literate.markdown("src_tutorials/gnn_intro.jl", "src/tutorials/"; execute = true) + +Literate.markdown("src_tutorials/node_classification.jl", "src/tutorials/"; execute = true) \ No newline at end of file diff --git a/GNNLux/docs/src/tutorials/node_classification.md b/GNNLux/docs/src/tutorials/node_classification.md new file mode 100644 index 000000000..8f332dac6 --- /dev/null +++ b/GNNLux/docs/src/tutorials/node_classification.md @@ -0,0 +1,5895 @@ +# Node Classification with Graph Neural Networks + +In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, we want to infer the labels for all the remaining nodes (transductive learning). + +## Import +Let us start off by importing some libraries. We will be using `Lux.jl` and `GNNLux.jl` for our tutorial. + +````julia +using Lux, GNNLux +using MLDatasets +using Plots, TSne +using Random, Statistics +using Zygote, Optimisers, OneHotArrays, ConcreteStructs + + +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation +rng = Random.seed!(17); # for reproducibility +```` + +## Visualize +We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane. + +````julia +function visualize_tsne(out, targets) + z = tsne(out, 2) + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) +end; +```` + +## Dataset: Cora + +For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other. + +Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. + +This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. + +````julia +dataset = Cora() +```` + +```` +dataset Cora: + metadata => Dict{String, Any} with 3 entries + graphs => 1-element Vector{MLDatasets.Graph} +```` + +Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. + +````julia +dataset.metadata +```` + +```` +Dict{String, Any} with 3 entries: + "name" => "cora" + "classes" => [1, 2, 3, 4, 5, 6, 7] + "num_classes" => 7 +```` + +The `graphs` variable contains the graph. The `Cora` dataset contains only 1 graph. + +````julia +dataset.graphs +```` + +```` +1-element Vector{MLDatasets.Graph}: + Graph(2708, 10556) +```` + +There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. + +````julia +g = mldataset2gnngraph(dataset) + + +println("Number of nodes: $(g.num_nodes)") +println("Number of edges: $(g.num_edges)") +println("Average node degree: $(g.num_edges / g.num_nodes)") +println("Number of training nodes: $(sum(g.ndata.train_mask))") +println("Training node label rate: $(mean(g.ndata.train_mask))") +println("Has isolated nodes: $(has_isolated_nodes(g))") +println("Has self-loops: $(has_self_loops(g))") +println("Is undirected: $(is_bidirected(g))") +```` + +```` +Number of nodes: 2708 +Number of edges: 10556 +Average node degree: 3.8980797636632203 +Number of training nodes: 140 +Training node label rate: 0.051698670605613 +Has isolated nodes: false +Has self-loops: false +Is undirected: true + +```` + +Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. +We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. +For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). +This results in a training node label rate of only 5%. + +We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). + +````julia +x = g.ndata.features # we onehot encode the node labels (what we want to predict): +y = onehotbatch(g.ndata.targets, 1:7) +train_mask = g.ndata.train_mask; +num_features = size(x)[1]; +hidden_channels = 16; +drop_rate = 0.5; +num_classes = dataset.metadata["num_classes"]; +```` + +## Multi-layer Perception Network (MLP) + +In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. + +Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): + +````julia +MLP = Chain(Dense(num_features => hidden_channels, relu), + Dropout(drop_rate), + Dense(hidden_channels => num_classes)) + +ps, st = Lux.setup(rng, MLP); +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 + +```` + +### Training a Multilayer Perceptron + +Our MLP is defined by two linear layers and enhanced by [ReLU](https://lux.csail.mit.edu/stable/api/NN_Primitives/ActivationFunctions#NNlib.relu) non-linearity and [Dropout](https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Dropout). +Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. + +Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/tutorials/gnn_intro/). +We again make use of the **cross entropy loss** and **Adam optimizer**. +This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). + +````julia +function loss(model, ps, st, x) + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = model(x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(MLP, ps, st, x, epochs) + train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) + for iter in 1:epochs + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss, x, train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss_value)") + end + end +end + +function accuracy(model, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_model!(MLP, ps, st, x, 2000) +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 +Epoch: 100 Loss: 0.810594 +Epoch: 200 Loss: 0.48982772 +Epoch: 300 Loss: 0.31716076 +Epoch: 400 Loss: 0.2397098 +Epoch: 500 Loss: 0.20041731 +Epoch: 600 Loss: 0.11589075 +Epoch: 700 Loss: 0.21093586 +Epoch: 800 Loss: 0.18869051 +Epoch: 900 Loss: 0.15322906 +Epoch: 1000 Loss: 0.12451931 +Epoch: 1100 Loss: 0.13396983 +Epoch: 1200 Loss: 0.111468166 +Epoch: 1300 Loss: 0.17113678 +Epoch: 1400 Loss: 0.18155631 +Epoch: 1500 Loss: 0.17731342 +Epoch: 1600 Loss: 0.11386197 +Epoch: 1700 Loss: 0.09408201 +Epoch: 1800 Loss: 0.15806198 +Epoch: 1900 Loss: 0.104388796 +Epoch: 2000 Loss: 0.18465123 + +```` + +After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. +Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: + +````julia +accuracy(MLP, x, ps, st, y, .!train_mask) +```` + +```` +0.5089563862928349 +```` + +As one can see, our MLP performs rather bad with only about ~50% test accuracy. +But why does the MLP do not perform better? +The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. + +It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. +That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. + +## Training a Graph Convolutional Neural Network (GNN) + +Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GCNConv) module. +To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} +``` + +where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. +In contrast, a single `Linear` layer is defined as + +```math +\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} +``` + +which does not make use of neighboring node information. + +````julia +@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} + nf::Int + nc::Int + hd::Int + conv1 + conv2 + drop + use_bias::Bool + init_weight + init_bias +end; + +function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor + conv1 = GCNConv(num_features => hidden_channels) + conv2 = GCNConv(hidden_channels => num_classes) + drop = Dropout(drop_rate) + return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias) +end; + +function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass + x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1) + x = relu.(x) + x, stdrop = gcn.drop(x, ps.drop, st.drop) + x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2) + return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) +end; +```` + +Now let's visualize the node embeddings of our **untrained** GCN network. + +````julia +gcn = GCN(num_features, num_classes, hidden_channels, drop_rate) +ps, st = Lux.setup(rng, gcn) +h_untrained, st = gcn(g, x, ps, st) +h_untrained = h_untrained |> transpose +visualize_tsne(h_untrained, g.ndata.targets) +```` + +```@raw html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +``` + +We certainly can do better by training our model. +The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. + +````julia +function loss(gcn, ps, st, tuple) + g, x, y = tuple + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = gcn(g, x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(gcn, ps, st, g, x, y) + train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) + for iter in 1:2000 + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss,(g, x, y), train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss_value)") + end + end + + return gcn, ps, st +end + +gcn, ps, st = train_model!(gcn, ps, st, g, x, y); +```` + +```` +┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`. +└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18 +Epoch: 100 Loss: 0.019381031 +Epoch: 200 Loss: 0.017426146 +Epoch: 300 Loss: 0.006051709 +Epoch: 400 Loss: 0.0015434261 +Epoch: 500 Loss: 0.0052008606 +Epoch: 600 Loss: 0.025294377 +Epoch: 700 Loss: 0.0012917791 +Epoch: 800 Loss: 0.005089373 +Epoch: 900 Loss: 0.00912053 +Epoch: 1000 Loss: 0.002442247 +Epoch: 1100 Loss: 0.00024606875 +Epoch: 1200 Loss: 0.00046606906 +Epoch: 1300 Loss: 0.002437515 +Epoch: 1400 Loss: 0.00019191795 +Epoch: 1500 Loss: 0.0056298207 +Epoch: 1600 Loss: 0.00020503976 +Epoch: 1700 Loss: 0.0028860446 +Epoch: 1800 Loss: 0.02319943 +Epoch: 1900 Loss: 0.00030635786 +Epoch: 2000 Loss: 0.00013437525 + +```` + +Now let's evaluate the loss of our trained GCN. + +````julia +function accuracy(model, g, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(g, x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask) +test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask) + +println("Train accuracy: $(train_accuracy)") +println("Test accuracy: $(test_accuracy)") +```` + +```` +Train accuracy: 1.0 +Test accuracy: 0.7636292834890965 + +```` + +**There it is!** +By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! +This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. + +We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. + +````julia +st = Lux.testmode(st) # inference mode + +out_trained, st = gcn(g, x, ps, st) +out_trained = out_trained|> transpose +visualize_tsne(out_trained, g.ndata.targets) +```` + +```@raw html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +``` + +## (Optional) Exercises + +1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **> 80% accuracy**. + +2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? + +3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. + +## Conclusion +In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification. + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/GNNLux/docs/src_tutorials/node_classification.jl b/GNNLux/docs/src_tutorials/node_classification.jl new file mode 100644 index 000000000..3ba82fb93 --- /dev/null +++ b/GNNLux/docs/src_tutorials/node_classification.jl @@ -0,0 +1,255 @@ +# # Node Classification with Graph Neural Networks + +# In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, we want to infer the labels for all the remaining nodes (transductive learning). + + +# ## Import +# Let us start off by importing some libraries. We will be using `Lux.jl` and `GNNLux.jl` for our tutorial. + +using Lux, GNNLux +using MLDatasets +using Plots, TSne +using Random, Statistics +using Zygote, Optimisers, OneHotArrays, ConcreteStructs + + +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation +rng = Random.seed!(17); # for reproducibility + +# ## Visualize +# We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane. + +function visualize_tsne(out, targets) + z = tsne(out, 2) + scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false) +end; + + +# ## Dataset: Cora + +# For our tutorial, we will be using the `Cora` dataset. `Cora` is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other. + +# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. + +# This dataset was first introduced by [Yang et al. (2016)](https://arxiv.org/abs/1603.08861) as one of the datasets of the `Planetoid` benchmark suite. We will be using [MLDatasets.jl](https://juliaml.github.io/MLDatasets.jl/stable/) for an easy access to this dataset. + +dataset = Cora() + +# Datasets in MLDatasets.jl have `metadata` containing information about the dataset itself. + +dataset.metadata + +# The `graphs` variable contains the graph. The `Cora` dataset contains only 1 graph. + +dataset.graphs + + +# There is only one graph of the dataset. The `node_data` contains `features` indicating if certain words are present or not and `targets` indicating the class for each document. We convert the single-graph dataset to a `GNNGraph`. + +g = mldataset2gnngraph(dataset) + + +println("Number of nodes: $(g.num_nodes)") +println("Number of edges: $(g.num_edges)") +println("Average node degree: $(g.num_edges / g.num_nodes)") +println("Number of training nodes: $(sum(g.ndata.train_mask))") +println("Training node label rate: $(mean(g.ndata.train_mask))") +println("Has isolated nodes: $(has_isolated_nodes(g))") +println("Has self-loops: $(has_self_loops(g))") +println("Is undirected: $(is_bidirected(g))") + +# Overall, this dataset is quite similar to the previously used [`KarateClub`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/graphs/#MLDatasets.KarateClub) network. +# We can see that the `Cora` network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. +# For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). +# This results in a training node label rate of only 5%. + +# We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation). + +x = g.ndata.features # we onehot encode the node labels (what we want to predict): +y = onehotbatch(g.ndata.targets, 1:7) +train_mask = g.ndata.train_mask; +num_features = size(x)[1]; +hidden_channels = 16; +drop_rate = 0.5; +num_classes = dataset.metadata["num_classes"]; + + +# ## Multi-layer Perception Network (MLP) + +# In theory, we should be able to infer the category of a document solely based on its content, *i.e.* its bag-of-words feature representation, without taking any relational information into account. + +# Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes): + +MLP = Chain(Dense(num_features => hidden_channels, relu), + Dropout(drop_rate), + Dense(hidden_channels => num_classes)) + +ps, st = Lux.setup(rng, MLP); + +# ### Training a Multilayer Perceptron + +# Our MLP is defined by two linear layers and enhanced by [ReLU](https://lux.csail.mit.edu/stable/api/NN_Primitives/ActivationFunctions#NNlib.relu) non-linearity and [Dropout](https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Dropout). +# Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (`hidden_channels=16`), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes. + +# Let's train our simple MLP by following a similar procedure as described in [the first part of this tutorial](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/tutorials/gnn_intro/). +# We again make use of the **cross entropy loss** and **Adam optimizer**. +# This time, we also define a **`accuracy` function** to evaluate how well our final model performs on the test node set (which labels have not been observed during training). + + +function loss(model, ps, st, x) + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = model(x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(MLP, ps, st, x, epochs) + train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3)) + for iter in 1:epochs + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss, x, train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss_value)") + end + end +end + +function accuracy(model, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_model!(MLP, ps, st, x, 2000) + +# After training the model, we can call the `accuracy` function to see how well our model performs on unseen labels. +# Here, we are interested in the accuracy of the model, *i.e.*, the ratio of correctly classified nodes: + +accuracy(MLP, x, ps, st, y, .!train_mask) + +# As one can see, our MLP performs rather bad with only about ~50% test accuracy. +# But why does the MLP do not perform better? +# The main reason for that is that this model suffers from heavy overfitting due to only having access to a **small amount of training nodes**, and therefore generalizes poorly to unseen node representations. + +# It also fails to incorporate an important bias into the model: **Cited papers are very likely related to the category of a document**. +# That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model. + + +# ## Training a Graph Convolutional Neural Network (GNN) + +# Following-up on the first part of this tutorial, we replace the `Dense` linear layers by the [`GCNConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GCNConv) module. +# To recap, the **GCN layer** ([Kipf et al. (2017)](https://arxiv.org/abs/1609.02907)) is defined as + +# ```math +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)} +# ``` + +# where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape `[num_output_features, num_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. +# In contrast, a single `Linear` layer is defined as + +# ```math +# \mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)} +# ``` + +# which does not make use of neighboring node information. + +@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)} + nf::Int + nc::Int + hd::Int + conv1 + conv2 + drop + use_bias::Bool + init_weight + init_bias +end; + +function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor + conv1 = GCNConv(num_features => hidden_channels) + conv2 = GCNConv(hidden_channels => num_classes) + drop = Dropout(drop_rate) + return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias) +end; + +function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass + x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1) + x = relu.(x) + x, stdrop = gcn.drop(x, ps.drop, st.drop) + x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2) + return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2) +end; + +# Now let's visualize the node embeddings of our **untrained** GCN network. + +gcn = GCN(num_features, num_classes, hidden_channels, drop_rate) +ps, st = Lux.setup(rng, gcn) +h_untrained, st = gcn(g, x, ps, st) +h_untrained = h_untrained |> transpose +visualize_tsne(h_untrained, g.ndata.targets) + + +# We certainly can do better by training our model. +# The training and testing procedure is once again the same, but this time we make use of the node features `x` **and** the graph `g` as input to our GCN model. + + + +function loss(gcn, ps, st, tuple) + g, x, y = tuple + logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) + ŷ, st = gcn(g, x, ps, st) + return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0 +end + +function train_model!(gcn, ps, st, g, x, y) + train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2)) + for iter in 1:2000 + _, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss,(g, x, y), train_state) + + if iter % 100 == 0 + println("Epoch: $(iter) Loss: $(loss_value)") + end + end + + return gcn, ps, st +end + +gcn, ps, st = train_model!(gcn, ps, st, g, x, y); + +# Now let's evaluate the loss of our trained GCN. + +function accuracy(model, g, x, ps, st, y, mask) + st = Lux.testmode(st) + ŷ, st = model(g, x, ps, st) + mean(onecold(ŷ)[mask] .== onecold(y)[mask]) +end + +train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask) +test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask) + +println("Train accuracy: $(train_accuracy)") +println("Test accuracy: $(test_accuracy)") +# **There it is!** +# By simply swapping the linear layers with GNN layers, we can reach **76% of test accuracy**! +# This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance. + +# We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category. + + + +st = Lux.testmode(st) # inference mode + +out_trained, st = gcn(g, x, ps, st) +out_trained = out_trained|> transpose +visualize_tsne(out_trained, g.ndata.targets) + +# ## (Optional) Exercises + +# 1. To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The `Cora` dataset provides a validation node set as `g.ndata.val_mask`, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to **> 80% accuracy**. + +# 2. How does `GCN` behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all? + +# 3. You can try to use different GNN layers to see how model performance changes. What happens if you swap out all `GCNConv` instances with [`GATConv`](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/api/conv/#GNNLux.GATConv) layers that make use of attention? Try to write a 2-layer `GAT` model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses a `dropout` ratio of `0.6` inside and outside each `GATConv` call, and uses a `hidden_channels` dimensions of `8` per head. + + +# ## Conclusion +# In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification.