Skip to content

Commit

Permalink
[Chore] backwards compatibility Python 3.9 (match -> if-elif-else)
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jan 14, 2025
1 parent eaaba4e commit 8012b2c
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions rl4co/models/zoo/gfacs/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,17 @@ def calculate_loss(
return tb_loss

def calculate_log_pb_uniform(self, actions: torch.Tensor, n_ants: int):
match self.env.name:
case "tsp":
return math.log(1 / 2 * actions.size(1))
case "cvrp":
_a1 = actions.detach().cpu().numpy()
# shape: (batch, max_tour_length)
n_nodes = np.count_nonzero(_a1, axis=1)
_a2 = _a1[:, 1:] - _a1[:, :-1]
n_routes = np.count_nonzero(_a2, axis=1) - n_nodes
_a3 = _a1[:, 2:] - _a1[:, :-2]
n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes
log_b_p = - scipy.special.gammaln(n_routes + 1) - n_multinode_routes * math.log(2)
return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants)
case _:
raise ValueError(f"Unknown environment: {self.env.name}")
if self.env.name == "tsp":
return math.log(1 / 2 * actions.size(1))
elif self.env.name == "cvrp":
_a1 = actions.detach().cpu().numpy()
# shape: (batch, max_tour_length)
n_nodes = np.count_nonzero(_a1, axis=1)
_a2 = _a1[:, 1:] - _a1[:, :-1]
n_routes = np.count_nonzero(_a2, axis=1) - n_nodes
_a3 = _a1[:, 2:] - _a1[:, :-2]
n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes
log_b_p = - scipy.special.gammaln(n_routes + 1) - n_multinode_routes * math.log(2)
return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants)
else:
raise ValueError(f"Unknown environment: {self.env.name}")

0 comments on commit 8012b2c

Please sign in to comment.