Skip to content

Commit

Permalink
Add G-retriever (GNN+LLM) example (#9167)
Browse files Browse the repository at this point in the history
1. #9462
2. #9480
3. #9481
4. **->** #9167

---

repro:
Latest NVIDIA PyG container
+
`git config --global credential.helper store; huggingface-cli login; cd
/opt/pyg; pip uninstall -y torch-geometric; rm -rf pytorch_geometric;
git clone -b gnn-llm-model-integration
https://github.com/pyg-team/pytorch_geometric.git; cd
/opt/pyg/pytorch_geometric; pip install .; pip install peft datasets
transformers pcst_fast sentencepiece; python3
examples/llm_plus_gnn/g_retriever.py`

old PR: #9154

note: pure cpu is 220x slower than pure GPU using a single Grace Hopper
(for llama-7b)

info:
tried gemma, performs worse in all train/val/test metrics. most likely
needs some tuning, will leave this as future work as part of the
community sprint to try many LLM and GNN combos and tune them. Therefore
keeping the default llama2

the new gemma-v2 is also much worse than llama2

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
4 people authored Sep 13, 2024
1 parent bfc6d1a commit 12421c2
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added the `WebQSPDataset` dataset ([#9481](https://github.com/pyg-team/pytorch_geometric/pull/9481))
- Added the `GRetriever` model ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480))
- Added the `GRetriever` model and an example ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480), [#9167](https://github.com/pyg-team/pytorch_geometric/pull/9167))
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
Expand Down
6 changes: 3 additions & 3 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| ------- | ----------- |
| | |
| Example | Description |
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
272 changes: 272 additions & 0 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
"""This example implements the G-Retriever model
(https://arxiv.org/abs/2402.07630) using PyG.
G-Retriever significantly reduces hallucinations by 54% compared to the
stand-alone LLM baseline.
Requirements:
`pip install datasets transformers pcst_fast sentencepiece accelerate`
"""
import argparse
import math
import os.path as osp
import re
import time

import pandas as pd
import torch
from torch import Tensor
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import WebQSPDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM


def compute_metrics(eval_output):
df = pd.concat([pd.DataFrame(d) for d in eval_output])
all_hit = []
all_precision = []
all_recall = []
all_f1 = []

for pred, label in zip(df.pred.tolist(), df.label.tolist()):
try:
pred = pred.split('[/s]')[0].strip().split('|')
hit = re.findall(pred[0], label)
all_hit.append(len(hit) > 0)

label = label.split('|')
matches = set(pred).intersection(set(label))
precision = len(matches) / len(set(label))
recall = len(matches) / len(set(pred))
if recall + precision == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)

all_precision.append(precision)
all_recall.append(recall)
all_f1.append(f1)

except Exception as e:
print(f'Label: {label}')
print(f'Pred: {pred}')
print(f'Exception: {e}')
print('------------------')

hit = sum(all_hit) / len(all_hit)
precision = sum(all_precision) / len(all_precision)
recall = sum(all_recall) / len(all_recall)
f1 = sum(all_f1) / len(all_f1)

print(f'Hit: {hit:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')


def save_params_dict(model, save_path):
state_dict = model.state_dict()
param_grad_dict = {
k: v.requires_grad
for (k, v) in model.named_parameters()
}
for k in list(state_dict.keys()):
if k in param_grad_dict.keys() and not param_grad_dict[k]:
del state_dict[k] # Delete parameters that do not require gradient
torch.save(state_dict, save_path)


def load_params_dict(model, save_path):
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)
return model


def get_loss(model, batch, model_save_name) -> Tensor:
if model_save_name == 'llm':
return model(batch.question, batch.label, batch.desc)
else:
return model(batch.question, batch.x, batch.edge_index, batch.batch,
batch.label, batch.edge_attr, batch.desc)


def inference_step(model, batch, model_save_name):
if model_save_name == 'llm':
return model.inference(batch.question, batch.desc)
else:
return model.inference(batch.question, batch.x, batch.edge_index,
batch.batch, batch.edge_attr, batch.desc)


def train(
num_epochs,
hidden_channels,
num_gnn_layers,
batch_size,
eval_batch_size,
lr,
checkpointing=False,
tiny_llama=False,
):
def adjust_learning_rate(param_group, LR, epoch):
# Decay the learning rate with half-cycle cosine after warmup
min_lr = 5e-6
warmup_epochs = 1
if epoch < warmup_epochs:
lr = LR
else:
lr = min_lr + (LR - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * (epoch - warmup_epochs) /
(num_epochs - warmup_epochs)))
param_group['lr'] = lr
return lr

start_time = time.time()
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', '..', 'data', 'WebQSPDataset')
train_dataset = WebQSPDataset(path, split='train')
val_dataset = WebQSPDataset(path, split='val')
test_dataset = WebQSPDataset(path, split='test')

seed_everything(42)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
drop_last=True, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)

