Skip to content

Commit 508f8d8

Browse files
committed
Fixed prefrence sampling bug
1 parent d13d210 commit 508f8d8

File tree

6 files changed

+38
-26
lines changed

6 files changed

+38
-26
lines changed

Investigation/condition_functions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
from .scoring.Last_score import final_score_cls
1212
from skimage.feature import blob_log
13-
13+
import time
1414

1515
def mock_peak_check(anchor,minc,maxc,configs,**kwags):
1616
a = configs.get('a',None)
@@ -29,6 +29,15 @@ def mock_peak_check(anchor,minc,maxc,configs,**kwags):
2929
c_peak = np.all(anchor<ub) and np.all(anchor>lb)
3030
if verb: print(c_peak)
3131
return c_peak, c_peak, None
32+
33+
34+
def mock_score_func(anchor,minc,maxc,configs,**kwags):
35+
a = np.array(configs.get('target',[-500,-500]))
36+
37+
score = 100/ np.linalg.norm(a-anchor)
38+
print(score)
39+
time.sleep(2)
40+
return score, False, None
3241

3342

3443

Playground/mock_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self,shapes_list,ndim,origin=None,dir=None):
3838
self.shapes_list = shapes_list
3939

4040
def jump(self, params):
41-
self.params = np.array(params)
41+
self.params = np.array(params).squeeze()
4242
return params
4343

4444
def measure(self):

Sampler_factory.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def __init__(self,configs):
8585
self.sampler_hook = None #bodge
8686

8787
self.t.add(samples=None,point_selected=None,boundary_points=[],
88-
vols_poff=[],detected=[],vols_poff_axes=[],poff=[],
89-
cand_v=[],all_v=[],vols_pinchoff=[],d_vec=[],poff_vec=[],meas_each_axis=[],vols_each_axis=[],extra_measure=[],
88+
vols_poff=[],detected=[],vols_poff_axes=[],poff=[]
89+
,all_v=[],vols_pinchoff=[],d_vec=[],poff_vec=[],meas_each_axis=[],vols_each_axis=[],extra_measure=[],
9090
vols_pinchoff_axes=[],vols_detected_axes=[],changed_origin=[],conditional_idx=[],r_vals=[])
9191

9292

@@ -120,7 +120,7 @@ def do_iter(self):
120120
do_gpr_p1, do_gpc_p1 = (i-1>self.t['gpr_start']) and self.t['gpr_on'], (i-1>self.t['gpc_start']) and self.t['gpc_on']
121121
print("GPR:",do_gpr,"GPC:",do_gpc,"prune:",do_pruning,"GPR1:",do_gpr_p1,"GPC1:",do_gpc_p1,"Optim:",do_optim)
122122
#pick a uvec and start sampling
123-
u, r_est = select_point(self.gpr, self.gpc, *self.t.get('origin', 'cand_v', 'all_v', 'directions'), do_gpr_p1, do_gpc_p1)
123+
u, r_est = select_point(self.gpr, self.gpc, *self.t.get('origin', 'boundary_points', 'vols_pinchoff', 'directions'), do_gpr_p1, do_gpc_p1)
124124
self.timer.logtime()
125125
self.sampler_hook = start_sampling(self.gpr, *self.t.get('samples', 'origin', 'real_ub', 'real_lb',
126126
'directions', 'n_part', 'sigma', 'max_steps'),sampler_hook=self.sampler_hook) if do_gpr_p1 else None
@@ -161,9 +161,9 @@ def do_iter(self):
161161

162162
return self.t.getd(*self.t['verbose'])
163163

164-
def select_point(hypersurface, selection_model, origin, cand_v, all_v, directions, use_selection=True, estimate_r=True):
164+
def select_point(hypersurface, selection_model, origin, boundary_points, vols_pinchoff, directions, use_selection=True, estimate_r=True):
165165
"""selects a point to investigate using thompson sampling, uniform sampling or random angles
166-
depending on use_selection flag or is no cand_v are present
166+
depending on use_selection flag or is no samples are present
167167
Args:
168168
hypersurface: model of the hypersurface
169169
selection_model: model of probability of observing desirable features
@@ -175,17 +175,23 @@ def select_point(hypersurface, selection_model, origin, cand_v, all_v, direction
175175
unit vector
176176
"""
177177

