-
Notifications
You must be signed in to change notification settings - Fork 7
/
rgcn_model.py
155 lines (114 loc) · 4.88 KB
/
rgcn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
from torch import nn
import torch.nn.functional as F
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
"""Return input"""
return x
class Aggregator(nn.Module):
def __init__(self):
super(Aggregator, self).__init__()
def forward(self, node):
curr_emb = node.mailbox['curr_emb'][:, 0, :] # (B, F)
nei_msg = torch.bmm(node.mailbox['alpha'].transpose(1, 2), node.mailbox['msg']).squeeze(1) # (B, F)
new_emb = self.update_embedding(curr_emb, nei_msg)
return {'h': new_emb}
def update_embedding(self, curr_emb, nei_msg):
new_emb = nei_msg + curr_emb
return new_emb
class RGCNLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_rels, num_bases=None, has_bias=False, activation=None,
is_input_layer=False):
super(RGCNLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_rels = num_rels
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
self.num_bases = self.num_rels
# for msg_func
self.rel_weight = None
self.input_ = None
self.has_bias = has_bias
self.activation = activation
self.is_input_layer = is_input_layer
# add basis weights
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_dim, self.out_dim))
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels*2, self.num_bases))
self.self_loop_weight = nn.Parameter(torch.Tensor(self.in_dim, self.out_dim))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu'))
self.aggregator = Aggregator()
# bias
if self.has_bias:
self.bias = nn.Parameter(torch.Tensor(self.out_dim))
nn.init.zeros_(self.bias)
def msg_func(self, edges):
w = self.rel_weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src[self.input_].unsqueeze(1), w).squeeze(1)
curr_emb = torch.mm(edges.dst[self.input_], self.self_loop_weight) # (B, F)
a = 1 / edges.dst['in_d'].to(torch.float32).to(device=w.device).reshape(-1, 1)
return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a}
def apply_node_func(self, nodes):
node_repr = nodes.data['h']
if self.has_bias:
node_repr = node_repr + self.bias
if self.activation:
node_repr = self.activation(node_repr)
return {'h': node_repr}
def forward(self, g):
# generate all relations' weight from bases
weight = self.weight.view(self.num_bases, self.in_dim * self.out_dim)
self.rel_weight = torch.matmul(self.w_comp, weight).view(
self.num_rels*2, self.in_dim, self.out_dim)
# normalization constant
g.dstdata['in_d'] = g.in_degrees()
self.input_ = 'feat' if self.is_input_layer else 'h'
g.update_all(self.msg_func, self.aggregator, self.apply_node_func)
if self.is_input_layer:
g.ndata['repr'] = torch.cat([g.ndata['feat'], g.ndata['h']], dim=1)
else:
g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h']], dim=1)
class RGCN(nn.Module):
def __init__(self, args):
super(RGCN, self).__init__()
self.emb_dim = args.ent_dim
self.num_rel = args.num_rel
self.num_bases = args.num_bases
self.num_layers = args.num_layers
self.device = args.gpu
# create rgcn layers
self.layers = nn.ModuleList()
self.build_model()
self.jk_linear = nn.Linear(self.emb_dim*(self.num_layers+1), self.emb_dim)
def build_model(self):
# i2h
i2h = self.build_input_layer()
self.layers.append(i2h)
# h2h
for idx in range(self.num_layers - 1):
h2h = self.build_hidden_layer()
self.layers.append(h2h)
def build_input_layer(self):
return RGCNLayer(self.emb_dim,
self.emb_dim,
self.num_rel,
self.num_bases,
has_bias=True,
activation=F.relu,
is_input_layer=True)
def build_hidden_layer(self):
return RGCNLayer(self.emb_dim,
self.emb_dim,
self.num_rel,
self.num_bases,
has_bias=True,
activation=F.relu)
def forward(self, g):
for idx, layer in enumerate(self.layers):
layer(g)
g.ndata['h'] = self.jk_linear(g.ndata['repr'])
return g.ndata['h']