Documentation | Paper | External Resources | Datasets
PyTorch Geometric Temporal is a temporal (dynamic) extension library for PyTorch Geometric.
The library consists of various dynamic and temporal geometric deep learning, embedding, and spatio-temporal regression methods from a variety of published research papers. In addition, it consists of an easy-to-use dataset loader and iterator for dynamic and temporal graphs, gpu-support. It also comes with a number of benchmark datasets with temporal and dynamic graphs (you can also create your own datasets).
A simple example
PyTorch Geometric Temporal makes implementing Dynamic and Temporal Graph Neural Networks quite easy -- see the accompanying tutorial. For example, this is all it takes to implement a recurrent graph convolutional network with two consecutive graph convolutional GRU cells and a linear layer:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvGRU
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, num_classes):
super(RecurrentGCN, self).__init__()
self.recurrent_1 = GConvGRU(node_features, 32, 5)
self.recurrent_2 = GConvGRU(32, 16, 5)
self.linear = torch.nn.Linear(16, num_classes)
def forward(self, x, edge_index, edge_weight):
x = self.recurrent_1(x, edge_index, edge_weight)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.recurrent_2(x, edge_index, edge_weight)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.linear(x)
return F.log_softmax(x, dim=1)
Methods Included
In detail, the following temporal graph neural networks were implemented.
Discrete Recurrent Graph Convolutions
-
GConvGRU from Seo et al.: Structured Sequence Modeling with Graph Convolutional Recurrent Networks (ICONIP 2018)
-
GConvLSTM from Seo et al.: Structured Sequence Modeling with Graph Convolutional Recurrent Networks (ICONIP 2018)
-
GC-LSTM from Chen et al.: GC-LSTM: Graph Convolution Embedded LSTM for Dynamic Link Prediction (CoRR 2018)
-
LRGCN from Li et al.: Predicting Path Failure In Time-Evolving Graphs (KDD 2019)
-
DyGrEncoder from Taheri et al.: Learning to Represent the Evolution of Dynamic Graphs with Recurrent Models (WWW 2019)
-
EvolveGCNH from Pareja et al.: EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs (AAAI 2020)
-
EvolveGCNO from Pareja et al.: EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs (AAAI 2020)
Head over to our documentation to find out more about installation, creation of datasets and a full list of implemented methods and available datasets.
For a quick start, check out the examples in the examples/
directory.
If you notice anything unexpected, please open an issue. If you are missing a specific method, feel free to open a feature request.
Citing
If you find PyTorch Geometric Temporal and the new datasets useful in your research, please consider adding the folowing citation:
@misc{rozemberczki_temporal,
author = {Benedek, Rozemberczki},
title = {{PyTorch Geometric Temporal}},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/benedekrozemberczki/pytorch_geometric_temporal}},
}
Installation
PyTorch 1.6.0
To install the binaries for PyTorch 1.6.0, simply run
$ pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.6.0.html
$ pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.6.0.html
$ pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.6.0.html
$ pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.6.0.html
$ pip install torch-geometric
$ pip install torch-geometric-temporal
where ${CUDA}
should be replaced by either cpu
, cu92
, cu101
or cu102
depending on your PyTorch installation.
cpu |
cu92 |
cu101 |
cu102 |
|
---|---|---|---|---|
Linux | ✅ | ✅ | ✅ | ✅ |
Windows | ✅ | ❌ | ✅ | ✅ |
macOS | ✅ |
PyTorch 1.5.0/1.5.1
To install the binaries for PyTorch 1.5.0/1.5.1, simply run
$ pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html
$ pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html
$ pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html
$ pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.5.0.html
$ pip install torch-geometric
$ pip install torch-geometric-temporal
where ${CUDA}
should be replaced by either cpu
, cu92
, cu101
or cu102
depending on your PyTorch installation.
cpu |
cu92 |
cu101 |
cu102 |
|
---|---|---|---|---|
Linux | ✅ | ✅ | ✅ | ✅ |
Windows | ✅ | ❌ | ✅ | ✅ |
macOS | ✅ |
PyTorch 1.4.0
To install the binaries for PyTorch 1.4.0, simply run
$ pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
$ pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
$ pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
$ pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
$ pip install torch-geometric
$ pip install torch-geometric-temporal
where ${CUDA}
should be replaced by either cpu
, cu92
, cu100
or cu101
depending on your PyTorch installation.
cpu |
cu92 |
cu100 |
cu101 |
|
---|---|---|---|---|
Linux | ✅ | ✅ | ✅ | ✅ |
Windows | ✅ | ❌ | ❌ | ✅ |
macOS | ✅ |
Running tests
$ python setup.py test
Running examples
$ cd examples
$ python gconvgru_example.py
License