Skip to content

Commit 264ec18

Browse files
Furfficocbhua
andcommitted
feat: add CVRPMVC and SHPP
Co-authored-by: Chuanbo Hua <[email protected]>
1 parent 4a63bc6 commit 264ec18

File tree

7 files changed

+423
-5
lines changed

7 files changed

+423
-5
lines changed

rl4co/envs/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# EDA
55
from rl4co.envs.eda import DPPEnv, MDPPEnv
66

7+
# Graph
8+
from rl4co.envs.graph import FLPEnv, MCPEnv
9+
710
# Routing
811
from rl4co.envs.routing import (
912
ATSPEnv,
1013
CVRPEnv,
14+
CVRPMVCEnv,
1115
CVRPTWEnv,
1216
DenseRewardTSPEnv,
1317
MDCPDPEnv,
@@ -18,6 +22,7 @@
1822
PDPEnv,
1923
PDPRuinRepairEnv,
2024
SDVRPEnv,
25+
SHPPEnv,
2126
SPCTSPEnv,
2227
SVRPEnv,
2328
TSPEnv,
@@ -28,14 +33,12 @@
2833
from rl4co.envs.scheduling import FFSPEnv, FJSPEnv, SMTWTPEnv
2934
from rl4co.envs.scheduling.jssp.env import JSSPEnv
3035

31-
# Graph
32-
from rl4co.envs.graph import MCPEnv, FLPEnv
33-
3436
# Register environments
3537
ENV_REGISTRY = {
3638
"atsp": ATSPEnv,
3739
"cvrp": CVRPEnv,
3840
"cvrptw": CVRPTWEnv,
41+
"cvrpmvc": CVRPMVCEnv,
3942
"dpp": DPPEnv,
4043
"ffsp": FFSPEnv,
4144
"jssp": JSSPEnv,
@@ -47,6 +50,7 @@
4750
"pdp": PDPEnv,
4851
"pdp_ruin_repair": PDPRuinRepairEnv,
4952
"sdvrp": SDVRPEnv,
53+
"shpp": SHPPEnv,
5054
"svrp": SVRPEnv,
5155
"spctsp": SPCTSPEnv,
5256
"tsp": TSPEnv,

rl4co/envs/routing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from rl4co.envs.routing.atsp.generator import ATSPGenerator
33
from rl4co.envs.routing.cvrp.env import CVRPEnv
44
from rl4co.envs.routing.cvrp.generator import CVRPGenerator
5+
from rl4co.envs.routing.cvrpmvc.env import CVRPMVCEnv
56
from rl4co.envs.routing.cvrptw.env import CVRPTWEnv
67
from rl4co.envs.routing.cvrptw.generator import CVRPTWGenerator
78
from rl4co.envs.routing.mdcpdp.env import MDCPDPEnv
@@ -17,6 +18,8 @@
1718
from rl4co.envs.routing.pdp.env import PDPEnv, PDPRuinRepairEnv
1819
from rl4co.envs.routing.pdp.generator import PDPGenerator
1920
from rl4co.envs.routing.sdvrp.env import SDVRPEnv
21+
from rl4co.envs.routing.shpp.env import SHPPEnv
22+
from rl4co.envs.routing.shpp.generator import SHPPGenerator
2023
from rl4co.envs.routing.spctsp.env import SPCTSPEnv
2124
from rl4co.envs.routing.svrp.env import SVRPEnv
2225
from rl4co.envs.routing.svrp.generator import SVRPGenerator

rl4co/envs/routing/cvrpmvc/env.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
3+
from tensordict.tensordict import TensorDict
4+
5+
from rl4co.envs.routing.cvrp.env import CVRPEnv
6+
from rl4co.utils.ops import gather_by_index
7+
from rl4co.utils.pylogger import get_pylogger
8+
9+
log = get_pylogger(__name__)
10+
11+
12+
class CVRPMVCEnv(CVRPEnv):
13+
"""Capacitated Vehicle Routing Problem (CVRP) with maximum vehicle constraint environment."""
14+
15+
name = "cvrpmvc"
16+
17+
def _step(self, td: TensorDict) -> TensorDict:
18+
vehicles_used = td["vehicles_used"] + (
19+
(td["action"].unsqueeze(-1) == 0) & (td["current_node"] != 0)
20+
)
21+
22+
current_node = td["action"][:, None] # Add dimension for step
23+
n_loc = td["demand"].size(-1) # Excludes depot
24+
25+
# Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot!
26+
selected_demand = gather_by_index(
27+
td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False
28+
)
29+
30+
# Increase capacity if depot is not visited, otherwise set to 0
31+
used_capacity = (td["used_capacity"] + selected_demand) * (
32+
current_node != 0
33+
).float()
34+
35+
demand_remaining = td["demand_remaining"] - selected_demand
36+
37+
# Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
38+
# Add one dimension since we write a single value
39+
visited = td["visited"].scatter(-1, current_node, 1)
40+
41+
# SECTION: get done
42+
done = visited.sum(-1) == visited.size(-1)
43+
reward = torch.zeros_like(done)
44+
45+
td.update(
46+
{
47+
"current_node": current_node,
48+
"used_capacity": used_capacity,
49+
"vehicles_used": vehicles_used,
50+
"demand_remaining": demand_remaining,
51+
"visited": visited,
52+
"reward": reward,
53+
"done": done,
54+
}
55+
)
56+
td.set("action_mask", self.get_action_mask(td))
57+
return td
58+
59+
def _reset(
60+
self, td: TensorDict | None = None, batch_size: list | None = None
61+
) -> TensorDict:
62+
td = super()._reset(td, batch_size)
63+
batch_size = batch_size or list(td.batch_size)
64+
td.set(
65+
"vehicles_used",
66+
torch.ones((*batch_size, 1), dtype=torch.int, device=td.device),
67+
)
68+
td.set("demand_remaining", td["demand"].sum(-1, keepdim=True))
69+
td.set(
70+
"max_vehicle", torch.ceil(td["demand_remaining"] / td["vehicle_capacity"]) + 1
71+
)
72+
return td
73+
74+
@staticmethod
75+
def get_action_mask(td: TensorDict) -> torch.Tensor:
76+
# For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting
77+
exceeds_cap = td["demand"] + td["used_capacity"] > td["vehicle_capacity"]
78+
79+
# Nodes that cannot be visited are already visited or too much demand to be served now
80+
mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap
81+
82+
if "vehicles_used" in td.keys():
83+
max_vehicle = td["max_vehicle"]
84+
demand_remaining = td["demand_remaining"]
85+
capacity_remaining = (max_vehicle - td["vehicles_used"]) * td[
86+
"vehicle_capacity"
87+
]
88+
mask_depot = ( # mask the depot
89+
(td["current_node"] == 0) # if the depot is just visited
90+
| (
91+
demand_remaining > capacity_remaining
92+
) # or the unassigned vehicles' capacity can't sastify remaining demands
93+
) & ~torch.all(
94+
mask_loc, dim=-1, keepdim=True
95+
) # unless there's no other choices
96+
else:
97+
# Cannot visit the depot if just visited and still unserved nodes
98+
mask_depot = (td["current_node"] == 0) & ~torch.all(
99+
mask_loc, dim=-1, keepdim=True
100+
)
101+
return ~torch.cat((mask_depot, mask_loc), -1)

rl4co/envs/routing/shpp/env.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from tensordict.tensordict import TensorDict
6+
from torchrl.data import (
7+
BoundedTensorSpec,
8+
CompositeSpec,
9+
UnboundedContinuousTensorSpec,
10+
UnboundedDiscreteTensorSpec,
11+
)
12+
13+
from rl4co.envs.common.base import RL4COEnvBase
14+
from rl4co.utils.ops import gather_by_index, get_tour_length
15+
from rl4co.utils.pylogger import get_pylogger
16+
17+
from .generator import SHPPGenerator
18+
from .render import render
19+
20+
log = get_pylogger(__name__)
21+
22+
23+
class SHPPEnv(RL4COEnvBase):
24+
"""
25+
Shortest Hamiltonian Path Problem (SHPP)
26+
SHPP is referred to the open-loop Traveling Salesman Problem (TSP) in the literature.
27+
The goal of the SHPP is to find the shortest Hamiltonian path in a given graph with
28+
given fixed starting/terminating nodes (they can be different nodes). A Hamiltonian
29+
path visits all other nodes exactly once. At each step, the agent chooses a city to visit.
30+
The reward is 0 unless the agent visits all the cities. In that case, the reward is
31+
(-)length of the path: maximizing the reward is equivalent to minimizing the path length.
32+
33+
Observation:
34+
- locations of each customer
35+
- starting node and terminating node
36+
- the current location of the vehicle
37+
38+
Constraints:
39+
- the first node is the starting node
40+
- the last node is the terminating node
41+
- each node is visited exactly once
42+
43+
Finish condition:
44+
- the agent has visited all the customers and reached the terminating node
45+
46+
Reward:
47+
- (minus) the length of the path
48+
49+
Args:
50+
generator: SHPPGenerator instance as the generator
51+
generator_params: parameters for the generator
52+
"""
53+
54+
name = "shpp"
55+
56+
def __init__(
57+
self,
58+
generator: SHPPGenerator = None,
59+
generator_params: dict = {},
60+
**kwargs,
61+
):
62+
super().__init__(**kwargs)
63+
if generator is None:
64+
generator = SHPPGenerator(**generator_params)
65+
self.generator = generator
66+
self._make_spec(self.generator)
67+
68+
@staticmethod
69+
def _step(td: TensorDict) -> TensorDict:
70+
current_node = td["action"]
71+
first_node = current_node if td["i"].all() == 0 else td["first_node"]
72+
73+
# Set not visited to 0 (i.e., we visited the node)
74+
available = td["available"].scatter(
75+
-1, current_node.unsqueeze(-1).expand_as(td["action_mask"]), 0
76+
)
77+
78+
# If all other nodes are visited, the terminating node will be available
79+
action_mask = available.clone()
80+
action_mask[..., -1] = ~available[..., :-1].any(dim=-1)
81+
82+
# We are done there are no unvisited locations
83+
done = torch.sum(available, dim=-1) == 0
84+
85+
# The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
86+
reward = torch.zeros_like(done)
87+
88+
td.update(
89+
{
90+
"first_node": first_node,
91+
"current_node": current_node,
92+
"i": td["i"] + 1,
93+
"available": available,
94+
"action_mask": action_mask,
95+
"reward": reward,
96+
"done": done,
97+
},
98+
)
99+
return td
100+
101+
def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
102+
"""Note: the first node is the starting node; the last node is the terminating node"""
103+
device = td.device
104+
locs = td["locs"]
105+
106+
# We do not enforce loading from self for flexibility
107+
num_loc = locs.shape[-2]
108+
109+
# Other variables
110+
current_node = torch.zeros((batch_size), dtype=torch.int64, device=device)
111+
last_node = torch.full(
112+
(batch_size), num_loc - 1, dtype=torch.int64, device=device
113+
)
114+
available = torch.ones(
115+
(*batch_size, num_loc), dtype=torch.bool, device=device
116+
) # 1 means not visited, i.e. action is allowed
117+
action_mask = torch.zeros((*batch_size, num_loc), dtype=torch.bool, device=device)
118+
action_mask[..., 0] = 1 # Only the start point is availabe at the beginning
119+
i = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device)
120+
121+
return TensorDict(
122+
{
123+
"locs": locs,
124+
"first_node": current_node,
125+
"last_node": last_node,
126+
"current_node": current_node,
127+
"i": i,
128+
"available": available,
129+
"action_mask": action_mask,
130+
"reward": torch.zeros((*batch_size, 1), dtype=torch.float32),
131+
},
132+
batch_size=batch_size,
133+
)
134+
135+
def _get_reward(self, td, actions) -> TensorDict:
136+
# Gather locations in order of tour and return distance between them (i.e., -reward)
137+
locs_ordered = gather_by_index(td["locs"], actions)
138+
return -get_tour_length(locs_ordered)
139+
140+
@staticmethod
141+
def check_solution_validity(td: TensorDict, actions: torch.Tensor):
142+
"""Check that solution is valid: nodes are visited exactly once"""
143+
assert (
144+
torch.arange(actions.size(1), out=actions.data.new())
145+
.view(1, -1)
146+
.expand_as(actions)
147+
== actions.data.sort(1)[0]
148+
).all(), "Invalid tour"
149+
150+
@staticmethod
151+
def render(td: TensorDict, actions: torch.Tensor = None, ax=None):
152+
return render(td, actions, ax)
153+
154+
def _make_spec(self, generator):
155+
"""Make the observation and action specs from the parameters"""
156+
self.observation_spec = CompositeSpec(
157+
locs=BoundedTensorSpec(
158+
low=generator.min_loc,
159+
high=generator.max_loc,
160+
shape=(generator.num_loc, 2),
161+
dtype=torch.float32,
162+
),
163+
first_node=UnboundedDiscreteTensorSpec(
164+
shape=(1),
165+
dtype=torch.int64,
166+
),
167+
current_node=UnboundedDiscreteTensorSpec(
168+
shape=(1),
169+
dtype=torch.int64,
170+
),
171+
i=UnboundedDiscreteTensorSpec(
172+
shape=(1),
173+
dtype=torch.int64,
174+
),
175+
action_mask=UnboundedDiscreteTensorSpec(
176+
shape=(generator.num_loc),
177+
dtype=torch.bool,
178+
),
179+
shape=(),
180+
)
181+
self.action_spec = BoundedTensorSpec(
182+
shape=(1,),
183+
dtype=torch.int64,
184+
low=0,
185+
high=generator.num_loc,
186+
)
187+
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
188+
self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)

0 commit comments

Comments
 (0)