Skip to content

Commit d3c90c0

Browse files
authored
jitter crash fix
See huawei-noah#61
1 parent 33fc831 commit d3c90c0

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

HEBO/hebo/models/gp/gp.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from gpytorch.means import ConstantMean
2222
from gpytorch.distributions import MultivariateNormal
2323
from gpytorch.constraints import GreaterThan
24+
from gpytorch.settings import cholesky_jitter
2425

2526
from ..util import filter_nan
2627
from ..base_model import BaseModel
@@ -99,13 +100,24 @@ def fit(self, Xc : Tensor, Xe : Tensor, y : Tensor):
99100
opt = torch.optim.Adam(self.gp.parameters(), lr = self.lr)
100101
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.lik, self.gp)
101102
for epoch in range(self.num_epochs):
102-
def closure():
103-
dist = self.gp(self.Xc, self.Xe)
104-
loss = -1 * mll(dist, self.y.squeeze())
105-
opt.zero_grad()
106-
loss.backward()
107-
return loss
108-
opt.step(closure)
103+
jitter = 10 ** -8
104+
cont = True
105+
while cont:
106+
cont = False
107+
cholesky_jitter._set_value(
108+
double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
109+
def closure():
110+
dist = self.gp(self.Xc, self.Xe)
111+
loss = -1 * mll(dist, self.y.squeeze())
112+
opt.zero_grad()
113+
loss.backward()
114+
return loss
115+
try:
116+
opt.step(closure)
117+
except:
118+
jitter *= 10
119+
cont = True
120+
print(f'jitter = {jitter}')
109121
if self.verbose and ((epoch + 1) % self.print_every == 0 or epoch == 0):
110122
print('After %d epochs, loss = %g' % (epoch + 1, closure().item()), flush = True)
111123
self.gp.eval()
@@ -114,7 +126,18 @@ def closure():
114126
def predict(self, Xc, Xe):
115127
Xc, Xe = self.xtrans(Xc, Xe)
116128
with gpytorch.settings.fast_pred_var(), gpytorch.settings.debug(False):
117-
pred = self.gp(Xc, Xe)
129+
jitter = 10 ** -8
130+
cont = True
131+
while cont:
132+
cont = False
133+
cholesky_jitter._set_value(
134+
double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
135+
try:
136+
pred = self.gp(Xc, Xe)
137+
except:
138+
jitter *= 10
139+
cont = True
140+
print(f'jitter = {jitter}')
118141
if self.pred_likeli:
119142
pred = self.lik(pred)
120143
mu_ = pred.mean.reshape(-1, self.num_out)

0 commit comments

Comments
 (0)