diff --git a/.gitignore b/.gitignore index a19ff61a..496d781f 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ scripts.py models/ *.DS_Store +.python-version diff --git a/src/gfn/gym/perfect_tree.py b/src/gfn/gym/perfect_tree.py new file mode 100644 index 00000000..66dc0007 --- /dev/null +++ b/src/gfn/gym/perfect_tree.py @@ -0,0 +1,123 @@ +from typing import Callable + +import torch + +from gfn.env import Actions, DiscreteEnv, DiscreteStates +from gfn.states import States + + +class PerfectBinaryTree(DiscreteEnv): + r""" + Perfect Tree Environment where there is a bijection between trajectories and terminating states. + Nodes are represented by integers, starting from 0 for the root. + States are represented by a single integer tensor corresponding to the node index. + Actions are integers: 0 (left child), 1 (right child), 2 (exit). + + e.g.: + + 0 (root) + / \ + 1 2 + / \ / \ + 3 4 5 6 + / \ / \ / \ / \ + 7 8 9 10 11 12 13 14 (terminating states if depth=3) + + Recommended preprocessor: `OneHotPreprocessor`. + """ + + def __init__(self, reward_fn: Callable, depth: int = 4): + self.reward_fn = reward_fn + self.depth = depth + self.branching_factor = 2 + self.n_actions = self.branching_factor + 1 + self.n_nodes = 2 ** (self.depth + 1) - 1 + + self.s0 = torch.zeros((1,), dtype=torch.long) + self.sf = torch.full((1,), fill_value=-1, dtype=torch.long) + super().__init__(self.n_actions, self.s0, (1,), sf=self.sf) + + ( + self.transition_table, + self.inverse_transition_table, + self.term_states, + ) = self._build_tree() + + def _build_tree(self) -> tuple[dict, dict, DiscreteStates]: + """Create a transition table ensuring a bijection between trajectories and last states.""" + transition_table = {} + inverse_transition_table = {} + node_index = 0 + queue = [(node_index, 0)] # (current_node, depth) + + terminating_states_id = set() + while queue: + node, d = queue.pop(0) + if d < self.depth: + for a in range(self.branching_factor): + node_index += 1 + transition_table[(node, a)] = node_index + inverse_transition_table[(node_index, a)] = node + queue.append((node_index, d + 1)) + else: + terminating_states_id.add(node) + terminating_states_id = torch.tensor(list(terminating_states_id)).reshape(-1, 1) + terminating_states = self.states_from_tensor(terminating_states_id) + + return transition_table, inverse_transition_table, terminating_states + + def backward_step(self, states: DiscreteStates, actions: Actions) -> torch.Tensor: + tuples = torch.hstack((states.tensor, actions.tensor)).tolist() + tuples = tuple((tuple_) for tuple_ in tuples) + next_states_tns = [ + self.inverse_transition_table.get(tuple(tuple_)) for tuple_ in tuples + ] + next_states_tns = torch.tensor(next_states_tns).reshape(-1, 1) + next_states_tns = torch.tensor(next_states_tns).reshape(-1, 1).long() + return next_states_tns + + def step(self, states: DiscreteStates, actions: Actions) -> torch.Tensor: + tuples = torch.hstack((states.tensor, actions.tensor)).tolist() + tuples = tuple(tuple(tuple_) for tuple_ in tuples) + next_states_tns = [self.transition_table.get(tuple_) for tuple_ in tuples] + next_states_tns = torch.tensor(next_states_tns).reshape(-1, 1).long() + return next_states_tns + + def update_masks(self, states: DiscreteStates) -> None: + terminating_states_mask = torch.isin( + states.tensor, self.terminating_states.tensor + ).flatten() + initial_state_mask = (states.tensor == self.s0).flatten() + even_states = (states.tensor % 2 == 0).flatten() + + # Going from any node, we can choose action 0 or 1 + # Except terminating states where we must end the trajectory + states.forward_masks[~terminating_states_mask, -1] = False + states.forward_masks[terminating_states_mask, :] = False + states.forward_masks[terminating_states_mask, -1] = True + + # Even states are to the right, so tied to action 1 + # Uneven states are to the left, tied to action 0 + states.backward_masks[even_states, 1] = True + states.backward_masks[even_states, 0] = False + states.backward_masks[~even_states, 1] = False + states.backward_masks[~even_states, 0] = True + + # Initial state has no available backward action + states.backward_masks[initial_state_mask, :] = False + + def get_states_indices(self, states: States): + return torch.flatten(states.tensor) + + @property + def all_states(self) -> DiscreteStates: + return self.states_from_tensor(torch.arange(self.n_nodes).reshape(-1, 1)) + + @property + def terminating_states(self) -> DiscreteStates: + lb = 2**self.depth - 1 + ub = 2 ** (self.depth + 1) - 1 + return self.make_states_class()(torch.arange(lb, ub).reshape(-1, 1)) + + def reward(self, final_states): + return self.reward_fn(final_states.tensor) diff --git a/src/gfn/gym/set_addition.py b/src/gfn/gym/set_addition.py new file mode 100644 index 00000000..6454625d --- /dev/null +++ b/src/gfn/gym/set_addition.py @@ -0,0 +1,75 @@ +from typing import Callable + +import torch + +from gfn.env import Actions, DiscreteEnv, DiscreteStates + + +class SetAddition(DiscreteEnv): + """Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 + [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html) + + The state is a binary vector of length `n_items`, where 1 indicates the presence of an item. + Actions are integers from 0 to `n_items - 1` to add the corresponding item, or `n_items` to exit. + Adding an existing item is invalid. The trajectory must end when `max_items` are present. + + Recommended preprocessor: `IdentityPreprocessor`. + """ + + def __init__(self, n_items: int, max_items: int, reward_fn: Callable): + self.n_items = n_items + self.reward_fn = reward_fn + self.max_traj_len = max_items + n_actions = n_items + 1 + s0 = torch.zeros(n_items) + state_shape = (n_items,) + + super().__init__(n_actions, s0, state_shape) + + def get_states_indices(self, states: DiscreteStates): + states_raw = states.tensor + + canonical_base = 2 ** torch.arange( + self.n_items - 1, -1, -1, device=states_raw.device + ) + indices = (canonical_base * states_raw).sum(-1).long() + return indices + + def update_masks(self, states: DiscreteStates) -> None: + trajs_that_must_end = states.tensor.sum(dim=1) >= self.max_traj_len + trajs_that_may_continue = states.tensor.sum(dim=1) < self.max_traj_len + + states.forward_masks[trajs_that_may_continue, : self.n_items] = ( + states.tensor[trajs_that_may_continue] == 0 + ) + + # Disallow everything for trajs that must end + states.forward_masks[trajs_that_must_end, : self.n_items] = 0 + states.forward_masks[..., -1] = 1 # Allow exit action + + states.backward_masks[..., : self.n_items] = states.tensor != 0 + + # Disallow exit action if at s_0 + at_initial_state = torch.all(states.tensor == 0, dim=1) + states.forward_masks[at_initial_state, -1] = 0 + + def step(self, states: DiscreteStates, actions: Actions) -> torch.Tensor: + new_states_tensor = states.tensor.scatter(-1, actions.tensor, 1, reduce="add") + return new_states_tensor + + def backward_step(self, states: DiscreteStates, actions: Actions): + new_states_tensor = states.tensor.scatter(-1, actions.tensor, -1, reduce="add") + return new_states_tensor + + def reward(self, final_states: DiscreteStates) -> torch.Tensor: + return self.reward_fn(final_states.tensor) + + @property + def all_states(self) -> DiscreteStates: + digits = torch.arange(0, 2, device=self.device) + all_states = torch.cartesian_prod(*[digits] * self.n_items) + return DiscreteStates(all_states) + + @property + def terminating_states(self) -> DiscreteStates: + return self.all_states[1:] # Remove initial state s_0 diff --git a/testing/test_environments.py b/testing/test_environments.py index ebceb87c..2b8079f8 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -9,6 +9,8 @@ from gfn.env import NonValidActionsError from gfn.gym import Box, DiscreteEBM, HyperGrid from gfn.gym.graph_building import GraphBuilding +from gfn.gym.perfect_tree import PerfectBinaryTree +from gfn.gym.set_addition import SetAddition from gfn.preprocessors import IdentityPreprocessor, KHotPreprocessor, OneHotPreprocessor from gfn.states import GraphStates @@ -505,3 +507,177 @@ def test_graph_env(): ) states = env._backward_step(states, actions) assert states.tensor.x.shape == (0, FEATURE_DIM) + + +def test_set_addition_fwd_step(): + N_ITEMS = 4 + MAX_ITEMS = 3 + BATCH_SIZE = 2 + + env = SetAddition( + n_items=N_ITEMS, max_items=MAX_ITEMS, reward_fn=lambda s: s.sum(-1) + ) + states = env.reset(batch_shape=BATCH_SIZE) + assert states.tensor.shape == (BATCH_SIZE, N_ITEMS) + + # Add item 0 and 1 + actions = env.actions_from_tensor(format_tensor([0, 1])) + states = env._step(states, actions) + expected_states = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=torch.float) + assert torch.equal(states.tensor, expected_states) + + # Add item 2 and 3 + actions = env.actions_from_tensor(format_tensor([2, 3])) + states = env._step(states, actions) + expected_states = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1]], dtype=torch.float) + assert torch.equal(states.tensor, expected_states) + + # Try adding existing items (invalid) + actions = env.actions_from_tensor(format_tensor([0, 1])) + with pytest.raises(NonValidActionsError): + env._step(states, actions) + + # Add item 3 and 0 + actions = env.actions_from_tensor(format_tensor([3, 0])) + states = env._step(states, actions) + expected_states = torch.tensor([[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.float) + assert torch.equal(states.tensor, expected_states) # Now has 3 items + + # Try adding another item (invalid, max_items reached) + actions = env.actions_from_tensor(format_tensor([1, 2])) + with pytest.raises(NonValidActionsError): + env._step(states, actions) + + # Exit action (valid) + actions = env.actions_from_tensor(format_tensor([N_ITEMS, N_ITEMS])) + final_states = env._step(states, actions) + assert torch.all(final_states.is_sink_state) + + # Check rewards + rewards = env.reward(states) + expected_rewards = torch.tensor([3.0, 3.0]) + assert torch.allclose(rewards, expected_rewards) + + +def test_set_addition_bwd_step(): + N_ITEMS = 5 + MAX_ITEMS = 4 + BATCH_SIZE = 2 + + env = SetAddition( + n_items=N_ITEMS, max_items=MAX_ITEMS, reward_fn=lambda s: s.sum(-1) + ) + + # Start from a state with 3 items + initial_tensor = torch.tensor([[1, 1, 0, 1, 0], [0, 1, 1, 0, 1]], dtype=torch.float) + states = env.states_from_tensor(initial_tensor) + + # Remove item 1 and 2 + actions = env.actions_from_tensor(format_tensor([1, 2])) + states = env._backward_step(states, actions) + expected_states = torch.tensor([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1]], dtype=torch.float) + assert torch.equal(states.tensor, expected_states) + + # Try removing non-existent item (invalid) + actions = env.actions_from_tensor(format_tensor([2, 0])) + with pytest.raises(NonValidActionsError): + env._backward_step(states, actions) + + # Remove item 0 and 4 + actions = env.actions_from_tensor(format_tensor([0, 4])) + states = env._backward_step(states, actions) + expected_states = torch.tensor([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0]], dtype=torch.float) + assert torch.equal(states.tensor, expected_states) + + # Remove item 3 and 1 (last items) + actions = env.actions_from_tensor(format_tensor([3, 1])) + states = env._backward_step(states, actions) + expected_states = torch.zeros((BATCH_SIZE, N_ITEMS), dtype=torch.float) + assert torch.equal(states.tensor, expected_states) + assert torch.all(states.is_initial_state) + + +def test_perfect_binary_tree_fwd_step(): + DEPTH = 3 + BATCH_SIZE = 2 + N_ACTIONS = 3 # 0=left, 1=right, 2=exit + + env = PerfectBinaryTree(depth=DEPTH, reward_fn=lambda s: s.float() + 1) + states = env.reset(batch_shape=BATCH_SIZE) + assert states.tensor.shape == (BATCH_SIZE, 1) + assert torch.all(states.tensor == 0) + + # Go left, Go right + actions = env.actions_from_tensor(format_tensor([0, 1])) + states = env._step(states, actions) + expected_states = torch.tensor([[1], [2]], dtype=torch.long) + assert torch.equal(states.tensor, expected_states) + + # Go right, Go left + actions = env.actions_from_tensor(format_tensor([1, 0])) + states = env._step(states, actions) + expected_states = torch.tensor([[4], [5]], dtype=torch.long) + assert torch.equal(states.tensor, expected_states) + + # Go left, Go left + actions = env.actions_from_tensor(format_tensor([0, 0])) + states = env._step(states, actions) + expected_states = torch.tensor([[9], [11]], dtype=torch.long) # Leaf nodes + assert torch.equal(states.tensor, expected_states) + assert torch.all(torch.isin(states.tensor, env.terminating_states.tensor)) + + # Try moving from leaf node (invalid) + actions = env.actions_from_tensor(format_tensor([0, 1])) + with pytest.raises(NonValidActionsError): + env._step(states, actions) + + # Exit action (valid) + actions = env.actions_from_tensor(format_tensor([N_ACTIONS - 1, N_ACTIONS - 1])) + final_states = env._step(states, actions) + assert torch.all(final_states.is_sink_state) + + # Check rewards + rewards = env.reward(states) + expected_rewards = torch.tensor([[10.0], [12.0]]) + assert torch.allclose(rewards, expected_rewards) + + +def test_perfect_binary_tree_bwd_step(): + DEPTH = 3 + + env = PerfectBinaryTree(depth=DEPTH, reward_fn=lambda s: s.float() + 1) + + # Start from leaf nodes 8 and 12 + initial_tensor = torch.tensor([[8], [12]], dtype=torch.long) + states = env.states_from_tensor(initial_tensor) + + # Try backward exit action (invalid) + actions = env.actions_from_tensor(format_tensor([2, 2])) + with pytest.raises(RuntimeError): + env._backward_step(states, actions) + + # Go up (from right child, from left child) + # Node 8 is right child of 3 (action 1). Node 12 is left child of 5 (action 0) + actions = env.actions_from_tensor(format_tensor([1, 0])) + # Go up (Node 8 is right child of 3 -> bwd action 1; Node 12 is right child of 5 -> bwd action 1) + actions = env.actions_from_tensor(format_tensor([1, 1])) + states = env._backward_step(states, actions) + expected_states = torch.tensor([[3], [5]], dtype=torch.long) + assert torch.equal(states.tensor, expected_states) + + # Go up (from left child, from right child) + # Node 3 is left child of 1 (action 0). Node 5 is right child of 2 (action 1) + actions = env.actions_from_tensor(format_tensor([0, 1])) + # Go up (Node 3 is left child of 1 -> bwd action 0; Node 5 is left child of 2 -> bwd action 0) + actions = env.actions_from_tensor(format_tensor([0, 0])) + states = env._backward_step(states, actions) + expected_states = torch.tensor([[1], [2]], dtype=torch.long) + assert torch.equal(states.tensor, expected_states) + + # Go up to root (from left child, from right child) + # Node 1 is left child of 0 (action 0). Node 2 is right child of 0 (action 1) + actions = env.actions_from_tensor(format_tensor([0, 1])) + states = env._backward_step(states, actions) + expected_states = torch.tensor([[0], [0]], dtype=torch.long) + assert torch.equal(states.tensor, expected_states) + assert torch.all(states.is_initial_state)