21
21
from gpytorch .means import ConstantMean
22
22
from gpytorch .distributions import MultivariateNormal
23
23
from gpytorch .constraints import GreaterThan
24
+ from gpytorch .settings import cholesky_jitter
24
25
25
26
from ..util import filter_nan
26
27
from ..base_model import BaseModel
@@ -99,13 +100,24 @@ def fit(self, Xc : Tensor, Xe : Tensor, y : Tensor):
99
100
opt = torch .optim .Adam (self .gp .parameters (), lr = self .lr )
100
101
mll = gpytorch .mlls .ExactMarginalLogLikelihood (self .lik , self .gp )
101
102
for epoch in range (self .num_epochs ):
102
- def closure ():
103
- dist = self .gp (self .Xc , self .Xe )
104
- loss = - 1 * mll (dist , self .y .squeeze ())
105
- opt .zero_grad ()
106
- loss .backward ()
107
- return loss
108
- opt .step (closure )
103
+ jitter = 10 ** - 8
104
+ cont = True
105
+ while cont :
106
+ cont = False
107
+ cholesky_jitter ._set_value (
108
+ double_value = jitter , float_value = 100 * jitter , half_value = 10000 * jitter )
109
+ def closure ():
110
+ dist = self .gp (self .Xc , self .Xe )
111
+ loss = - 1 * mll (dist , self .y .squeeze ())
112
+ opt .zero_grad ()
113
+ loss .backward ()
114
+ return loss
115
+ try :
116
+ opt .step (closure )
117
+ except :
118
+ jitter *= 10
119
+ cont = True
120
+ print (f'jitter = { jitter } ' )
109
121
if self .verbose and ((epoch + 1 ) % self .print_every == 0 or epoch == 0 ):
110
122
print ('After %d epochs, loss = %g' % (epoch + 1 , closure ().item ()), flush = True )
111
123
self .gp .eval ()
@@ -114,7 +126,18 @@ def closure():
114
126
def predict (self , Xc , Xe ):
115
127
Xc , Xe = self .xtrans (Xc , Xe )
116
128
with gpytorch .settings .fast_pred_var (), gpytorch .settings .debug (False ):
117
- pred = self .gp (Xc , Xe )
129
+ jitter = 10 ** - 8
130
+ cont = True
131
+ while cont :
132
+ cont = False
133
+ cholesky_jitter ._set_value (
134
+ double_value = jitter , float_value = 100 * jitter , half_value = 10000 * jitter )
135
+ try :
136
+ pred = self .gp (Xc , Xe )
137
+ except :
138
+ jitter *= 10
139
+ cont = True
140
+ print (f'jitter = { jitter } ' )
118
141
if self .pred_likeli :
119
142
pred = self .lik (pred )
120
143
mu_ = pred .mean .reshape (- 1 , self .num_out )
0 commit comments