Skip to content

Commit 844551e

Browse files
authored
Add more flexibility to GCNFrameEncoder (#5)
* Enhance GCNFrameEncoder * Added new GCN params to experiment and added random seed
1 parent 8ef6265 commit 844551e

File tree

3 files changed

+146
-27
lines changed

3 files changed

+146
-27
lines changed

experiment.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import logging
55
import os
66
import pickle
7+
import random
78
import shutil
89
import time
910
from dataclasses import dataclass
1011
from typing import Any, Dict, List, Optional, Tuple, Type, get_args
1112

1213
import click
14+
import numpy as np
15+
import torch
1316
from nephelai import upload
1417
from sylloge import OAEI, MovieGraphBenchmark, OpenEA
1518
from sylloge.base import EADataset
@@ -56,6 +59,15 @@
5659
logger = logging.getLogger("KlinkerExperiment")
5760

5861

62+
def set_random_seed(seed: Optional[int] = None):
63+
if seed is None:
64+
seed = np.random.randint(0, 2**16)
65+
logger.info(f"No random seed provided. Using {seed}")
66+
np.random.seed(seed=seed)
67+
torch.manual_seed(seed=seed)
68+
random.seed(seed)
69+
70+
5971
@dataclass
6072
class ExperimentInfo:
6173
params: Dict
@@ -194,21 +206,27 @@ def prepare(
194206
@click.option("--clean/--no-clean", default=True)
195207
@click.option("--wandb/--no-wandb", is_flag=True, default=False)
196208
@click.option("--nextcloud/--no-nextcloud", is_flag=True, default=False)
197-
def cli(clean: bool, wandb: bool, nextcloud: bool):
209+
@click.option("--random-seed", type=int, default=None)
210+
def cli(clean: bool, wandb: bool, nextcloud: bool, random_seed: Optional[int]):
198211
pass
199212

200213

201214
@cli.result_callback()
202215
def process_pipeline(
203-
blocker_and_dataset: List, clean: bool, wandb: bool, nextcloud: bool
216+
blocker_and_dataset: List,
217+
clean: bool,
218+
wandb: bool,
219+
nextcloud: bool,
220+
random_seed: Optional[int],
204221
):
222+
set_random_seed(random_seed)
205223
assert (
206224
len(blocker_and_dataset) == 2
207225
), "Only 1 dataset and 1 blocker command can be used!"
208226
if not isinstance(blocker_and_dataset[0][0], EADataset):
209227
raise ValueError("First command must be dataset command!")
210228
if not isinstance(blocker_and_dataset[1][0], Blocker):
211-
raise ValueError("First command must be blocker command!")
229+
raise ValueError("Second command must be blocker command!")
212230
dataset_with_params, blocker_with_params = blocker_and_dataset
213231
dataset, ds_params = dataset_with_params
214232
blocker, bl_params, blocker_creation_time = blocker_with_params
@@ -609,20 +627,32 @@ def light_ea_blocker(
609627

610628
@cli.command()
611629
@tokenized_frame_encoder_resolver.get_option(
612-
"--inner-encoder", default="TransformerTokenizedFrameEncoder", as_string=True
630+
"--inner-encoder", default="SIFEmbeddingTokenizedFrameEncoder", as_string=True
613631
)
632+
@click.option("--batch-size", type=int)
614633
@click.option("--embeddings", type=str, default="glove")
615634
@click.option("--depth", type=int, default=2)
616-
@click.option("--batch-size", type=int)
635+
@click.option("--edge-weight", type=float, default=1.0)
636+
@click.option("--self-loop-weight", type=float, default=2.0)
637+
@click.option("--layer-dims", type=int, default=300)
638+
@click.option("--bias", type=bool, default=True)
639+
@click.option("--use-weight-layers", type=bool, default=True)
640+
@click.option("--aggr", type=str, default="sum")
617641
@block_builder_resolver.get_option("--block-builder", default="kiez", as_string=True)
618642
@click.option("--block-builder-kwargs", type=str)
619643
@click.option("--n-neighbors", type=int, default=100)
620644
@click.option("--force", type=bool, default=True)
621645
def gcn_blocker(
622646
inner_encoder: Type[TokenizedFrameEncoder],
647+
batch_size: Optional[int],
623648
embeddings: str,
624649
depth: int,
625-
batch_size: Optional[int],
650+
edge_weight: float,
651+
self_loop_weight: float,
652+
layer_dims: int,
653+
bias: bool,
654+
use_weight_layers: bool,
655+
aggr: str,
626656
block_builder: Type[EmbeddingBlockBuilder],
627657
block_builder_kwargs: str,
628658
n_neighbors: int,
@@ -646,6 +676,12 @@ def gcn_blocker(
646676
blocker = EmbeddingBlocker(
647677
frame_encoder=GCNFrameEncoder(
648678
depth=depth,
679+
edge_weight=edge_weight,
680+
self_loop_weight=self_loop_weight,
681+
layer_dims=layer_dims,
682+
bias=bias,
683+
use_weight_layers=use_weight_layers,
684+
aggr=aggr,
649685
attribute_encoder=inner_encoder,
650686
attribute_encoder_kwargs=attribute_encoder_kwargs,
651687
),

src/klinker/encoders/gcn.py

+97-12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
2-
from typing import Optional, Tuple, Union
2+
import math
3+
from typing import List, Optional, Tuple, Union
34

45
import numpy as np
56
import torch
7+
import torch.nn as nn
68
from class_resolver import HintOrType, OptionalKwargs
79

810
try:
@@ -72,12 +74,11 @@ def _gcn_norm(
7274
edge_index,
7375
num_nodes: int,
7476
edge_weight=None,
75-
improved=True,
77+
fill_value=2.0,
7678
add_self_loops=True,
7779
flow="source_to_target",
7880
dtype=None,
7981
):
80-
fill_value = 2.0 if improved else 1.0
8182
assert flow in ["source_to_target", "target_to_source"]
8283

8384
if edge_weight is None:
@@ -104,35 +105,119 @@ def _gcn_norm(
104105
return edge_index, edge_weight
105106

106107

108+
class BasicMessagePassing:
109+
def __init__(
110+
self,
111+
edge_weight: float = 1.0,
112+
self_loop_weight: float = 2.0,
113+
aggr: str = "add",
114+
):
115+
self.edge_weight = edge_weight
116+
self.self_loop_weight = self_loop_weight
117+
self.aggr = aggr
118+
119+
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
120+
edge_index_with_loops, edge_weights = _gcn_norm(
121+
edge_index,
122+
num_nodes=len(x),
123+
edge_weight=torch.tensor([self.edge_weight] * len(edge_index[0])),
124+
fill_value=self.self_loop_weight,
125+
)
126+
return sparse_matmul(
127+
SparseTensor.from_edge_index(edge_index_with_loops, edge_attr=edge_weights),
128+
x,
129+
reduce=self.aggr,
130+
)
131+
132+
133+
def _glorot(value: torch.Tensor):
134+
# see https://github.com/pyg-team/pytorch_geometric/blob/3e55a4c263f04ed6676618226f9a0aaf406d99b9/torch_geometric/nn/inits.py#L30
135+
stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1)))
136+
value.data.uniform_(-stdv, stdv)
137+
138+
139+
class FrozenGCNConv(BasicMessagePassing):
140+
def __init__(
141+
self,
142+
in_channels: int,
143+
out_channels: int,
144+
bias: bool = False,
145+
edge_weight: float = 1.0,
146+
self_loop_weight: float = 2.0,
147+
aggr: str = "add",
148+
):
149+
super().__init__(
150+
edge_weight=edge_weight, self_loop_weight=self_loop_weight, aggr=aggr
151+
)
152+
self.lin = nn.Linear(in_channels, out_channels, bias=bias)
153+
for param in self.lin.parameters():
154+
param.requires_grad = False
155+
# Use glorot initialization
156+
_glorot(self.lin.weight)
157+
158+
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
159+
x = self.lin(x)
160+
return super().forward(x, edge_index)
161+
162+
107163
class GCNFrameEncoder(RelationFrameEncoder):
108164
"""Use untrained GCN for aggregating neighboring embeddings with self.
109165
110166
Args:
111167
depth: How many hops of neighbors should be incorporated
168+
edge_weight: Weighting of non-self-loops
169+
self_loop_weight: Weighting of self-loops
170+
layer_dims: Dimensionality of layers if used
171+
bias: Whether to use bias in layers
172+
use_weight_layers: Whether to use randomly initialized layers in aggregation
173+
aggr: Which aggregation to use. Can be :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`
112174
attribute_encoder: HintOrType[TokenizedFrameEncoder]: Base encoder class
113175
attribute_encoder_kwargs: OptionalKwargs: Keyword arguments for initializing encoder
114176
"""
115177

116178
def __init__(
117179
self,
118180
depth: int = 2,
181+
edge_weight: float = 1.0,
182+
self_loop_weight: float = 2.0,
183+
layer_dims: int = 300,
184+
bias: bool = False,
185+
use_weight_layers: bool = True,
186+
aggr: str = "sum",
119187
attribute_encoder: HintOrType[TokenizedFrameEncoder] = None,
120188
attribute_encoder_kwargs: OptionalKwargs = None,
121189
):
122190
if not TORCH_SCATTER:
123191
logger.error("Could not find torch_scatter and/or torch_sparse package!")
124192
self.depth = depth
193+
self.edge_weight = edge_weight
194+
self.self_loop_weight = self_loop_weight
125195
self.device = resolve_device()
126196
self.attribute_encoder = tokenized_frame_encoder_resolver.make(
127197
attribute_encoder, attribute_encoder_kwargs
128198
)
129-
130-
def _forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
131-
edge_index_with_loops, edge_weights = _gcn_norm(edge_index, num_nodes=len(x))
132-
return sparse_matmul(
133-
SparseTensor.from_edge_index(edge_index_with_loops, edge_attr=edge_weights),
134-
x,
135-
)
199+
layers: List[BasicMessagePassing]
200+
if use_weight_layers:
201+
layers = [
202+
FrozenGCNConv(
203+
in_channels=layer_dims,
204+
out_channels=layer_dims,
205+
edge_weight=edge_weight,
206+
self_loop_weight=self_loop_weight,
207+
aggr=aggr,
208+
)
209+
for _ in range(self.depth)
210+
]
211+
else:
212+
layers = [
213+
BasicMessagePassing(
214+
edge_weight=edge_weight,
215+
self_loop_weight=self_loop_weight,
216+
aggr=aggr,
217+
)
218+
for _ in range(self.depth)
219+
]
220+
self.layers = layers
136221

137222
def _encode_rel(
138223
self,
@@ -143,6 +228,6 @@ def _encode_rel(
143228
full_graph = np.concatenate([rel_triples_left, rel_triples_right])
144229
edge_index = torch.from_numpy(full_graph[:, [0, 2]]).t()
145230
x = ent_features.vectors
146-
for _ in range(self.depth):
147-
x = self._forward(x, edge_index)
231+
for layer in self.layers:
232+
x = layer.forward(x, edge_index)
148233
return x

tests/test_blockers.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,7 @@
2121
SimpleRelationalTokenBlocker,
2222
concat_neighbor_attributes,
2323
)
24-
from klinker.data import (
25-
KlinkerBlockManager,
26-
KlinkerDaskFrame,
27-
KlinkerFrame,
28-
KlinkerPandasFrame,
29-
KlinkerTriplePandasFrame,
30-
from_klinker_frame,
31-
)
24+
from klinker.data import KlinkerBlockManager, KlinkerFrame, from_klinker_frame
3225
from klinker.encoders.base import _get_ids
3326

3427

@@ -287,7 +280,12 @@ def test_assign_embedding_blocker(
287280

288281

289282
@pytest.mark.parametrize(
290-
"cls, params", [("LightEAFrameEncoder", dict(mini_dim=3)), ("GCNFrameEncoder", {})]
283+
"cls, params",
284+
[
285+
("LightEAFrameEncoder", dict(mini_dim=3)),
286+
("GCNFrameEncoder", dict(layer_dims=3, use_weight_layers=True)),
287+
("GCNFrameEncoder", dict(layer_dims=3, use_weight_layers=False)),
288+
],
291289
)
292290
def test_assign_relation_frame_encoder(
293291
cls,

0 commit comments

Comments
 (0)