-
Notifications
You must be signed in to change notification settings - Fork 4
/
adamic_utils.py
69 lines (63 loc) · 2.64 KB
/
adamic_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import scipy.sparse as ssp
def get_A(adj, num_nodes):
row, col,val = adj.coo()
A = ssp.csr_matrix((val.cpu(), (row.cpu(), col.cpu())), shape=(num_nodes, num_nodes))
return A
def AA(A, edge_index, batch_size=2000):
# The Adamic-Adar heuristic score.
multiplier = 1 / np.log(A.sum(0))
multiplier[np.isinf(multiplier)] = 0
A_ = A.multiply(multiplier).tocsr()
link_loader = DataLoader(range(edge_index.size(1)), batch_size)
scores = []
for ind in tqdm(link_loader):
src, dst = edge_index[0, ind], edge_index[1, ind]
cur_scores = np.array(np.sum(A[src].multiply(A_[dst]), 1)).flatten()
scores.append(cur_scores)
scores = np.concatenate(scores, 0)
return torch.FloatTensor(scores), edge_index
def get_pos_neg_edges(split, split_edge, edge_index, num_nodes, percent=100):
if 'edge' in split_edge['train']:
pos_edge = split_edge[split]['edge'].t()
if split == 'train':
new_edge_index, _ = add_self_loops(edge_index)
neg_edge = negative_sampling(
new_edge_index, num_nodes=num_nodes,
num_neg_samples=pos_edge.size(1))
else:
neg_edge = split_edge[split]['edge_neg'].t()
# subsample for pos_edge
np.random.seed(123)
num_pos = pos_edge.size(1)
perm = np.random.permutation(num_pos)
perm = perm[:int(percent / 100 * num_pos)]
pos_edge = pos_edge[:, perm]
# subsample for neg_edge
np.random.seed(123)
num_neg = neg_edge.size(1)
perm = np.random.permutation(num_neg)
perm = perm[:int(percent / 100 * num_neg)]
neg_edge = neg_edge[:, perm]
elif 'source_node' in split_edge['train']:
source = split_edge[split]['source_node']
target = split_edge[split]['target_node']
if split == 'train':
target_neg = torch.randint(0, num_nodes, [target.size(0), 1],
dtype=torch.long)
else:
target_neg = split_edge[split]['target_node_neg']
# subsample
np.random.seed(123)
num_source = source.size(0)
perm = np.random.permutation(num_source)
perm = perm[:int(percent / 100 * num_source)]
source, target, target_neg = source[perm], target[perm], target_neg[perm, :]
pos_edge = torch.stack([source, target])
neg_per_target = target_neg.size(1)
neg_edge = torch.stack([source.repeat_interleave(neg_per_target),
target_neg.view(-1)])
return pos_edge, neg_edge