178-
if len(cand_v) > 0 and use_selection:
179-
points_candidate = rw.project_crosses_to_boundary(cand_v, hypersurface, origin)
180-
v = choose_next(points_candidate, all_v, selection_model, d_tooclose = 20.)
181-
elif len(cand_v) != 0:
182-
v = rw.pick_from_boundary_points(cand_v)
178+
boundary_points = [] if boundary_points is None else boundary_points
179+
180+
if len(boundary_points) > 0 and use_selection:
181+
points_candidate = rw.project_crosses_to_boundary(boundary_points, hypersurface, origin)
182+
v = choose_next(points_candidate, vols_pinchoff, selection_model, d_tooclose = 20.)
183+
elif len(boundary_points) != 0:
184+
v = rw.pick_from_boundary_points(boundary_points)
183185
else:
184186
print('WARNING: no boundary point is sampled')
185187
return random_angle_directions(len(origin), 1, np.array(directions))[0], None
186188
v_origin = v - origin
187189
u = v_origin / np.sqrt(np.sum(np.square(v_origin)))
188-
return u, hypersurface.predict(u) if estimate_r else None
190+
r_est,r_std = hypersurface.predict(u[np.newaxis,:])
191+
192+
r_est = np.maximum(r_est - 1.0*np.sqrt(r_std), 0.0)
193+
194+
return u.squeeze(), r_est.squeeze() if estimate_r else None
189195

190196

191197

@@ -236,6 +242,7 @@ def stop_sampling(sampler,stopper,listener):
236242
counter, samples, boundary_points = listener.recv()
237243
sampler.join()
238244
print("STOP")
245+
print(len(samples),len(boundary_points))
239246
return {'samples':samples,'boundary_points':boundary_points}
240247

241248
def project_samples_inside(hypersurface, samples, origin, ub, lb):
@@ -283,14 +290,9 @@ def unpack(key,list_of_dict):
283290

284291
def predict_probs(points, gpc_list):
285292

286-
probs = []
287-
for gpc in gpc_list:
288-
probs += [gpc.predict_prob(points)[:,0]]
289-
293+
total_probs = gpc_list.predict_comb_prob(points)
290294

291-
total_prob = np.prod(probs, axis=0)
292-
log_total_prob = np.sum(np.log(probs), axis=0)
293-
return total_prob, log_total_prob, probs
295+
return total_probs.squeeze(), None, None
294296

295297
def choose_next(points_candidate, points_observed, gpc_dict, d_tooclose = 100.):
296298
points_observed = np.array(points_observed)

Sampling/random_walk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def project_points_to_inside(v, gp, origin, factor=0.5):
252252

253253
def project_points_to_boundary(v, gp, origin):
254254
u, r = ur_from_v(v, origin)
255-
r_surf, _ = gp.predict_f(u)
255+
r_surf, _ = gp.predict(u)
256256

257257
v_boundary = u * r_surf + origin
258258
return v_boundary

Sampling/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def search_line(self, voltages, detector, unit_vector, step_size, ignore_idxs, l
141141
# 2. break if poff detected, but check some signal afterwards within 'len_after'
142142
first_iter = True
143143
while check_inside_boundary(voltages, self.lb, self.ub) and L2_norm(voltages-voltages_from) < max_dist:
144+
144145
if first_iter:
145146
t = time.time()
146147
# big jump expected, swith on the big jump mode

tune.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ def jump(params,inv=False):
9898
inv_timer = Timer()
9999
investigation_stage = Investigation_stage(jump,measure,check,configs['investigation'],inv_timer)
100100

101-
results = tune(jump,measure,investigation_stage,configs)
101+
results,sampler = tune(jump,measure,investigation_stage,configs)
102102

103103
plot_conditional_idx_improvment(results['conditional_idx'])
104-
return results
104+
return results,sampler
105105

106106

107107

@@ -113,8 +113,8 @@ def tune_from_file(jump,measure,check,config_file):
113113

114114
inv_timer = Timer()
115115
investigation_stage = Investigation_stage(jump,measure,check,configs['investigation'],inv_timer)
116-
results = tune(jump,measure,investigation_stage,configs)
117-
return results
116+
results,sampler = tune(jump,measure,investigation_stage,configs)
117+
return results,sampler
118118

119119

120120

@@ -138,7 +138,7 @@ def tune(jump,measure,investigation_stage,configs):
138138
for key,item in results.items():
139139
print("%s:"%(key),item[-1])
140140

141-
return results
141+
return results, ps
142142

143143

144144
def tune_origin_variable(jump,measure,par_invstage,child_invstage,par_configs,child_configs):

0 commit comments

Comments
 (0)