Skip to content

Commit

Permalink
QoL improvements and more docs (#126)
Browse files Browse the repository at this point in the history
This PR is (unfortunately) fairly massive, but is mostly renaming things, or removing things that aren't needed anymore.
Considering all the renaming and breaking, bumping the version to 0.2.0 🎉  

- Updated `README.md`
  - Added `docs/contributing.md` and `docs/getting_started.md`
  - Added (a link to) an introductory notebook, runnable on colab
- Renamed `FlatRewards` into `ObjectProperties`, as well as related function names
- Renamed `RewardScalar` into either `LogScalar` or `LinScalar`, as well as changed the function signature of functions accordingly
- Removed all passing around of `rng` instances, replaced by `get_worker_rng` to remove any confusion around worker seeding
   - Similarly, `dev` should now be acquired via `get_worker_device`
- `None`-proofed `cond_info` in models. There should now be a sensible default behavior. Eventually we should properly make sure the code support unconditional tasks
- Made `cfg.git_hash` computational conditional on being in a git repo
- Renamed generic functions and methods working with objects to `obj[s]` rather than `mol[s]`
- Made `terminate` close log files (to avoid issues when a single process runs multiple trials)
- Added a `read_all_results` function in `sqlite_log.py`. This is probably the fastest way to load log data
- Fixed a number of `mypy` issues.
  • Loading branch information
bengioe authored May 8, 2024
1 parent 02eaed3 commit f106cde
Show file tree
Hide file tree
Showing 36 changed files with 500 additions and 433 deletions.
54 changes: 30 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,14 @@ GFlowNet-related training and environment code on graphs.

**Primer**

[GFlowNet](https://yoshuabengio.org/2022/03/05/generative-flow-networks/), short for Generative Flow Network, is a novel generative modeling framework, particularly suited for discrete, combinatorial objects. Here in particular it is implemented for graph generation.
GFlowNet [[1]](https://yoshuabengio.org/2022/03/05/generative-flow-networks/), [[2]](https://www.gflownet.org/), [[3]](https://github.com/zdhNarsil/Awesome-GFlowNets), short for Generative Flow Network, is a novel generative modeling framework, particularly suited for discrete, combinatorial objects. Here in particular it is implemented for graph generation.

The idea behind GFN is to estimate flows in a (graph-theoretic) directed acyclic network*. The network represents all possible ways of constructing an object, and so knowing the flow gives us a policy which we can follow to sequentially construct objects. Such a sequence of partially constructed objects is a _trajectory_. *Perhaps confusingly, the _network_ in GFN refers to the state space, not a neural network architecture.
The idea behind GFN is to estimate flows in a (graph-theoretic) directed acyclic network*. The network represents all possible ways of constructing objects, and so knowing the flow gives us a policy which we can follow to sequentially construct objects. Such a sequence of partially constructed objects is a _trajectory_. *Perhaps confusingly, the _network_ in GFN refers to the state space, not a neural network architecture.

Here the objects we construct are themselves graphs (e.g. graphs of atoms), which are constructed node by node. To make policy predictions, we use a graph neural network. This GNN outputs per-node logits (e.g. add an atom to this atom, or add a bond between these two atoms), as well as per-graph logits (e.g. stop/"done constructing this object").
The main focus of this library (although it can do other things) is to construct graphs (e.g. graphs of atoms), which are constructed node by node. To make policy predictions, we use a graph neural network. This GNN outputs per-node logits (e.g. add an atom to this atom, or add a bond between these two atoms), as well as per-graph logits (e.g. stop/"done constructing this object").

The GNN model can be trained on a mix of existing data (offline) and self-generated data (online), the latter being obtained by querying the model sequentially to obtain trajectories. For offline data, we can easily generate trajectories since we know the end state.
This library supports a variety of GFN algorithms (as well as some baselines), and supports training on a mix of existing data (offline) and self-generated data (online), the latter being obtained by querying the model sequentially to obtain trajectories.

## Repo overview

- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities.
- [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data.
- [examples](docs/examples), contains simple example implementations of GFlowNet.
- [models](src/gflownet/models), contains model definitions.
- [tasks](src/gflownet/tasks), contains training code.
- [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward.
- [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein
- [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives).
- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning).
- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models.
- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop.

See [implementation notes](docs/implementation_notes.md) for more.

## Getting started

A good place to get started is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`).

## Installation

Expand All @@ -62,6 +42,30 @@ pip install git+https://github.com/recursionpharma/[email protected] --find-l

If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main-3.10.txt`.

## Getting started

A good place to get started immediately is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`).

For a gentler introduction to the library, see [Getting Started](docs/getting_started.md). For a more in-depth look at the library, see [Implementation Notes](docs/implementation_notes.md).

## Repo overview

- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories.
- [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities.
- [envs](src/gflownet/envs), contains environment classes; the base environment is agnostic to what kind of graph is being made, and context classes specify mappings from graphs to objects (e.g. molecules) and torch geometric Data.
- [examples](docs/examples), contains simple example implementations of GFlowNet.
- [models](src/gflownet/models), contains model definitions.
- [tasks](src/gflownet/tasks), contains training code.
- [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward.
- [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein
- [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives).
- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning).
- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models.
- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop.

See [implementation notes](docs/implementation_notes.md) for more.


## Developing & Contributing

External contributions are welcome.
Expand All @@ -73,3 +77,5 @@ pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.

We use `tox` to run tests and linting, and `pre-commit` to run checks before committing.
To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively.

For more information, see [Contributing](docs/contributing.md).
38 changes: 38 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Contributing

Contributions to the repository are welcome, and we encourage you to open issues and pull requests. In general, it is recommended to fork this repository and open a pull request from your fork to the `trunk` branch. PRs are encouraged to be short and focused, and to include tests and documentation where appropriate.

## Installation

To install the developers dependencies run:
```
pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html
```

## Dependencies

Dependencies are defined in `pyproject.toml`, and frozen versions that are known to work are provided in `requirements/`.

To regenerate the frozen versions, run `./generate_requirements.sh <ENV-NAME>`. See comments within.

## Linting and testing

We use `tox` to run tests and linting, and `pre-commit` to run checks before committing.
To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively.

`tox` itself runs many linters, but the most important ones are `black`, `ruff`, `isort`, and `mypy`. The full list
of linting tools is found in `.pre-commit-config.yaml`, while `tox.ini` defines the environments under which these
linters (as well as tests) are run.

## Github Actions

We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`.

The cascade of events is as follows:
- For `build-and-test`, `tox -> testenv:py310 -> pytest` is run.
- For `code-quality`, `tox -e style -> testenv:style -> pre-commit -> {isort, black, mypy, bandit, ruff, & others}`. This and the "others" are defined in `.pre-commit-config.yaml` and include things like checking for secrets and trailing whitespace.

## Style Guide

On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html).
Our docstrings follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) format, and we use type hints throughout the codebase.
11 changes: 11 additions & 0 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Getting Started

For an introduction to the library, see [this colab notebook](https://colab.research.google.com/drive/1wANyo6Y-ceYEto9-p50riCsGRb_6U6eH).

For an introduction to using `wandb` to log experiments, see [this demo](../src/gflownet/hyperopt/wandb_demo).

For more general introductions to GFlowNets, check out the following:
- The 2023 [GFlowNet workshop](https://gflownet.org/) has several introductory talks and colab tutorials.
- This high-level [GFlowNet colab tutorial](https://colab.research.google.com/drive/1fUMwgu2OhYpQagpzU5mhe9_Esib3Q2VR) (updated versions of which were written for the 2023 workshop, in particular for continuous GFNs).

A good place to get started immediately is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`).
52 changes: 35 additions & 17 deletions src/gflownet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import Dict, List, NewType, Optional, Tuple
from typing import Any, Dict, List, NewType, Optional, Tuple

import torch_geometric.data as gd
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor, nn

from .config import Config

# This type represents an unprocessed list of reward signals/conditioning information
FlatRewards = NewType("FlatRewards", Tensor) # type: ignore
# This type represents a set of scalar properties attached to each object in a batch.
ObjectProperties = NewType("ObjectProperties", Tensor) # type: ignore

# This type represents the outcome for a multi-objective task of
# converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta
RewardScalar = NewType("RewardScalar", Tensor) # type: ignore
# This type represents log-scalars, in particular log-rewards at the scale we operate with with GFlowNets
# for example, converting a reward ObjectProperties to a log-scalar with log [(sum R_i omega_i) ** beta]
LogScalar = NewType("LogScalar", Tensor) # type: ignore
# This type represents linear-scalars
LinScalar = NewType("LinScalar", Tensor) # type: ignore


class GFNAlgorithm:
Expand Down Expand Up @@ -75,15 +76,15 @@ def get_random_action_prob(self, it: int):


class GFNTask:
def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar:
"""Combines a minibatch of reward signal vectors and conditional information into a scalar reward.
Parameters
----------
cond_info: Dict[str, Tensor]
A dictionary with various conditional informations (e.g. temperature)
flat_reward: FlatRewards
A 2d tensor where each row represents a series of flat rewards.
obj_props: ObjectProperties
A 2d tensor where each row represents a series of object properties.
Returns
-------
Expand All @@ -92,18 +93,35 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat
"""
raise NotImplementedError()

def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
"""Compute the flat rewards of mols according the the tasks' proxies
def compute_obj_properties(self, objs: List[Any]) -> Tuple[ObjectProperties, Tensor]:
"""Compute the flat rewards of objs according the the tasks' proxies
Parameters
----------
mols: List[RDMol]
A list of RDKit molecules.
objs: List[Any]
A list of n objects.
Returns
-------
reward: FlatRewards
A 2d tensor, a vector of scalar reward for valid each molecule.
obj_probs: ObjectProperties
A 2d tensor (m, p), a vector of scalar properties for the m <= n valid objects.
is_valid: Tensor
A 1d tensor, a boolean indicating whether the molecule is valid.
A 1d tensor (n,), a boolean indicating whether each object is valid.
"""
raise NotImplementedError()

def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]:
"""Sample conditional information for n objects
Parameters
----------
n: int
The number of objects to sample conditional information for.
train_it: int
The training iteration number.
Returns
-------
cond_info: Dict[str, Tensor]
A dictionary with various conditional informations (e.g. temperature)
"""
raise NotImplementedError()
13 changes: 4 additions & 9 deletions src/gflownet/algo/advantage_actor_critic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data as gd
Expand All @@ -16,7 +15,6 @@ def __init__(
self,
env: GraphBuildingEnv,
ctx: GraphBuildingEnvContext,
rng: np.random.RandomState,
cfg: Config,
):
"""Advantage Actor-Critic implementation, see
Expand All @@ -34,15 +32,12 @@ def __init__(
A graph environment.
ctx: GraphBuildingEnvContext
A context.
rng: np.random.RandomState
rng used to take random actions
cfg: Config
The experiment configuration
"""
self.ctx = ctx
self.env = env
self.rng = rng
self.max_len = cfg.algo.max_len
self.max_nodes = cfg.algo.max_nodes
self.illegal_action_logreward = cfg.algo.illegal_action_logreward
Expand All @@ -54,7 +49,7 @@ def __init__(
# Experimental flags
self.sample_temp = 1
self.do_q_prime_correction = False
self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp)
self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, self.sample_temp)

def create_training_data_from_own_samples(
self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float
Expand Down Expand Up @@ -82,7 +77,7 @@ def create_training_data_from_own_samples(
"""
dev = get_worker_device()
cond_info = cond_info.to(dev)
data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob)
data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob)
return data

def create_training_data_from_graphs(self, graphs):
Expand Down Expand Up @@ -152,12 +147,12 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap:
# of length 4, trajectory 1 of length 3, and so on.
batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens)

# Forward pass of the model, returns a GraphActionCategorical and per molecule predictions
# Forward pass of the model, returns a GraphActionCategorical and per graph predictions
# Here we will interpret the logits of the fwd_cat as Q values
policy, per_state_preds = model(batch, cond_info[batch_idx])
V = per_state_preds[:, 0]
G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1
G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid mol
G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object
A = G - V
log_probs = policy.log_prob(batch.actions)

Expand Down
Loading

0 comments on commit f106cde

Please sign in to comment.