1
1
import logging
2
- from typing import Optional , Tuple , Union
2
+ import math
3
+ from typing import List , Optional , Tuple , Union
3
4
4
5
import numpy as np
5
6
import torch
7
+ import torch .nn as nn
6
8
from class_resolver import HintOrType , OptionalKwargs
7
9
8
10
try :
@@ -72,12 +74,11 @@ def _gcn_norm(
72
74
edge_index ,
73
75
num_nodes : int ,
74
76
edge_weight = None ,
75
- improved = True ,
77
+ fill_value = 2.0 ,
76
78
add_self_loops = True ,
77
79
flow = "source_to_target" ,
78
80
dtype = None ,
79
81
):
80
- fill_value = 2.0 if improved else 1.0
81
82
assert flow in ["source_to_target" , "target_to_source" ]
82
83
83
84
if edge_weight is None :
@@ -104,35 +105,119 @@ def _gcn_norm(
104
105
return edge_index , edge_weight
105
106
106
107
108
+ class BasicMessagePassing :
109
+ def __init__ (
110
+ self ,
111
+ edge_weight : float = 1.0 ,
112
+ self_loop_weight : float = 2.0 ,
113
+ aggr : str = "add" ,
114
+ ):
115
+ self .edge_weight = edge_weight
116
+ self .self_loop_weight = self_loop_weight
117
+ self .aggr = aggr
118
+
119
+ def forward (self , x : torch .Tensor , edge_index : torch .Tensor ) -> torch .Tensor :
120
+ edge_index_with_loops , edge_weights = _gcn_norm (
121
+ edge_index ,
122
+ num_nodes = len (x ),
123
+ edge_weight = torch .tensor ([self .edge_weight ] * len (edge_index [0 ])),
124
+ fill_value = self .self_loop_weight ,
125
+ )
126
+ return sparse_matmul (
127
+ SparseTensor .from_edge_index (edge_index_with_loops , edge_attr = edge_weights ),
128
+ x ,
129
+ reduce = self .aggr ,
130
+ )
131
+
132
+
133
+ def _glorot (value : torch .Tensor ):
134
+ # see https://github.com/pyg-team/pytorch_geometric/blob/3e55a4c263f04ed6676618226f9a0aaf406d99b9/torch_geometric/nn/inits.py#L30
135
+ stdv = math .sqrt (6.0 / (value .size (- 2 ) + value .size (- 1 )))
136
+ value .data .uniform_ (- stdv , stdv )
137
+
138
+
139
+ class FrozenGCNConv (BasicMessagePassing ):
140
+ def __init__ (
141
+ self ,
142
+ in_channels : int ,
143
+ out_channels : int ,
144
+ bias : bool = False ,
145
+ edge_weight : float = 1.0 ,
146
+ self_loop_weight : float = 2.0 ,
147
+ aggr : str = "add" ,
148
+ ):
149
+ super ().__init__ (
150
+ edge_weight = edge_weight , self_loop_weight = self_loop_weight , aggr = aggr
151
+ )
152
+ self .lin = nn .Linear (in_channels , out_channels , bias = bias )
153
+ for param in self .lin .parameters ():
154
+ param .requires_grad = False
155
+ # Use glorot initialization
156
+ _glorot (self .lin .weight )
157
+
158
+ def forward (self , x : torch .Tensor , edge_index : torch .Tensor ) -> torch .Tensor :
159
+ x = self .lin (x )
160
+ return super ().forward (x , edge_index )
161
+
162
+
107
163
class GCNFrameEncoder (RelationFrameEncoder ):
108
164
"""Use untrained GCN for aggregating neighboring embeddings with self.
109
165
110
166
Args:
111
167
depth: How many hops of neighbors should be incorporated
168
+ edge_weight: Weighting of non-self-loops
169
+ self_loop_weight: Weighting of self-loops
170
+ layer_dims: Dimensionality of layers if used
171
+ bias: Whether to use bias in layers
172
+ use_weight_layers: Whether to use randomly initialized layers in aggregation
173
+ aggr: Which aggregation to use. Can be :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`
112
174
attribute_encoder: HintOrType[TokenizedFrameEncoder]: Base encoder class
113
175
attribute_encoder_kwargs: OptionalKwargs: Keyword arguments for initializing encoder
114
176
"""
115
177
116
178
def __init__ (
117
179
self ,
118
180
depth : int = 2 ,
181
+ edge_weight : float = 1.0 ,
182
+ self_loop_weight : float = 2.0 ,
183
+ layer_dims : int = 300 ,
184
+ bias : bool = False ,
185
+ use_weight_layers : bool = True ,
186
+ aggr : str = "sum" ,
119
187
attribute_encoder : HintOrType [TokenizedFrameEncoder ] = None ,
120
188
attribute_encoder_kwargs : OptionalKwargs = None ,
121
189
):
122
190
if not TORCH_SCATTER :
123
191
logger .error ("Could not find torch_scatter and/or torch_sparse package!" )
124
192
self .depth = depth
193
+ self .edge_weight = edge_weight
194
+ self .self_loop_weight = self_loop_weight
125
195
self .device = resolve_device ()
126
196
self .attribute_encoder = tokenized_frame_encoder_resolver .make (
127
197
attribute_encoder , attribute_encoder_kwargs
128
198
)
129
-
130
- def _forward (self , x : torch .Tensor , edge_index : torch .Tensor ) -> torch .Tensor :
131
- edge_index_with_loops , edge_weights = _gcn_norm (edge_index , num_nodes = len (x ))
132
- return sparse_matmul (
133
- SparseTensor .from_edge_index (edge_index_with_loops , edge_attr = edge_weights ),
134
- x ,
135
- )
199
+ layers : List [BasicMessagePassing ]
200
+ if use_weight_layers :
201
+ layers = [
202
+ FrozenGCNConv (
203
+ in_channels = layer_dims ,
204
+ out_channels = layer_dims ,
205
+ edge_weight = edge_weight ,
206
+ self_loop_weight = self_loop_weight ,
207
+ aggr = aggr ,
208
+ )
209
+ for _ in range (self .depth )
210
+ ]
211
+ else :
212
+ layers = [
213
+ BasicMessagePassing (
214
+ edge_weight = edge_weight ,
215
+ self_loop_weight = self_loop_weight ,
216
+ aggr = aggr ,
217
+ )
218
+ for _ in range (self .depth )
219
+ ]
220
+ self .layers = layers
136
221
137
222
def _encode_rel (
138
223
self ,
@@ -143,6 +228,6 @@ def _encode_rel(
143
228
full_graph = np .concatenate ([rel_triples_left , rel_triples_right ])
144
229
edge_index = torch .from_numpy (full_graph [:, [0 , 2 ]]).t ()
145
230
x = ent_features .vectors
146
- for _ in range ( self .depth ) :
147
- x = self . _forward (x , edge_index )
231
+ for layer in self .layers :
232
+ x = layer . forward (x , edge_index )
148
233
return x
0 commit comments