-
Notifications
You must be signed in to change notification settings - Fork 8
/
scripts.py
32 lines (29 loc) · 878 Bytes
/
scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import time
import torch
import negative_sampling
import data_loader
import models
''''
data_path = '/home/kotnis/data/neg_sampling/freebase/'
params_path = data_path + 'rescal_params.pt'
results_dir = data_path + 'rescal_1/'
def sample(ns,ex):
return ns.sample(ex,True)
data = data_loader.read_dataset(data_path,results_dir,dev_mode=True,max_examples=float('inf'))
model = models.Rescal(data['num_ents'], data['num_rels'], 100)
state_dict = torch.load(params_path)
model.load_state_dict(state_dict)
ns = negative_sampling.NN_Sampler(data['train'],100,model,filtered=False)
batch = data['train'][:4000]
print("Start Profiling")
start = time.time()
samples = ns.batch_sample(batch,True,100)
end = time.time()
print("Time Taken {}".format(end-start))
'''
import numpy as np
n_h = 0
for n in range(500):
if np.random.uniform() < 0.27:
n_h+=1
print(n_h/float(n))