11import torch
22from torch import nn
33import torch .nn .functional as F
4- from torch_geometric .nn import MessagePassing
54from torch_geometric .nn .dense .linear import Linear
65from torch_geometric .nn .conv .gcn_conv import gcn_norm
76from torch_sparse import SparseTensor
109from torch .nn import Parameter
1110from torch_sparse import SparseTensor
1211from typing import Optional
13- from torch_geometric .nn .conv import MessagePassing
1412from torch_geometric .nn .inits import zeros
1513from torch_geometric .typing import Adj , OptPairTensor , OptTensor
1614from torch_sparse import matmul as torch_sparse_matmul
17- from torch_geometric .nn .conv .gcn_conv import gcn_norm
1815
1916
2017def spmm (src : Adj , other : Tensor , reduce : str = "sum" ) -> Tensor :
@@ -43,7 +40,7 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
4340 raise ValueError (f"`{ reduce } ` reduction is not supported for "
4441 f"`torch.sparse.Tensor`." )
4542
46- class GCN_reweight (MessagePassing ):
43+ class GCN_reweight (pyg_nn . MessagePassing ):
4744 r"""The graph convolutional operator from the `"Semi-supervised
4845 Classification with Graph Convolutional Networks"
4946 <https://arxiv.org/abs/1609.02907>`_ paper
@@ -103,13 +100,21 @@ class GCN_reweight(MessagePassing):
103100 _cached_edge_index : Optional [OptPairTensor ]
104101 _cached_adj_t : Optional [SparseTensor ]
105102
106- def __init__ (self , in_channels : int , out_channels : int , aggr : str ,
107- improved : bool = False , cached : bool = False ,
108- add_self_loops : bool = False , normalize : bool = True ,
109- bias : bool = True , ** kwargs ):
103+ def __init__ (
104+ self ,
105+ in_channels : int ,
106+ out_channels : int ,
107+ aggr : str ,
108+ improved : bool = False ,
109+ cached : bool = False ,
110+ add_self_loops : bool = False ,
111+ normalize : bool = True ,
112+ bias : bool = True ,
113+ ** kwargs
114+ ):
110115
111- kwargs .setdefault ('aggr' , "add" )
112- super ().__init__ (** kwargs , flow = "target_to_source" )
116+ # kwargs.setdefault('aggr', "add")
117+ super (GCN_reweight , self ).__init__ (aggr = aggr , flow = "target_to_source" )
113118
114119 self .in_channels = in_channels
115120 self .out_channels = out_channels
@@ -120,13 +125,11 @@ def __init__(self, in_channels: int, out_channels: int, aggr: str,
120125 self .normalize = False
121126 else :
122127 self .normalize = True
123- #self.normalize = normalize
124-
128+
125129 self ._cached_edge_index = None
126130 self ._cached_adj_t = None
127131
128- self .lin = Linear (in_channels , out_channels , bias = False ,
129- weight_initializer = 'glorot' )
132+ self .lin = Linear (in_channels , out_channels , bias = False , weight_initializer = 'glorot' )
130133
131134 if bias :
132135 self .bias = Parameter (torch .Tensor (out_channels ))
@@ -142,12 +145,9 @@ def reset_parameters(self):
142145 self ._cached_adj_t = None
143146
144147
145- def forward (self , x : Tensor , edge_index : Adj ,
146- edge_weight : OptTensor = None , lmda = 1 ) -> Tensor :
147- """"""
148+ def forward (self , x , edge_index , edge_weight , lmda ):
148149 edge_rw = edge_weight
149150 edge_weight = torch .ones_like (edge_rw )
150- #edge_weight = None
151151 if self .normalize :
152152 if isinstance (edge_index , Tensor ):
153153 cache = self ._cached_edge_index
@@ -174,25 +174,24 @@ def forward(self, x: Tensor, edge_index: Adj,
174174 x = self .lin (x )
175175
176176 # propagate_type: (x: Tensor, edge_weight: OptTensor)
177- out = self .propagate (edge_index , x = x , edge_weight = edge_weight ,
178- size = None , lmda = lmda , edge_rw = edge_rw )
177+ out = self .propagate (edge_index , size = None , x = x , edge_weight = edge_weight , edge_rw = edge_rw , lmda = lmda )
179178
180179 if self .bias is not None :
181180 out = out + self .bias
182181
183182 return out
184183
185- def message (self , x_j : Tensor , edge_weight : OptTensor , lmda , edge_rw ) -> Tensor :
184+ def message (self , x_j , edge_index , edge_weight , edge_rw , lmda ) :
186185 x_j = (edge_weight .view (- 1 , 1 ) * x_j )
187186 x_j = (1 - lmda ) * x_j + (lmda ) * (edge_rw .view (- 1 , 1 ) * x_j )
188187 return x_j
189188
190189 def message_and_aggregate (self , adj_t : SparseTensor , x : Tensor ) -> Tensor :
191190 return spmm (adj_t , x , reduce = self .aggr )
192191
192+
193193class GS_reweight (pyg_nn .MessagePassing ):
194- def __init__ (self , in_channels , out_channels , reducer ,
195- normalize_embedding = False ):
194+ def __init__ (self , in_channels , out_channels , reducer , normalize_embedding = False ):
196195 super (GS_reweight , self ).__init__ (aggr = reducer , flow = "target_to_source" )
197196 self .lin = torch .nn .Linear (in_channels , out_channels )
198197 self .agg_lin = torch .nn .Linear (out_channels + in_channels , out_channels )
@@ -201,12 +200,12 @@ def __init__(self, in_channels, out_channels, reducer,
201200
202201 def forward (self , x , edge_index , edge_weight , lmda ):
203202 num_nodes = x .size (0 )
204- return self .propagate (edge_index , size = (num_nodes , num_nodes ), x = x , edge_weight = edge_weight , lmda = lmda )
203+ return self .propagate (edge_index , size = (num_nodes , num_nodes ), x = x , edge_weight = edge_weight , lmda = lmda )
205204
206205 def message (self , x_j , edge_index , edge_weight , lmda ):
207206 x_j = self .lin (x_j )
208207 x_j = (1 - lmda ) * x_j + (lmda ) * (edge_weight .view (- 1 , 1 ) * x_j )
209- #print(lmda)
208+
210209 return x_j
211210
212211 def update (self , aggr_out , x ):
@@ -274,7 +273,7 @@ def __init__(
274273 def forward (self , data , h ):
275274 x , edge_index , edge_weight = h , data .edge_index , data .edge_weight
276275 for i , layer in enumerate (self .conv ):
277- x = layer (x , edge_index , edge_weight = edge_weight , lmda = self .lmda )
276+ x = layer (x , edge_index , edge_weight , self .lmda )
278277 # if self.bn and (i != len(self.conv) - 1):
279278 # x = self.bns[i](x)
280279 x = F .relu (x )
0 commit comments