Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ scripts.py

models/
*.DS_Store
.python-version
123 changes: 123 additions & 0 deletions src/gfn/gym/perfect_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Callable

import torch

from gfn.env import Actions, DiscreteEnv, DiscreteStates
from gfn.states import States


class PerfectBinaryTree(DiscreteEnv):
r"""
Perfect Tree Environment where there is a bijection between trajectories and terminating states.
Nodes are represented by integers, starting from 0 for the root.
States are represented by a single integer tensor corresponding to the node index.
Actions are integers: 0 (left child), 1 (right child), 2 (exit).

e.g.:

0 (root)
/ \
1 2
/ \ / \
3 4 5 6
/ \ / \ / \ / \
7 8 9 10 11 12 13 14 (terminating states if depth=3)

Recommended preprocessor: `OneHotPreprocessor`.
"""

def __init__(self, reward_fn: Callable, depth: int = 4):
self.reward_fn = reward_fn
self.depth = depth
self.branching_factor = 2
self.n_actions = self.branching_factor + 1
self.n_nodes = 2 ** (self.depth + 1) - 1

self.s0 = torch.zeros((1,), dtype=torch.long)
self.sf = torch.full((1,), fill_value=-1, dtype=torch.long)
super().__init__(self.n_actions, self.s0, (1,), sf=self.sf)

(
self.transition_table,
self.inverse_transition_table,
self.term_states,
) = self._build_tree()

def _build_tree(self) -> tuple[dict, dict, DiscreteStates]:
"""Create a transition table ensuring a bijection between trajectories and last states."""
transition_table = {}
inverse_transition_table = {}
node_index = 0
queue = [(node_index, 0)] # (current_node, depth)

terminating_states_id = set()
while queue:
node, d = queue.pop(0)
if d < self.depth:
for a in range(self.branching_factor):
node_index += 1
transition_table[(node, a)] = node_index
inverse_transition_table[(node_index, a)] = node
queue.append((node_index, d + 1))
else:
terminating_states_id.add(node)
terminating_states_id = torch.tensor(list(terminating_states_id)).reshape(-1, 1)
terminating_states = self.states_from_tensor(terminating_states_id)

return transition_table, inverse_transition_table, terminating_states

def backward_step(self, states: DiscreteStates, actions: Actions) -> torch.Tensor:
tuples = torch.hstack((states.tensor, actions.tensor)).tolist()
tuples = tuple((tuple_) for tuple_ in tuples)
next_states_tns = [
self.inverse_transition_table.get(tuple(tuple_)) for tuple_ in tuples
]
next_states_tns = torch.tensor(next_states_tns).reshape(-1, 1)
next_states_tns = torch.tensor(next_states_tns).reshape(-1, 1).long()
return next_states_tns

def step(self, states: DiscreteStates, actions: Actions) -> torch.Tensor:
tuples = torch.hstack((states.tensor, actions.tensor)).tolist()
tuples = tuple(tuple(tuple_) for tuple_ in tuples)
next_states_tns = [self.transition_table.get(tuple_) for tuple_ in tuples]
next_states_tns = torch.tensor(next_states_tns).reshape(-1, 1).long()
return next_states_tns

def update_masks(self, states: DiscreteStates) -> None:
terminating_states_mask = torch.isin(
states.tensor, self.terminating_states.tensor
).flatten()
initial_state_mask = (states.tensor == self.s0).flatten()
even_states = (states.tensor % 2 == 0).flatten()

# Going from any node, we can choose action 0 or 1
# Except terminating states where we must end the trajectory
states.forward_masks[~terminating_states_mask, -1] = False
states.forward_masks[terminating_states_mask, :] = False
states.forward_masks[terminating_states_mask, -1] = True

# Even states are to the right, so tied to action 1
# Uneven states are to the left, tied to action 0
states.backward_masks[even_states, 1] = True
states.backward_masks[even_states, 0] = False
states.backward_masks[~even_states, 1] = False
states.backward_masks[~even_states, 0] = True

# Initial state has no available backward action
states.backward_masks[initial_state_mask, :] = False

def get_states_indices(self, states: States):
return torch.flatten(states.tensor)

@property
def all_states(self) -> DiscreteStates:
return self.states_from_tensor(torch.arange(self.n_nodes).reshape(-1, 1))

@property
def terminating_states(self) -> DiscreteStates:
lb = 2**self.depth - 1
ub = 2 ** (self.depth + 1) - 1
return self.make_states_class()(torch.arange(lb, ub).reshape(-1, 1))

def reward(self, final_states):
return self.reward_fn(final_states.tensor)
75 changes: 75 additions & 0 deletions src/gfn/gym/set_addition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Callable

import torch

from gfn.env import Actions, DiscreteEnv, DiscreteStates


class SetAddition(DiscreteEnv):
"""Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a one or two line explanation of this environment, e.g. states, actions... ?

[Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html)

The state is a binary vector of length `n_items`, where 1 indicates the presence of an item.
Actions are integers from 0 to `n_items - 1` to add the corresponding item, or `n_items` to exit.
Adding an existing item is invalid. The trajectory must end when `max_items` are present.

Recommended preprocessor: `IdentityPreprocessor`.
"""

