Skip to content

Commit df8f1c0

Browse files
committed
added a cross-validation example
1 parent afbd96c commit df8f1c0

File tree

7 files changed

+112
-18
lines changed

7 files changed

+112
-18
lines changed

examples/2D_leave_n_out.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
__author__ = 'cpaulson'
2+
import pyKriging
3+
from pyKriging.krige import kriging
4+
from pyKriging.samplingplan import samplingplan
5+
from pyKriging.CrossValidation import Cross_Validation
6+
from pyKriging.utilities import saveModel
7+
8+
# The Kriging model starts by defining a sampling plan, we use an optimal Latin Hypercube here
9+
sp = samplingplan(2)
10+
X = sp.optimallhc(5)
11+
12+
# Next, we define the problem we would like to solve
13+
testfun = pyKriging.testfunctions().branin
14+
15+
# We generate our observed values based on our sampling plan and the test function
16+
y = testfun(X)
17+
18+
print 'Setting up the Kriging Model'
19+
cvMSE = []
20+
# Now that we have our initial data, we can create an instance of a kriging model
21+
k = kriging(X, y, testfunction=testfun, name='simple', testPoints=300)
22+
k.train(optimizer='ga')
23+
k.snapshot()
24+
# cv = Cross_Validation(k)
25+
# cvMSE.append( cv.leave_n_out(q=5)[0] )
26+
27+
k.plot()
28+
for i in range(15):
29+
print i
30+
newpoints = k.infill(1)
31+
for point in newpoints:
32+
# print 'Adding point {}'.format(point)
33+
k.addPoint(point, testfun(point)[0])
34+
k.train(optimizer='pso')
35+
k.snapshot()
36+
# cv = Cross_Validation(k)
37+
# cvMSE.append( cv.leave_n_out(q=5)[0] )
38+
k.plot()
39+
40+
41+
42+
# saveModel(k, 'crossValidation.plk')
43+
44+
# #And plot the model
45+
46+
print 'Now plotting final results...'
47+
# k.plot()
48+
49+
50+
print k.testPoints
51+
print k.history['points']
52+
print k.history['rsquared']
53+
print k.history['avgMSE']
54+
print cvMSE
55+
from matplotlib import pylab as plt
56+
plt.plot(range(len(k.history['rsquared'])), k.history['rsquared'])
57+
plt.plot(range(len(cvMSE)), cvMSE)
58+
plt.show()
59+

examples/2d_regression_Kriging.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from pyKriging.regressionkrige import regression_kriging
66
from pyKriging.samplingplan import samplingplan
77

8+
9+
from pyKriging.krige import kriging
810
# The Kriging model starts by defining a sampling plan, we use an optimal Latin Hypercube here
911
sp = samplingplan(2)
10-
X = sp.optimallhc(25)
12+
X = sp.optimallhc(30)
1113

1214
# Next, we define the problem we would like to solve
1315
testfun = pyKriging.testfunctions().branin_noise
@@ -16,25 +18,36 @@
1618
y = testfun(X)
1719
print X, y
1820

21+
testfun = pyKriging.testfunctions().branin
22+
23+
1924
print 'Setting up the Kriging Model'
2025

2126
# Now that we have our initial data, we can create an instance of a kriging model
2227
k = regression_kriging(X, y, testfunction=testfun, name='simple', testPoints=250)
23-
k.train(optimizer='ga')
28+
k.train(optimizer='pso')
29+
k1 = kriging(X, y, testfunction=testfun, name='simple', testPoints=250)
30+
k1.train(optimizer='pso')
2431
print k.Lambda
25-
# k.snapshot()
26-
#
27-
for i in range(5):
32+
k.snapshot()
33+
34+
35+
for i in range(1):
2836
newpoints = k.infill(5)
2937
for point in newpoints:
3038
print 'Adding point {}'.format(point)
31-
k.addPoint(point, testfun(point)[0])
39+
newValue = testfun(point)[0]
40+
k.addPoint(point, newValue)
41+
k1.addPoint(point, newValue)
3242
k.train()
43+
k1.train()
3344
# k.snapshot()
3445
#
3546
# # #And plot the model
3647

3748
print 'Now plotting final results...'
38-
k.plot()
49+
print k.Lambda
50+
k.plot(show=False)
51+
k1.plot()
3952

4053

