@@ -99,11 +99,72 @@ def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tens
9999 x = init_embeddings [index ],
100100 edge_index = edge_index ,
101101 edge_attr = edge_attr ,
102+ ) # type: ignore
103+ graph_data .append (graph )
104+
105+ batch = Batch .from_data_list (graph_data ) # type: ignore
106+ batch .edge_attr = self .edge_embed (batch .edge_attr ) # type: ignore
107+ return batch
108+
109+
110+ class VRPPolarEdgeEmbedding (TSPEdgeEmbedding ):
111+ """TODO"""
112+
113+ node_dim = 2
114+
115+ def forward (self , td , init_embeddings : Tensor ):
116+ with torch .no_grad ():
117+ if "polar_locs" in td .keys ():
118+ theta = td ["polar_locs" ][..., 1 ]
119+ else :
120+ shifted_locs = td ["locs" ] - td ["locs" ][..., 0 :1 , :]
121+ x , y = shifted_locs [..., 0 ], shifted_locs [..., 1 ]
122+ theta = torch .atan2 (y , x )
123+
124+ delta_theta_matrix = theta [..., :, None ] - theta [..., None , :]
125+ edge_attr1 = 1 - torch .cos (delta_theta_matrix )
126+ edge_attr2 = get_distance_matrix (td ["locs" ])
127+ cost_matrix = torch .stack ((edge_attr1 , edge_attr2 ), dim = - 1 )
128+ del edge_attr1 , edge_attr2 , delta_theta_matrix
129+
130+ batch = self ._cost_matrix_to_graph (cost_matrix , init_embeddings )
131+ del cost_matrix
132+
133+ batch .edge_attr = self .edge_embed (batch .edge_attr ) # type: ignore
134+ return batch
135+
136+ @torch .no_grad ()
137+ def _cost_matrix_to_graph (self , batch_cost_matrix : Tensor , init_embeddings : Tensor ):
138+ """Convert batched cost_matrix to batched PyG graph, and calculate edge embeddings.
139+
140+ Args:
141+ batch_cost_matrix: Tensor of shape [batch_size, n, n, m]
142+ init_embedding: init embeddings of shape [batch_size, n, m]
143+ """
144+ graph_data = []
145+ for index , cost_matrix in enumerate (batch_cost_matrix ):
146+ edge_index , _ = sparsify_graph (
147+ cost_matrix [..., 0 ], self .k_sparse , self_loop = False
148+ )
149+ edge_index = edge_index .T [torch .all (edge_index != 0 , dim = 0 )].T
150+ _ , depot_edge_index = torch .topk (
151+ cost_matrix [0 , :, 1 ], k = self .k_sparse , largest = False , sorted = False
102152 )
153+ depot_edge_index = depot_edge_index [depot_edge_index != 0 ]
154+ depot_edge_index = torch .stack (
155+ (torch .zeros_like (depot_edge_index ), depot_edge_index ), dim = 0
156+ )
157+ edge_index = torch .concat ((depot_edge_index , edge_index ), dim = - 1 ).detach ()
158+ edge_attr = cost_matrix [edge_index [0 ], edge_index [1 ]].detach ()
159+
160+ graph = Data (
161+ x = init_embeddings [index ],
162+ edge_index = edge_index ,
163+ edge_attr = edge_attr ,
164+ ) # type: ignore
103165 graph_data .append (graph )
104166
105- batch = Batch .from_data_list (graph_data )
106- batch .edge_attr = self .edge_embed (batch .edge_attr )
167+ batch = Batch .from_data_list (graph_data ) # type: ignore
107168 return batch
108169
109170
@@ -146,8 +207,8 @@ def forward(self, td, init_embeddings: Tensor):
146207 x = node_embed ,
147208 edge_index = edge_index ,
148209 edge_attr = torch .zeros ((m , self .embed_dim ), device = device ),
149- )
210+ ) # type: ignore
150211 data_list .append (data )
151212
152- batch = Batch .from_data_list (data_list )
213+ batch = Batch .from_data_list (data_list ) # type: ignore
153214 return batch
0 commit comments