diff --git a/rl4co/models/zoo/gfacs/model.py b/rl4co/models/zoo/gfacs/model.py index a2dbfdba..fd92052d 100644 --- a/rl4co/models/zoo/gfacs/model.py +++ b/rl4co/models/zoo/gfacs/model.py @@ -110,18 +110,17 @@ def calculate_loss( return tb_loss def calculate_log_pb_uniform(self, actions: torch.Tensor, n_ants: int): - match self.env.name: - case "tsp": - return math.log(1 / 2 * actions.size(1)) - case "cvrp": - _a1 = actions.detach().cpu().numpy() - # shape: (batch, max_tour_length) - n_nodes = np.count_nonzero(_a1, axis=1) - _a2 = _a1[:, 1:] - _a1[:, :-1] - n_routes = np.count_nonzero(_a2, axis=1) - n_nodes - _a3 = _a1[:, 2:] - _a1[:, :-2] - n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes - log_b_p = - scipy.special.gammaln(n_routes + 1) - n_multinode_routes * math.log(2) - return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants) - case _: - raise ValueError(f"Unknown environment: {self.env.name}") + if self.env.name == "tsp": + return math.log(1 / 2 * actions.size(1)) + elif self.env.name == "cvrp": + _a1 = actions.detach().cpu().numpy() + # shape: (batch, max_tour_length) + n_nodes = np.count_nonzero(_a1, axis=1) + _a2 = _a1[:, 1:] - _a1[:, :-1] + n_routes = np.count_nonzero(_a2, axis=1) - n_nodes + _a3 = _a1[:, 2:] - _a1[:, :-2] + n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes + log_b_p = - scipy.special.gammaln(n_routes + 1) - n_multinode_routes * math.log(2) + return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants) + else: + raise ValueError(f"Unknown environment: {self.env.name}")