Skip to content

Commit c13bd43

Browse files
committed
version up
1 parent 8067832 commit c13bd43

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

pygda/nn/reweight_gnn.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from torch import nn
33
import torch.nn.functional as F
4-
from torch_geometric.nn import MessagePassing
54
from torch_geometric.nn.dense.linear import Linear
65
from torch_geometric.nn.conv.gcn_conv import gcn_norm
76
from torch_sparse import SparseTensor
@@ -10,11 +9,9 @@
109
from torch.nn import Parameter
1110
from torch_sparse import SparseTensor
1211
from typing import Optional
13-
from torch_geometric.nn.conv import MessagePassing
1412
from torch_geometric.nn.inits import zeros
1513
from torch_geometric.typing import Adj, OptPairTensor, OptTensor
1614
from torch_sparse import matmul as torch_sparse_matmul
17-
from torch_geometric.nn.conv.gcn_conv import gcn_norm
1815

1916

2017
def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
@@ -43,7 +40,7 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
4340
raise ValueError(f"`{reduce}` reduction is not supported for "
4441
f"`torch.sparse.Tensor`.")
4542

46-
class GCN_reweight(MessagePassing):
43+
class GCN_reweight(pyg_nn.MessagePassing):
4744
r"""The graph convolutional operator from the `"Semi-supervised
4845
Classification with Graph Convolutional Networks"
4946
<https://arxiv.org/abs/1609.02907>`_ paper
@@ -103,13 +100,21 @@ class GCN_reweight(MessagePassing):
103100
_cached_edge_index: Optional[OptPairTensor]
104101
_cached_adj_t: Optional[SparseTensor]
105102

106-
def __init__(self, in_channels: int, out_channels: int, aggr: str,
107-
improved: bool = False, cached: bool = False,
108-
add_self_loops: bool = False, normalize: bool = True,
109-
bias: bool = True, **kwargs):
103+
def __init__(
104+
self,
105+
in_channels: int,
106+
out_channels: int,
107+
aggr: str,
108+
improved: bool = False,
109+
cached: bool = False,
110+
add_self_loops: bool = False,
111+
normalize: bool = True,
112+
bias: bool = True,
113+
**kwargs
114+
):
110115

111-
kwargs.setdefault('aggr', "add")
112-
super().__init__(**kwargs, flow ="target_to_source")
116+
# kwargs.setdefault('aggr', "add")
117+
super(GCN_reweight, self).__init__(aggr=aggr, flow ="target_to_source")
113118

114119
self.in_channels = in_channels
115120
self.out_channels = out_channels
@@ -120,13 +125,11 @@ def __init__(self, in_channels: int, out_channels: int, aggr: str,
120125
self.normalize = False
121126
else:
122127
self.normalize = True
123-
#self.normalize = normalize
124-
128+
125129
self._cached_edge_index = None
126130
self._cached_adj_t = None
127131

128-
self.lin = Linear(in_channels, out_channels, bias=False,
129-
weight_initializer='glorot')
132+
self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot')
130133

131134
if bias:
132135
self.bias = Parameter(torch.Tensor(out_channels))
@@ -142,12 +145,9 @@ def reset_parameters(self):
142145
self._cached_adj_t = None
143146

144147

145-
def forward(self, x: Tensor, edge_index: Adj,
146-
edge_weight: OptTensor=None, lmda = 1) -> Tensor:
147-
""""""
148+
def forward(self, x, edge_index, edge_weight, lmda):
148149
edge_rw = edge_weight
149150
edge_weight = torch.ones_like(edge_rw)
150-
#edge_weight = None
151151
if self.normalize:
152152
if isinstance(edge_index, Tensor):
153153
cache = self._cached_edge_index
@@ -174,25 +174,24 @@ def forward(self, x: Tensor, edge_index: Adj,
174174
x = self.lin(x)
175175

176176
# propagate_type: (x: Tensor, edge_weight: OptTensor)
177-
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
178-
size=None, lmda = lmda, edge_rw = edge_rw)
177+
out = self.propagate(edge_index, size=None, x=x, edge_weight=edge_weight, edge_rw=edge_rw, lmda=lmda)
179178

180179
if self.bias is not None:
181180
out = out + self.bias
182181

183182
return out
184183

185-
def message(self, x_j: Tensor, edge_weight: OptTensor, lmda, edge_rw) -> Tensor:
184+
def message(self, x_j, edge_index, edge_weight, edge_rw, lmda):
186185
x_j = (edge_weight.view(-1, 1) * x_j)
187186
x_j = (1-lmda) * x_j + (lmda) * (edge_rw.view(-1, 1) * x_j)
188187
return x_j
189188

190189
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
191190
return spmm(adj_t, x, reduce=self.aggr)
192191

192+
193193
class GS_reweight(pyg_nn.MessagePassing):
194-
def __init__(self, in_channels, out_channels, reducer,
195-
normalize_embedding=False):
194+
def __init__(self, in_channels, out_channels, reducer, normalize_embedding=False):
196195
super(GS_reweight, self).__init__(aggr=reducer, flow ="target_to_source")
197196
self.lin = torch.nn.Linear(in_channels, out_channels)
198197
self.agg_lin = torch.nn.Linear(out_channels + in_channels, out_channels)
@@ -201,12 +200,12 @@ def __init__(self, in_channels, out_channels, reducer,
201200

202201
def forward(self, x, edge_index, edge_weight, lmda):
203202
num_nodes = x.size(0)
204-
return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x, edge_weight = edge_weight, lmda = lmda)
203+
return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x, edge_weight=edge_weight, lmda=lmda)
205204

206205
def message(self, x_j, edge_index, edge_weight, lmda):
207206
x_j = self.lin(x_j)
208207
x_j = (1-lmda) * x_j + (lmda) * (edge_weight.view(-1, 1) * x_j)
209-
#print(lmda)
208+
210209
return x_j
211210

212211
def update(self, aggr_out, x):
@@ -274,7 +273,7 @@ def __init__(
274273
def forward(self, data, h):
275274
x, edge_index, edge_weight = h, data.edge_index, data.edge_weight
276275
for i, layer in enumerate(self.conv):
277-
x = layer(x, edge_index, edge_weight=edge_weight, lmda = self.lmda)
276+
x = layer(x, edge_index, edge_weight, self.lmda)
278277
# if self.bn and (i != len(self.conv) - 1):
279278
# x = self.bns[i](x)
280279
x = F.relu(x)

pygda/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.0.3'
1+
__version__ = '0.0.4'

0 commit comments

Comments
 (0)