def __init__(self, n_items: int, max_items: int, reward_fn: Callable):
self.n_items = n_items
self.reward_fn = reward_fn
self.max_traj_len = max_items
n_actions = n_items + 1
s0 = torch.zeros(n_items)
state_shape = (n_items,)

super().__init__(n_actions, s0, state_shape)

def get_states_indices(self, states: DiscreteStates):
states_raw = states.tensor

canonical_base = 2 ** torch.arange(
self.n_items - 1, -1, -1, device=states_raw.device
)
indices = (canonical_base * states_raw).sum(-1).long()
return indices

def update_masks(self, states: DiscreteStates) -> None:
trajs_that_must_end = states.tensor.sum(dim=1) >= self.max_traj_len
trajs_that_may_continue = states.tensor.sum(dim=1) < self.max_traj_len

states.forward_masks[trajs_that_may_continue, : self.n_items] = (
states.tensor[trajs_that_may_continue] == 0
)

# Disallow everything for trajs that must end
states.forward_masks[trajs_that_must_end, : self.n_items] = 0
states.forward_masks[..., -1] = 1 # Allow exit action

states.backward_masks[..., : self.n_items] = states.tensor != 0

# Disallow exit action if at s_0
at_initial_state = torch.all(states.tensor == 0, dim=1)
states.forward_masks[at_initial_state, -1] = 0

def step(self, states: DiscreteStates, actions: Actions) -> torch.Tensor:
new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add")
return new_states_tensor

def backward_step(self, states: DiscreteStates, actions: Actions):
new_states_tensor = states.tensor.scatter(-1, actions.tensor, -1, reduce="add")
return new_states_tensor

def reward(self, final_states: DiscreteStates) -> torch.Tensor:
return self.reward_fn(final_states.tensor)

@property
def all_states(self) -> DiscreteStates:
digits = torch.arange(0, 2, device=self.device)
all_states = torch.cartesian_prod(*[digits] * self.n_items)
return DiscreteStates(all_states)

@property
def terminating_states(self) -> DiscreteStates:
return self.all_states[1:] # Remove initial state s_0
176 changes: 176 additions & 0 deletions testing/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from gfn.env import NonValidActionsError
from gfn.gym import Box, DiscreteEBM, HyperGrid
from gfn.gym.graph_building import GraphBuilding
from gfn.gym.perfect_tree import PerfectBinaryTree
from gfn.gym.set_addition import SetAddition
from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor, OneHotPreprocessor
from gfn.states import GraphStates

Expand Down Expand Up @@ -505,3 +507,177 @@ def test_graph_env():
)
states = env._backward_step(states, actions)
assert states.tensor.x.shape == (0, FEATURE_DIM)


def test_set_addition_fwd_step():
N_ITEMS = 4
MAX_ITEMS = 3
BATCH_SIZE = 2

env = SetAddition(
n_items=N_ITEMS, max_items=MAX_ITEMS, reward_fn=lambda s: s.sum(-1)
)
states = env.reset(batch_shape=BATCH_SIZE)
assert states.tensor.shape == (BATCH_SIZE, N_ITEMS)

# Add item 0 and 1
actions = env.actions_from_tensor(format_tensor([0, 1]))
states = env._step(states, actions)
expected_states = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=torch.float)
assert torch.equal(states.tensor, expected_states)

# Add item 2 and 3
actions = env.actions_from_tensor(format_tensor([2, 3]))
states = env._step(states, actions)
expected_states = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.float)
assert torch.equal(states.tensor, expected_states)

# Try adding existing items (invalid)
actions = env.actions_from_tensor(format_tensor([0, 1]))
with pytest.raises(NonValidActionsError):
env._step(states, actions)

# Add item 3 and 0
actions = env.actions_from_tensor(format_tensor([3, 0]))
states = env._step(states, actions)
expected_states = torch.tensor([[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.float)
assert torch.equal(states.tensor, expected_states) # Now has 3 items

# Try adding another item (invalid, max_items reached)
actions = env.actions_from_tensor(format_tensor([1, 2]))
with pytest.raises(NonValidActionsError):
env._step(states, actions)

# Exit action (valid)
actions = env.actions_from_tensor(format_tensor([N_ITEMS, N_ITEMS]))
final_states = env._step(states, actions)
assert torch.all(final_states.is_sink_state)

# Check rewards
rewards = env.reward(states)
expected_rewards = torch.tensor([3.0, 3.0])
assert torch.allclose(rewards, expected_rewards)


def test_set_addition_bwd_step():
N_ITEMS = 5
MAX_ITEMS = 4
BATCH_SIZE = 2

env = SetAddition(
n_items=N_ITEMS, max_items=MAX_ITEMS, reward_fn=lambda s: s.sum(-1)
)

# Start from a state with 3 items
initial_tensor = torch.tensor([[1, 1, 0, 1, 0], [0, 1, 1, 0, 1]], dtype=torch.float)
states = env.states_from_tensor(initial_tensor)

# Remove item 1 and 2
actions = env.actions_from_tensor(format_tensor([1, 2]))
states = env._backward_step(states, actions)
expected_states = torch.tensor([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1]], dtype=torch.float)
assert torch.equal(states.tensor, expected_states)

# Try removing non-existent item (invalid)
actions = env.actions_from_tensor(format_tensor([2, 0]))
with pytest.raises(NonValidActionsError):
env._backward_step(states, actions)

# Remove item 0 and 4
actions = env.actions_from_tensor(format_tensor([0, 4]))
states = env._backward_step(states, actions)
expected_states = torch.tensor([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0]], dtype=torch.float)
assert torch.equal(states.tensor, expected_states)