gnn = GAT(
in_channels=1024,
hidden_channels=hidden_channels,
out_channels=1024,
num_layers=num_gnn_layers,
heads=4,
)
if tiny_llama:
llm = LLM(
model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
num_params=1,
)
model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048)
else:
llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)
model = GRetriever(llm=llm, gnn=gnn)

model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm'
params = [p for _, p in model.named_parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{
'params': params,
'lr': lr,
'weight_decay': 0.05
},
], betas=(0.9, 0.95))
grad_steps = 2

best_epoch = 0
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
epoch_loss = 0
if epoch == 0:
print(f"Total Preparation Time: {time.time() - start_time:2f}s")
start_time = time.time()
print("Training beginning...")
epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'
loader = tqdm(train_loader, desc=epoch_str)
for step, batch in enumerate(loader):
optimizer.zero_grad()
loss = get_loss(model, batch, model_save_name)
loss.backward()

clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)

if (step + 1) % grad_steps == 0:
adjust_learning_rate(optimizer.param_groups[0], lr,
step / len(train_loader) + epoch)

optimizer.step()
epoch_loss = epoch_loss + float(loss)

if (step + 1) % grad_steps == 0:
lr = optimizer.param_groups[0]['lr']
train_loss = epoch_loss / len(train_loader)
print(epoch_str + f', Train Loss: {train_loss:4f}')

val_loss = 0
eval_output = []
model.eval()
with torch.no_grad():
for step, batch in enumerate(val_loader):
loss = get_loss(model, batch, model_save_name)
val_loss += loss.item()
val_loss = val_loss / len(val_loader)
print(epoch_str + f", Val Loss: {val_loss:4f}")
if checkpointing and val_loss < best_val_loss:
print("Checkpointing best model...")
best_val_loss = val_loss
best_epoch = epoch
save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt')
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

if checkpointing and best_epoch != num_epochs - 1:
print("Loading best checkpoint...")
model = load_params_dict(
model,
f'{model_save_name}_best_val_loss_ckpt.pt',
)

model.eval()
eval_output = []
print("Final evaluation...")
progress_bar_test = tqdm(range(len(test_loader)))
for step, batch in enumerate(test_loader):
with torch.no_grad():
pred = inference_step(model, batch, model_save_name)
eval_data = {
'pred': pred,
'question': batch.question,
'desc': batch.desc,
'label': batch.label
}
eval_output.append(eval_data)
progress_bar_test.update(1)

compute_metrics(eval_output)
print(f"Total Training Time: {time.time() - start_time:2f}s")
save_params_dict(model, f'{model_save_name}.pt')
torch.save(eval_output, f'{model_save_name}_eval_outs.pt')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gnn_hidden_channels', type=int, default=1024)
parser.add_argument('--num_gnn_layers', type=int, default=4)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--eval_batch_size', type=int, default=16)
parser.add_argument('--checkpointing', action='store_true')
parser.add_argument('--tiny_llama', action='store_true')
args = parser.parse_args()

start_time = time.time()
train(
args.epochs,
args.gnn_hidden_channels,
args.num_gnn_layers,
args.batch_size,
args.eval_batch_size,
args.lr,
checkpointing=args.checkpointing,
tiny_llama=args.tiny_llama,
)
print(f"Total Time: {time.time() - start_time:2f}s")
22 changes: 18 additions & 4 deletions torch_geometric/nn/models/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from torch import Tensor

from torch_geometric.nn.models import GAT
from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS
from torch_geometric.utils import scatter

Expand Down Expand Up @@ -43,7 +42,6 @@ def __init__(
llm: LLM,
gnn: torch.nn.Module,
use_lora: bool = False,
gnn_to_use=GAT,
mlp_out_channels: int = 4096,
) -> None:
super().__init__()
Expand Down Expand Up @@ -126,7 +124,15 @@ def forward(
"""
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(x.size(0), dim=0)
xs = x.split(1, dim=0)

# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
if len(batch_unique) < batch_size:
xs = [
xs[i] if i in batch_unique else None for i in range(batch_size)
]

(
inputs_embeds,
Expand Down Expand Up @@ -174,7 +180,15 @@ def inference(
"""
x = self.encode(x, edge_index, batch, edge_attr)
x = self.projector(x)
xs = x.split(x.size(0), dim=0)
xs = x.split(1, dim=0)

# Handle questions without node features:
batch_unique = batch.unique()
batch_size = len(question)
if len(batch_unique) < batch_size:
xs = [
xs[i] if i in batch_unique else None for i in range(batch_size)
]

inputs_embeds, attention_mask, _ = self.llm._get_embeds(
question, additional_text_context, xs)
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/nlp/llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from contextlib import nullcontext
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
self.word_embedding = self.llm.model.get_input_embeddings()

if 'max_memory' not in kwargs: # Pure CPU:
warnings.warn("LLM is being used on CPU, which may be slow")
self.device = torch.device('cpu')
self.autocast_context = nullcontext()
else:
Expand Down

0 comments on commit 12421c2

Please sign in to comment.