Skip to content

Commit 4a63bc6

Browse files
committed
fix(GLOP): add VRPPolar init/edge embeddings (!not usable!)
1 parent c5d80fd commit 4a63bc6

File tree

4 files changed

+234
-85
lines changed

4 files changed

+234
-85
lines changed

rl4co/models/nn/env_embeddings/edge.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,72 @@ def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tens
9999
x=init_embeddings[index],
100100
edge_index=edge_index,
101101
edge_attr=edge_attr,
102+
) # type: ignore
103+
graph_data.append(graph)
104+
105+
batch = Batch.from_data_list(graph_data) # type: ignore
106+
batch.edge_attr = self.edge_embed(batch.edge_attr) # type: ignore
107+
return batch
108+
109+
110+
class VRPPolarEdgeEmbedding(TSPEdgeEmbedding):
111+
"""TODO"""
112+
113+
node_dim = 2
114+
115+
def forward(self, td, init_embeddings: Tensor):
116+
with torch.no_grad():
117+
if "polar_locs" in td.keys():
118+
theta = td["polar_locs"][..., 1]
119+
else:
120+
shifted_locs = td["locs"] - td["locs"][..., 0:1, :]
121+
x, y = shifted_locs[..., 0], shifted_locs[..., 1]
122+
theta = torch.atan2(y, x)
123+
124+
delta_theta_matrix = theta[..., :, None] - theta[..., None, :]
125+
edge_attr1 = 1 - torch.cos(delta_theta_matrix)
126+
edge_attr2 = get_distance_matrix(td["locs"])
127+
cost_matrix = torch.stack((edge_attr1, edge_attr2), dim=-1)
128+
del edge_attr1, edge_attr2, delta_theta_matrix
129+
130+
batch = self._cost_matrix_to_graph(cost_matrix, init_embeddings)
131+
del cost_matrix
132+
133+
batch.edge_attr = self.edge_embed(batch.edge_attr) # type: ignore
134+
return batch
135+
136+
@torch.no_grad()
137+
def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tensor):
138+
"""Convert batched cost_matrix to batched PyG graph, and calculate edge embeddings.
139+
140+
Args:
141+
batch_cost_matrix: Tensor of shape [batch_size, n, n, m]
142+
init_embedding: init embeddings of shape [batch_size, n, m]
143+
"""
144+
graph_data = []
145+
for index, cost_matrix in enumerate(batch_cost_matrix):
146+
edge_index, _ = sparsify_graph(
147+
cost_matrix[..., 0], self.k_sparse, self_loop=False
148+
)
149+
edge_index = edge_index.T[torch.all(edge_index != 0, dim=0)].T
150+
_, depot_edge_index = torch.topk(
151+
cost_matrix[0, :, 1], k=self.k_sparse, largest=False, sorted=False
102152
)
153+
depot_edge_index = depot_edge_index[depot_edge_index != 0]
154+
depot_edge_index = torch.stack(
155+
(torch.zeros_like(depot_edge_index), depot_edge_index), dim=0
156+
)
157+
edge_index = torch.concat((depot_edge_index, edge_index), dim=-1).detach()
158+
edge_attr = cost_matrix[edge_index[0], edge_index[1]].detach()
159+
160+
graph = Data(
161+
x=init_embeddings[index],
162+
edge_index=edge_index,
163+
edge_attr=edge_attr,
164+
) # type: ignore
103165
graph_data.append(graph)
104166

105-
batch = Batch.from_data_list(graph_data)
106-
batch.edge_attr = self.edge_embed(batch.edge_attr)
167+
batch = Batch.from_data_list(graph_data) # type: ignore
107168
return batch
108169

109170

@@ -146,8 +207,8 @@ def forward(self, td, init_embeddings: Tensor):
146207
x=node_embed,
147208
edge_index=edge_index,
148209
edge_attr=torch.zeros((m, self.embed_dim), device=device),
149-
)
210+
) # type: ignore
150211
data_list.append(data)
151212

152-
batch = Batch.from_data_list(data_list)
213+
batch = Batch.from_data_list(data_list) # type: ignore
153214
return batch

rl4co/models/nn/env_embeddings/init.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tensordict.tensordict import TensorDict
55

66
from rl4co.models.nn.ops import PositionalEncoding
7+
from rl4co.utils.ops import cartesian_to_polar
78

89

910
def env_init_embedding(env_name: str, config: dict) -> nn.Module:
@@ -152,6 +153,55 @@ def forward(self, td):
152153
return torch.cat((depot_embedding, node_embeddings), -2)
153154

154155

156+
class VRPPolarInitEmbedding(nn.Module):
157+
"""Initial embedding for the Vehicle Routing Problems (VRP).
158+
Embed the following node features to the embedding space, based on polar coordinates:
159+
- locs: r, theta coordinates of the nodes, with the depot as the origin
160+
- demand: demand of the customers
161+
"""
162+
163+
def __init__(
164+
self,
165+
embed_dim,
166+
linear_bias=True,
167+
node_dim: int = 3,
168+
attach_cartesian_coords=False,
169+
):
170+
super(VRPPolarInitEmbedding, self).__init__()
171+
self.node_dim = node_dim + (
172+
2 if attach_cartesian_coords else 0
173+
) # 3: r, theta, demand; 5: r, theta, demand, x, y;
174+
self.attach_cartesian_coords = attach_cartesian_coords
175+
self.init_embed = nn.Linear(self.node_dim, embed_dim, linear_bias)
176+
self.init_embed_depot = nn.Linear(
177+
self.node_dim, embed_dim, linear_bias
178+
) # depot embedding
179+
180+
def forward(self, td):
181+
with torch.no_grad():
182+
locs = td["locs"]
183+
polar_locs = cartesian_to_polar(locs, locs[..., 0:1, :])
184+
td["polar_locs"] = polar_locs
185+
186+
demand = td["demand"]
187+
demand_with_depot = torch.concat(
188+
(torch.zeros(demand.shape[0], 1, device=demand.device), demand),
189+
dim=-1,
190+
).unsqueeze(-1)
191+
192+
if self.attach_cartesian_coords:
193+
x = torch.concat((polar_locs, demand_with_depot, locs), dim=-1)
194+
else:
195+
x = torch.concat((polar_locs, demand_with_depot), dim=-1)
196+
197+
depot, cities = x[:, :1, :], x[:, 1:, :]
198+
depot_embedding = self.init_embed_depot(depot)
199+
node_embeddings = self.init_embed(cities)
200+
201+
out = torch.cat((depot_embedding, node_embeddings), -2)
202+
return out
203+
204+
155205
class SVRPInitEmbedding(nn.Module):
156206
def __init__(self, embed_dim, linear_bias=True, node_dim: int = 3):
157207
super(SVRPInitEmbedding, self).__init__()

0 commit comments

Comments
 (0)