# Remove item 3 and 1 (last items)
actions = env.actions_from_tensor(format_tensor([3, 1]))
states = env._backward_step(states, actions)
expected_states = torch.zeros((BATCH_SIZE, N_ITEMS), dtype=torch.float)
assert torch.equal(states.tensor, expected_states)
assert torch.all(states.is_initial_state)


def test_perfect_binary_tree_fwd_step():
DEPTH = 3
BATCH_SIZE = 2
N_ACTIONS = 3 # 0=left, 1=right, 2=exit

env = PerfectBinaryTree(depth=DEPTH, reward_fn=lambda s: s.float() + 1)
states = env.reset(batch_shape=BATCH_SIZE)
assert states.tensor.shape == (BATCH_SIZE, 1)
assert torch.all(states.tensor == 0)

# Go left, Go right
actions = env.actions_from_tensor(format_tensor([0, 1]))
states = env._step(states, actions)
expected_states = torch.tensor([[1], [2]], dtype=torch.long)
assert torch.equal(states.tensor, expected_states)

# Go right, Go left
actions = env.actions_from_tensor(format_tensor([1, 0]))
states = env._step(states, actions)
expected_states = torch.tensor([[4], [5]], dtype=torch.long)
assert torch.equal(states.tensor, expected_states)

# Go left, Go left
actions = env.actions_from_tensor(format_tensor([0, 0]))
states = env._step(states, actions)
expected_states = torch.tensor([[9], [11]], dtype=torch.long) # Leaf nodes
assert torch.equal(states.tensor, expected_states)
assert torch.all(torch.isin(states.tensor, env.terminating_states.tensor))

# Try moving from leaf node (invalid)
actions = env.actions_from_tensor(format_tensor([0, 1]))
with pytest.raises(NonValidActionsError):
env._step(states, actions)

# Exit action (valid)
actions = env.actions_from_tensor(format_tensor([N_ACTIONS - 1, N_ACTIONS - 1]))
final_states = env._step(states, actions)
assert torch.all(final_states.is_sink_state)

# Check rewards
rewards = env.reward(states)
expected_rewards = torch.tensor([[10.0], [12.0]])
assert torch.allclose(rewards, expected_rewards)


def test_perfect_binary_tree_bwd_step():
DEPTH = 3

env = PerfectBinaryTree(depth=DEPTH, reward_fn=lambda s: s.float() + 1)

# Start from leaf nodes 8 and 12
initial_tensor = torch.tensor([[8], [12]], dtype=torch.long)
states = env.states_from_tensor(initial_tensor)

# Try backward exit action (invalid)
actions = env.actions_from_tensor(format_tensor([2, 2]))
with pytest.raises(RuntimeError):
env._backward_step(states, actions)

# Go up (from right child, from left child)
# Node 8 is right child of 3 (action 1). Node 12 is left child of 5 (action 0)
actions = env.actions_from_tensor(format_tensor([1, 0]))
# Go up (Node 8 is right child of 3 -> bwd action 1; Node 12 is right child of 5 -> bwd action 1)
actions = env.actions_from_tensor(format_tensor([1, 1]))
states = env._backward_step(states, actions)
expected_states = torch.tensor([[3], [5]], dtype=torch.long)
assert torch.equal(states.tensor, expected_states)

# Go up (from left child, from right child)
# Node 3 is left child of 1 (action 0). Node 5 is right child of 2 (action 1)
actions = env.actions_from_tensor(format_tensor([0, 1]))
# Go up (Node 3 is left child of 1 -> bwd action 0; Node 5 is left child of 2 -> bwd action 0)
actions = env.actions_from_tensor(format_tensor([0, 0]))
states = env._backward_step(states, actions)
expected_states = torch.tensor([[1], [2]], dtype=torch.long)
assert torch.equal(states.tensor, expected_states)

# Go up to root (from left child, from right child)
# Node 1 is left child of 0 (action 0). Node 2 is right child of 0 (action 1)
actions = env.actions_from_tensor(format_tensor([0, 1]))
states = env._backward_step(states, actions)
expected_states = torch.tensor([[0], [0]], dtype=torch.long)
assert torch.equal(states.tensor, expected_states)
assert torch.all(states.is_initial_state)