Skip to content

Commit

Permalink
ruff & mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Oct 8, 2024
1 parent 30bd2e3 commit ccefd86
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 45 deletions.
34 changes: 19 additions & 15 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,17 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor
# If we're not bootstrapping, we could query the reward
# model here, but this is expensive/impractical. Instead
# just report forward and backward logprobs
# TODO: stop using dicts and used typed objects
data[i]["fwd_logprobs"] = torch.stack(data[i]["fwd_logprobs"]).reshape(-1)
data[i]["U_bck_logprobs"] = torch.stack(data[i]["U_bck_logprobs"]).reshape(-1)
data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum()
data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum()
data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() # type: ignore
data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() # type: ignore
data[i]["result"] = graphs[i]
if self.pad_with_terminal_state:
data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad)))
data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore
data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)])
data[i]["is_sink"].append(1)
assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"])
data[i]["is_sink"].append(1) # type: ignore
assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore
return data

def sample_backward_from_graphs(
Expand Down Expand Up @@ -198,16 +199,19 @@ def sample_backward_from_graphs(

for i in range(n):
# See comments in sample_from_model
data[i]["traj"] = data[i]["traj"][::-1]
# TODO: stop using dicts and used typed objects
data[i]["traj"] = data[i]["traj"][::-1] # type: ignore
# I think this pad is only necessary if we're padding terminal states???
data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1]
data[i]["is_sink"] = data[i]["is_sink"][::-1]
data[i]["U_bck_logprobs"] = torch.tensor([0] + data[i]["U_bck_logprobs"][::-1], device=dev).reshape(-1)
data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] # type: ignore
data[i]["is_sink"] = data[i]["is_sink"][::-1] # type: ignore
data[i]["U_bck_logprobs"] = torch.tensor(
[0] + data[i]["U_bck_logprobs"][::-1], device=dev # type: ignore
).reshape(-1)
if self.pad_with_terminal_state:
data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad)))
data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore
data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)])
data[i]["is_sink"].append(1)
assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"])
data[i]["is_sink"].append(1) # type: ignore
assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore
return data

