This repository contains the official code of the paper Understanding and Extending Subgraph GNNs by Rethinking Their Symmetries (NeurIPS 2022 Oral).
The master branch contains the most recent version of the code, using pyg=2.2.0 and pytorch=1.13.1.
For the version submitted at NeurIPS 2022, check out the neurips22 tag.
The code builds on top of the ESAN framework.
First create a conda environment
conda env create -f environment.yml
and activate it
conda activate subgraph
Then, set-up wandb.
We provide the hyperparameter configurations to obtain the reported results on ogbg-molhiv (Table 2) and ZINC (Table 1).
Prepare the data
python data.py --dataset ZINC --policies ego_nets ego_nets_plus
python data.py --dataset ogbg-molhiv --policies ego_nets_plusObtain a sweep id <sweep-id> by running
wandb sweep configs/deterministic/<config-name>where configs/deterministic/<config-name> is one between configs/deterministic/SUN-ogbg-molhiv.yaml and configs/deterministic/SUN-ZINC.yaml.
Run the 10 seeds with
wandb agent <sweep-id>and compute mean and std of Metric/test_mean over the runs in the sweep to obtain SUN results in Tables 1, 2.
First, prepare the data. Run
python data.py --dataset $DATASET --policies $POLICYwhere $DATASET is one of the following:
- TUDatasets (MUTAG, PTC, PROTEINS, NCI1, NCI109, IMDB-BINARY, IMDB-MULTI) - Table 4
- graphproperty - Table 5
- subgraphcount (aka counting substructures) - Table 1
and $POLICY is one of the following:
- ego_nets
- ego_nets_plus
- node_marked
- null
To perform hyperparameter tuning, make use of wandb:
-
In
configs/deterministicfolder, choose theyamlfile corresponding to the dataset of interest, say<config-name>. This file contains the hyperparameters grid. -
Run
wandb sweep configs/deterministic/<config-name>
to obtain a sweep id
<sweep-id> -
Run the hyperparameter tuning with
wandb agent <sweep-id>
You can run the above command multiple times on each machine you would like to contribute to the grid-search
-
Open your project in your wandb account on the browser to see the results:
-
For the TUDatasets refer to
Metric/valid_meanandMetric/valid_stdto obtain the results. -
For graphproperty and subgraphcount, compute mean and std of
Metric/train_mean,Metric/valid_mean,Metric/test_meanby grouping over all hyperparameters and averaging over the different seeds. Then, take the results corresponding to the configuration obtaining the best validation metric.
-
Note that in configs/deterministic/SUN-subgraphcount.yaml,
key task_idx indicates the target, that is, 0, 1, 2, 3 indicates respectively Triangle, Tailed Tri., Star and 4-Cycle.
Similarly in configs/deterministic/SUN-graphproperty.yaml, key task_idx 0, 1, 2 indicates respectively IsConnected, Diameter, Radius.
Values for GIN and GNN-AK models are obtained with the GNN-AK code; DSS-GNN, DS-GNN and NGNN values can be obtained by running the code in this repo with the appropriate model.
We report results for these methods in the out/ folder.
SUN curves can be obtained as detailed below.
Prepare the data
python data.py --dataset subgraphcount --policies ego_netsRun
for i in {1..10}; do python plot.py --batch_size=128 --channels=96 --dataset=subgraphcount --drop_ratio=0 --emb_dim=110 --epochs=250 --gnn_type=originalgin --jk=concat --learning_rate=0.001 --model=sun --num_layer=5 --policy=ego_nets --task_idx=3 --seed="$i"; doneThen, plot the curve in ego_nets-plot.pdf by running
python make_plot.py --policy ego_netsPrepare the data
python data.py --dataset subgraphcount --policies ego_nets_plusRun
for i in {1..10}; do python plot.py --batch_size=128 --channels=96 --dataset=subgraphcount --drop_ratio=0 --emb_dim=96 --epochs=250 --gnn_type=originalgin --jk=concat --learning_rate=0.001 --model=sun --num_layer=6 --policy=ego_nets_plus --task_idx=3 --seed="$i"; doneThen, plot the curve in ego_nets_plus-plot.pdf by running
python make_plot.py --policy ego_nets_plusPrepare the data
python data.py --dataset ZINC --policies ego_netsRun
for i in {1..10}; do python plot.py --batch_size=128 --channels=96 --dataset=ZINC --drop_ratio=0 --emb_dim=64 --epochs=400 --gnn_type=zincgin --learning_rate=0.001 --model=sun --num_layer=6 --patience=40 --policy=ego_nets --num_hops=3 --seed="$i"; doneThen, plot the curve in ego_nets-ZINC-plot.pdf by running
python make_plot_zinc.pyFor attribution in academic contexts, please cite
@inproceedings{frasca2022understanding,
title={Understanding and Extending Subgraph GNNs by Rethinking Their Symmetries},
author={Frasca, Fabrizio and Bevilacqua, Beatrice and Bronstein, Michael M and Maron, Haggai},
booktitle={Advances in Neural Information Processing Systems},
year={2022},
}