pyKriging/CrossValidation.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,24 @@
33
"""
44
import numpy as np
55
from matplotlib import pyplot as plt
6-
from pyKrige.krige import kriging
6+
import pyKriging
7+
from pyKriging.krige import kriging
8+
from pyKriging.utilities import *
79
import random
810
import scipy.stats as stats
911

1012

1113
class Cross_Validation():
1214

13-
def __init__(self, X, y, name):
15+
def __init__(self, model, name=None):
1416
"""
1517
X- sampling plane
1618
y- Objective function evaluations
1719
name- the name of the model
1820
"""
19-
self.X = X
20-
self.y = y
21+
self.model = model
22+
self.X = self.model.X
23+
self.y = self.model.y
2124
self.n, self.k = np.shape(self.X)
2225
self.predict_list, self.predict_varr, self.scvr = [], [], []
2326
self.name = name
@@ -209,4 +212,19 @@ def QQ_plot(self):
209212
plt.ylabel('Standard quantile')
210213
plt.show()
211214

215+
def leave_n_out(self, q=5):
216+
'''
217+
:param q: the numer of groups to split the model data inot
218+
:return:
219+
'''
220+
mseArray = []
221+
for i in splitArrays(self.model,5):
222+
testk = kriging( i[0], i[1] )
223+
testk.train()
224+
for j in range(len(i[2])):
225+
mseArray.append(mse(i[3][j], testk.predict( i[2][j] )))
226+
del(testk)
227+
return np.average(mseArray), np.std(mseArray)
228+
229+
212230
## Example Use Case:

pyKriging/krige.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math as m
1919

2020

21+
2122
class kriging(matrixops):
2223
def __init__(self, X, y, testfunction=None, name='', testPoints=None, **kwargs):
2324
self.X = copy.deepcopy(X)
@@ -36,7 +37,7 @@ def __init__(self, X, y, testfunction=None, name='', testPoints=None, **kwargs):
3637
self.updateData()
3738
self.updateModel()
3839

39-
self.thetamin = 1e-4
40+
self.thetamin = 1e-5
4041
self.thetamax = 100
4142
self.pmin = 1
4243
self.pmax = 2
@@ -51,6 +52,7 @@ def __init__(self, X, y, testfunction=None, name='', testPoints=None, **kwargs):
5152
self.history['adjrsquared'] = [0]
5253
self.history['chisquared'] = [1000]
5354
self.history['lastPredictedPoints'] = []
55+
self.history['avgMSE'] = []
5456
if testPoints:
5557
self.history['pointData'] = []
5658
self.testPoints = self.sp.rlh(testPoints)
@@ -116,10 +118,10 @@ def normalizeData(self):
116118
for i in range(self.k):
117119
self.normRange.append([min(self.X[:, i]), max(self.X[:, i])])
118120

119-
print self.X
121+
# print self.X
120122
for i in range(self.n):
121123
self.X[i] = self.normX(self.X[i])
122-
print self.X
124+
#print self.X
123125

124126
self.ynormRange.append(min(self.y))
125127
self.ynormRange.append(max(self.y))
@@ -688,6 +690,8 @@ def snapshot(self):
688690
self.history['theta'].append(copy.deepcopy(self.theta))
689691
self.history['p'].append(copy.deepcopy(self.pl))
690692

693+
self.history['avgMSE'].append(self.calcuatemeanMSE(points=self.testPoints)[0])
694+
691695
currentPredictions = []
692696
if self.history['pointData']!=None:
693697
for pointprim in self.history['pointData']:

pyKriging/regressionkrige.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def updateModel(self):
167167
try:
168168
self.regupdatePsi()
169169
except Exception, err:
170-
# pass
170+
pass
171171
# print Exception, err
172-
raise Exception("bad params")
172+
# raise Exception("bad params")
173173

174174
def predict(self, X, norm=True):
175175
'''

pyKriging/testfunctions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def branin_noise(self, X):
8484
noiseFree = ((a*( X2 - b*X1**2 + c*X1 - d )**2 + e*(1-ff)*np.cos(X1) + e)+5*x)
8585
withNoise=[]
8686
for i in noiseFree:
87-
withNoise.append(i + np.random.standard_normal()*7)
87+
withNoise.append(i + np.random.standard_normal()*15)
8888
return np.array(withNoise)
8989

9090

pyKriging/utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def norm(x):
99
return x-min(x)
1010

1111
def saveModel(model, filePath):
12-
pickle.dump(model, open(filePath, 'w'))
12+
pickle.dump(model, open(filePath, 'w'), byref=True)
1313

1414
def loadModel(filePath):
1515
return pickle.load(open(filePath,'r'))

0 commit comments

Comments
 (0)