def local_search_sample_from_model(
Expand Down Expand Up @@ -249,7 +253,7 @@ def local_search_sample_from_model(
] # type: ignore
graphs = [i["traj"][-1][0] for i in current_trajs]
done = [False] * n
fwd_a = []
fwd_a: List[GraphAction] = []
for i in range(cfg.num_bck_steps):
# This modifies `bck_trajs` & `graphs` in place, passing fwd_a computes P_F(s|s') for the previous step
self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a)
Expand All @@ -264,7 +268,7 @@ def local_search_sample_from_model(
{"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs
] # type: ignore
done = [False] * n
bck_a = []
bck_a: List[GraphAction] = []
while not all(done):
self._forward_step(model, fwd_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob, bck_a)
done = [d or (len(t["traj"]) + T) >= self.max_len for d, t, T in zip(done, fwd_trajs, trunc_lens)]
Expand All @@ -281,7 +285,7 @@ def local_search_sample_from_model(
sampled_terminals.extend(terminals)
for traj, term in zip(fwd_trajs, terminals):
traj["result"] = term
traj["is_accept"] = False
traj["is_accept"] = False # type: ignore
# Compute rewards for the acceptance
if compute_reward is not None:
compute_reward(fwd_trajs, cond_info)
Expand Down
3 changes: 0 additions & 3 deletions src/gflownet/algo/local_search_tb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import torch

from gflownet import GFNTask
from gflownet.algo.trajectory_balance import TrajectoryBalance
from gflownet.data.data_source import DataSource
from gflownet.utils.misc import get_worker_device


class LocalSearchTB(TrajectoryBalance):
Expand Down
4 changes: 0 additions & 4 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,6 @@ def compute_batch_losses(
if self.cfg.mle_loss_multiplier != 0:
info["mle_loss"] = mle_loss.item()

if not torch.isfinite(loss):
import pdb

pdb.set_trace()
return loss, info

def analytical_maxent_backward(self, batch, first_graph_idx):
Expand Down
11 changes: 1 addition & 10 deletions src/gflownet/data/data_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
import warnings
from typing import Callable, Generator, List, Optional

Expand Down Expand Up @@ -92,19 +93,9 @@ def __iter__(self):
raise e
print(f"Error in DataSource: {e} [tol={self._err_tol}]")
# print full traceback
import sys
import traceback

traceback.print_exc()
continue
except:
print("Unknown error in DataSource")
import sys
import traceback

traceback.print_exc()
self._err_tol -= 1
continue

def validate_batch(self, batch, trajs):
for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + (
Expand Down
10 changes: 5 additions & 5 deletions src/gflownet/data/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import heapq
from threading import Lock
from typing import List
from typing import Any, List

import numpy as np
import torch
Expand All @@ -15,16 +15,16 @@ def __init__(self, cfg: Config):
Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories)
In self.push(), the buffer detaches any torch tensor and sends it to the CPU.
"""
self.capacity = cfg.replay.capacity
self.warmup = cfg.replay.warmup
self.capacity = cfg.replay.capacity or int(1e6)
self.warmup = cfg.replay.warmup or 0
assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity"

self.buffer: List[tuple] = []
self.position = 0

self.treat_as_heap = cfg.replay.keep_highest_rewards
self.filter_uniques = cfg.replay.keep_only_uniques
self._uniques = set()
self._uniques: set[Any] = set()

self._lock = Lock()

Expand Down Expand Up @@ -56,7 +56,7 @@ def push(self, *args, unique_obj=None, priority=None):
self._uniques.add(unique_obj)
else:
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer.append(())
if self.filter_uniques:
if self.position == 0 and len(self.buffer) == self.capacity:
# We're about to wrap around, so remove the oldest element
Expand Down
2 changes: 1 addition & 1 deletion src/gflownet/models/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap
sc = self.logit_scaler(
g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)
)
cat.logits = [l * sc[b] for l, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them
cat.logits = [lg * sc[b] for lg, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them
return cat

def forward(self, g: gd.Batch):
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/online_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import os
import pathlib
from typing import Any

import git
import torch
Expand Down Expand Up @@ -137,7 +137,7 @@ def setup(self):

def step(self, loss: Tensor, train_it: int):
loss.backward()
info = {}
info: dict[str, Any] = {}
if train_it % self.cfg.algo.grad_acc_steps != 0:
return info
if self.cfg.opt.clip_grad_type is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/gflownet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,11 @@ def _send_models_to_device(self):
self.model.to(self.device)
self.sampling_model.to(self.device)
if self.world_size > 1:
self.model = DistributedDataParallel(
self.model = nn.parallel.DistributedDataParallel(
self.model.to(self.rank), device_ids=[self.rank], output_device=self.rank
)
if self.sampling_model is not self.model:
self.sampling_model = DistributedDataParallel(
self.sampling_model = nn.parallel.DistributedDataParallel(
self.sampling_model.to(self.rank), device_ids=[self.rank], output_device=self.rank
)

Expand Down
6 changes: 3 additions & 3 deletions src/gflownet/utils/multiprocessing_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def run(self):
break
timeouts = 0
attr, args, kwargs = r
if hasattr(self.obj, "lock"):
f.lock.acquire()
if hasattr(self.obj, "lock"): # TODO: this is not used anywhere?
self.obj.lock.acquire()
f = getattr(self.obj, attr)
args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args]
kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()}
Expand All @@ -293,7 +293,7 @@ def run(self):
msg = self.to_cpu(result)
self.out_queues[qi].put(self.encode(msg))
if hasattr(self.obj, "lock"):
f.lock.release()
self.obj.lock.release()

def terminate(self):
self.stop.set()
Expand Down

0 comments on commit ccefd86

Please sign in to comment.