-
Notifications
You must be signed in to change notification settings - Fork 0
/
retriever.py
130 lines (94 loc) · 4.66 KB
/
retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from torch import nn
import torch.nn.functional as F
from torchdrug import core
from gnn.model import *
class BasePNARetriever(nn.Module):
'''
Retrieve text information
'''
def __init__(self, config, text_embeddings, kgl2token, orig_vocab_size):
super().__init__()
self.config = config
self.text_embeddings =text_embeddings
self.kgl2token = kgl2token
self.orig_vocab_size = orig_vocab_size
self.down_scaling = nn.Linear(
self.config.llm_hidden_dim, self.config.r, bias=False, dtype=torch.float)
if self.config.text_encoder == 'pna':
self.re_scaling = nn.Linear(config.r*12, self.config.r)
def aggregate_text(self, token_ids, text_embeddings, method='pna'):
device = text_embeddings.device
token_ids = token_ids.to(device) # Batch x Length
token_mask = (token_ids > 0).unsqueeze(-1).to(device) # B x L X 1
token_lengths = token_mask.half().sum(axis=1).to(device) # B X 1
degree = token_lengths
token_embs = text_embeddings[token_ids] # B x L x Hidden
mean = (token_embs * token_mask).sum(axis=1) / token_lengths
if method == 'mean':
result = mean
else:
sq_mean = (token_embs**2 * token_mask).sum(axis=1) / \
token_lengths
max, _ = (token_embs*token_mask).max(axis=1)
min, _ = (token_embs*token_mask).min(axis=1)
std = (sq_mean - mean ** 2).clamp(min=1e-6).sqrt()
features = torch.cat(
[mean, max, min, std], dim=-1)
scale = degree.log()
scale = scale / scale.mean()
scales = torch.cat(
[torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1)
result = (features.unsqueeze(-1) *
scales.unsqueeze(-2)).flatten(-2)
return result
def retrieve_text(self, token_ids):
# token_ids: num_kgl_tokens x num_tokens
R = self.down_scaling(self.text_embeddings)
result = self.aggregate_text(token_ids, R, self.config.text_encoder)
if self.config.text_encoder == 'pna':
result = self.re_scaling(result)
return self.norm(result)
def norm(self, x):
return F.normalize(x, p=2, dim=1)
def forward(self, kgl_ids=None):
if kgl_ids is not None:
kgl_ids = kgl_ids - self.orig_vocab_size
token_ids = self.kgl2token[kgl_ids.cpu()]
else:
token_ids = self.kgl2token
return self.retrieve_text(token_ids)
class ContextRetriever(BasePNARetriever):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.up_scaling = nn.Linear(
self.config.r, self.config.llm_hidden_dim, bias=False, dtype=torch.float)
def forward(self, kgl_ids, graph, all_index, all_kgl_index):
text_embs = super().forward(kgl_ids)
context = self.up_scaling(text_embs)
return context
class ScoreRetriever(BasePNARetriever):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.kg_retriever = core.Configurable.load_config_dict(config.kg_encoder)
# self.kg_retriever = ConditionedPNA(config.kg_encoder)
self.h_down_scaling = nn.Linear(
self.config.llm_hidden_dim, self.config.r, bias=False, dtype=torch.float)
self.r_down_scaling = nn.Linear(
self.config.llm_hidden_dim, self.config.r, bias=False, dtype=torch.float)
def forward(self, h_id, r_id, t_id, hidden_states, rel_hidden_states, graph, all_index, all_kgl_index):
score_text_embs = super().forward(all_kgl_index)
head_embeds = self.h_down_scaling(hidden_states)
rel_embeds = self.r_down_scaling(rel_hidden_states)
score = self.kg_retriever(h_id, r_id, t_id, head_embeds, rel_embeds, graph, score_text_embs, all_index)
return score
class RelScoreRetriever(BasePNARetriever):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.r_down_scaling = nn.Linear(
self.config.llm_hidden_dim, self.config.r, bias=False, dtype=torch.float)
def forward(self, rel_hidden_states, all_rel_kgl_index):
score_text_embs = super().forward(all_rel_kgl_index) # num rel, r
rel_embeds = self.r_down_scaling(rel_hidden_states) # batch size, r
score = F.linear(rel_embeds, score_text_embs)
return score