Skip to content

Commit

Permalink
minor refactorings
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Jan 13, 2025
1 parent 8270199 commit eaaba4e
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 57 deletions.
57 changes: 7 additions & 50 deletions rl4co/models/zoo/deepaco/antsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class AntSystem:
decay: Rate at which pheromone evaporates. Should be between 0 and 1. Defaults to 0.95.
Q: Rate at which pheromone deposits. Defaults to `1 / n_ants`.
pheromone: Initial pheromone matrix. Defaults to `torch.ones_like(log_heuristic)`.
require_logprobs: Whether to require the log probability of actions. Defaults to False.
use_local_search: Whether to use local_search provided by the env. Default to False.
use_nls: Whether to use neural-guided local search provided by the env. Default to False.
n_perturbations: Number of perturbations to be used for nls. Defaults to 5.
Expand All @@ -43,7 +42,6 @@ def __init__(
decay: float = 0.95,
Q: Optional[float] = None,
pheromone: Optional[Tensor] = None,
require_logprobs: bool = False,
use_local_search: bool = False,
use_nls: bool = False,
n_perturbations: int = 5,
Expand All @@ -67,8 +65,7 @@ def __init__(
self.pheromone = pheromone

self.final_actions = self.final_reward = None
self.require_logprobs = require_logprobs
self.all_records = []
self.final_reward_cache = torch.zeros(self.batch_size, 0, device=log_heuristic.device)

self.use_local_search = use_local_search
assert not (use_nls and not use_local_search), "use_nls requires use_local_search"
Expand Down Expand Up @@ -123,8 +120,8 @@ def run(
td = td_initial.clone()
self._one_step(td, env)

assert self.final_reward is not None
action_matrix = self._convert_final_action_to_matrix()
assert action_matrix is not None and self.final_reward is not None
td, env = self._recreate_final_routes(td_initial, env, action_matrix)

return td, action_matrix, self.final_reward
Expand Down Expand Up @@ -184,9 +181,6 @@ def _sampling(
logprobs, actions, td, env = decode_strategy.post_decoder_hook(td, env)
reward = env.get_reward(td, actions)

if self.require_logprobs:
self.all_records.append((logprobs, actions, reward, td.get("mask", None)))

return td, env, actions, reward

def local_search(
Expand Down Expand Up @@ -246,6 +240,9 @@ def _update_results(self, actions, reward):
self.final_actions[index] = best_actions[index]
self.final_reward[require_update] = best_reward[require_update]

self.final_reward_cache = torch.cat(
[self.final_reward_cache, self.final_reward.unsqueeze(-1)], -1
)
return best_index

def _update_pheromone(self, actions, reward):
Expand Down Expand Up @@ -285,53 +282,14 @@ def _recreate_final_routes(self, td, env, action_matrix):
assert td["done"].all()
return td, env

def get_logp(self):
"""Get the log probability (logprobs) values recorded during the execution of the algorithm.
Returns:
results: Tuple containing the log probability values,
actions chosen, rewards obtained, and mask values (if available).
Raises:
AssertionError: If `require_logp` is not enabled.
"""

assert (
self.require_logprobs
), "Please enable `require_logp` to record logprobs values"

logprobs_list, actions_list, reward_list, mask_list = [], [], [], []

for logprobs, actions, reward, mask in self.all_records:
logprobs_list.append(logprobs)
actions_list.append(actions)
reward_list.append(reward)
mask_list.append(mask)

if mask_list[0] is None:
mask_list = None
else:
mask_list = torch.stack(mask_list, 0)

# reset records
self.all_records = []

return (
torch.stack(logprobs_list, 0),
torch.stack(actions_list, 0),
torch.stack(reward_list, 0),
mask_list,
)

@staticmethod
@lru_cache(5)
def _batch_action_indices(batch_size: int, n_actions: int, device: torch.device):
batchindex = torch.arange(batch_size, device=device)
return batchindex.unsqueeze(1).repeat(1, n_actions).view(-1)

def _convert_final_action_to_matrix(self) -> Optional[Tensor]:
if self.final_actions is None:
return None
def _convert_final_action_to_matrix(self) -> Tensor:
assert self.final_actions is not None
action_count = max(len(actions) for actions in self.final_actions)
mat_actions = torch.zeros(
(self.batch_size, action_count),
Expand All @@ -340,5 +298,4 @@ def _convert_final_action_to_matrix(self) -> Optional[Tensor]:
)
for index, action in enumerate(self.final_actions):
mat_actions[index, : len(action)] = action

return mat_actions
9 changes: 7 additions & 2 deletions rl4co/models/zoo/deepaco/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from rl4co.models.zoo.deepaco.antsystem import AntSystem
from rl4co.models.zoo.nargnn.encoder import NARGNNEncoder
from rl4co.utils.decoding import modify_logits_for_top_k_filtering, modify_logits_for_top_p_filtering
from rl4co.utils.ops import batchify, get_distance_matrix, unbatchify
from rl4co.utils.utils import merge_with_defaults
from rl4co.utils.ops import batchify, unbatchify


class DeepACOPolicy(NonAutoregressivePolicy):
Expand Down Expand Up @@ -60,6 +60,9 @@ def __init__(
test_decode_type="multistart_sampling",
)

self.top_p = top_p
self.top_k = top_k

self.aco_class = AntSystem if aco_class is None else aco_class
self.aco_kwargs = aco_kwargs
self.train_with_local_search = train_with_local_search
Expand Down Expand Up @@ -124,7 +127,9 @@ def forward(
heatmap_logits = outdict["hidden"]
# TODO: Refactor this so that we don't need to use the aco object
aco = self.aco_class(heatmap_logits, n_ants=n_ants, **self.aco_kwargs)
_, ls_reward = aco.local_search(batchify(td_initial, n_ants), env, outdict["actions"])
_, ls_reward = aco.local_search(
batchify(td_initial, n_ants), env, outdict["actions"] # type:ignore
)
outdict["ls_reward"] = unbatchify(ls_reward, n_ants)

outdict["log_likelihood"] = unbatchify(outdict["log_likelihood"], n_ants)
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/gfacs/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Optional, Union
from typing import Optional, Union

import numpy as np
import scipy
Expand Down
5 changes: 3 additions & 2 deletions rl4co/models/zoo/gfacs/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def forward(
if self.train_with_local_search:
# TODO: Refactor this so that we don't need to use the aco object
aco = self.aco_class(hidden, n_ants=n_ants, **self.aco_kwargs)
ls_actions, ls_reward = aco.local_search(batchify(td_initial, n_ants), env, actions)
ls_actions, ls_reward = aco.local_search(
batchify(td_initial, n_ants), env, actions # type:ignore
)
ls_logprobs, ls_actions, td, env = self.common_decoding(
"evaluate", td_initial, env, hidden, n_ants, ls_actions, **decoding_kwargs
)
Expand All @@ -149,7 +151,6 @@ def forward(
)
}
)

if return_actions:
outdict["ls_actions"] = ls_actions
########################################################################
Expand Down
4 changes: 2 additions & 2 deletions rl4co/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ def pre_decoder_hook(
if self.num_starts is None:
self.num_starts = env.get_num_starts(td)
if self.multisample:
log.warn(
log.warning(
f"num_starts is not provided for sampling, using num_starts={self.num_starts}"
)
else:
if self.num_starts is not None:
if self.num_starts >= 1:
log.warn(
log.warning(
f"num_starts={self.num_starts} is ignored for decode_type={self.name}"
)

Expand Down

0 comments on commit eaaba4e

Please sign in to comment.