Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CoGNN #9693

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/cognn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os.path as osp

import torch
from torch.nn import CrossEntropyLoss, ModuleList

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCN
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.models import CoGNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Cora')
dataset = Planetoid(root=path, name='cora')
data = dataset[0].to(device)
loss_fn = CrossEntropyLoss()
num_features = data.x.shape[1]
num_classes = data.y.max() + 1

# Use all message passing edges as training labels:
env_layer1 = GCNConv(
in_channels=num_features,
out_channels=num_features,
)

env_layer2 = GCNConv(
in_channels=num_features,
out_channels=num_classes,
)

action_net = GCN(
in_channels=num_features,
out_channels=4,
hidden_channels=num_features,
num_layers=2,
)

model = CoGNN(
env_net=ModuleList([env_layer1, env_layer2]),
action_net=action_net,
env_activation='relu',
temp=0.1,
dropout=0.1,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train() -> float:
model.train()
optimizer.zero_grad()
scores = model(data.x, data.edge_index)
loss = loss_fn(scores, data.y)
loss.backward()
optimizer.step()
return loss.item()


@torch.no_grad()
def test() -> float:
model.eval()
scores = model(data.x, data.edge_index)
predictions = scores.argmax(dim=1)
return (predictions == data.y).float().mean().item()


for epoch in range(1, 101):
loss = train()
accuracy = test()
print(f'Epoch: {epoch: 03d}, Loss: {loss: .4f}, accuracy: {accuracy}')
67 changes: 67 additions & 0 deletions test/nn/models/test_cognn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from torch.nn import ModuleList

from torch_geometric.nn import GCN
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.models.cognn import CoGNN


def test_cognn() -> None:
num_nodes = 10
num_features = 128
x = torch.randn(num_nodes, num_features)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
])

env_layer1 = GCNConv(
in_channels=num_features,
out_channels=num_features,
)

env_layer2 = GCNConv(
in_channels=num_features,
out_channels=num_features,
)

env_layer3 = GCNConv(
in_channels=num_features,
out_channels=1,
)

action_net = GCN(
in_channels=num_features,
out_channels=4,
hidden_channels=num_features,
num_layers=2,
)

model = CoGNN(
env_net=ModuleList([env_layer1, env_layer2, env_layer3]),
action_net=action_net,
env_activation='relu',
temp=0.1,
dropout=0.0,
)
assert str(model) == '''CoGNN(
(env_net): ModuleList(
(0-1): 2 x GCNConv(128, 128)
(2): GCNConv(128, 1)
)
(action_net): GCN(128, 4, num_layers=2)
(activation): ReLU()
(dropout): Dropout(p=0.0, inplace=False)
)'''

# Train:
model.train()
out = model(x, edge_index)
assert out.size() == (num_nodes, 1)
assert out.min().item() >= -0.1 and out.max().item() < 1.3

# Test:
model.eval()
out = model(x, edge_index)
assert out.size() == (num_nodes, 1)
assert out.min().item() >= -0.1 and out.max().item() < 1.3
1 change: 1 addition & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .neural_fingerprint import NeuralFingerprint
from .visnet import ViSNet
from .g_retriever import GRetriever
from .cognn import CoGNN

# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
Expand Down
73 changes: 73 additions & 0 deletions torch_geometric/nn/models/cognn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Dict, List, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Module

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import Adj


class CoGNN(Module):
r"""The CoGNN model from the `"Cooperative Graph Neural Netowrks"
<https://arxiv.org/abs/2310.01267>`_ paper.

Args:
env_net (List[MessagePassing]): A list of MessagePassing modules, which compose the environment network.
action_net (Module): The action network.
env_activation (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
env_activation_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`activation`.
(default: :obj:`None`)
temp (float, Optional): The gumbel softmax temperature. (default: :obj:`0.01`)
dropout (float, Optional): The dropout ratio. (default: :obj:`0.0`)

.. note::
The env_net is assumed to have the same input dimension and the action_net is assumed to have an output
dimension of 4, see `examples/cognn.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/cognn.py>`_ and `"Cooperative Graph Neural Netowrks"
<https://arxiv.org/abs/2310.01267>`_ paper.
"""
def __init__(
self,
env_net: List[MessagePassing],
action_net: Module,
env_activation: str = 'relu',
env_activation_kwargs: Optional[Dict[str, Any]] = None,
temp: Optional[float] = 0.01,
dropout: Optional[float] = 0.0,
):
super().__init__()

self.env_net = env_net
self.action_net = action_net
self.activation = activation_resolver(env_activation,
**(env_activation_kwargs or {}))
self.temp = temp
self.dropout = Dropout(p=dropout)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
u, v = edge_index

for env_layer in self.env_net:
action_logits = self.action_net(x=x,
edge_index=edge_index) # (N, 4)

# sampling actions
incoming_edge_prob = F.gumbel_softmax(logits=action_logits[:, :2],
tau=self.temp, hard=True)
outgoing_edge_prob = F.gumbel_softmax(logits=action_logits[:, 2:],
tau=self.temp, hard=True)

# creating subgraph
keep_incoming_prob = incoming_edge_prob[:, 0]
keep_outgoing_prob = outgoing_edge_prob[:, 0]
edge_weight = keep_incoming_prob[v] * keep_outgoing_prob[u]

# message propagation
x = env_layer(x=x, edge_index=edge_index, edge_weight=edge_weight)
x = self.dropout(x)
x = self.activation(x)
return x
Loading