Skip to content

Commit 95ea0d0

Browse files
authored
Merge pull request #190 from aragilar/fix-linking
Add tests for CVODES and IDAS using existing tests
2 parents 7ac3140 + 3d336f7 commit 95ea0d0

File tree

5 files changed

+57
-44
lines changed

5 files changed

+57
-44
lines changed

packages/scikits-odes/src/scikits/odes/tests/test_get_info.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def rhs(x, y, ydot):
1818
#ydot[:] = (np.cos(x) * (x + 0.1) - np.sin(x)) / np.pow((x + 0.1), 2)
1919

2020

21-
class GetInfoTest(unittest.TestCase):
21+
class GetInfoTestCVODE(unittest.TestCase):
22+
solvername = 'cvode'
23+
2224
def setUp(self):
23-
self.ode = ode('cvode', rhs, old_api=False)
25+
self.ode = ode(self.solvername, rhs, old_api=False)
2426
self.solution = self.ode.solve(xs, np.array([1]))
2527

2628
def test_we_integrated_correctly(self):
@@ -45,13 +47,24 @@ def test_ode_exposes_num_rhs_evals(self):
4547
assert 'NumRhsEvals' in info
4648
assert info['NumRhsEvals'] > 0
4749

48-
class GetInfoTestSpils(unittest.TestCase):
50+
51+
class GetInfoTestCVODES(GetInfoTestCVODE):
52+
solvername = 'cvodes'
53+
54+
55+
class GetInfoTestSpilsCVODE(unittest.TestCase):
56+
solvername = 'cvode'
57+
4958
def setUp(self):
50-
self.ode = ode('cvode', rhs, linsolver="spgmr", old_api=False)
59+
self.ode = ode(self.solvername, rhs, linsolver="spgmr", old_api=False)
5160
self.solution = self.ode.solve(xs, np.array([1]))
5261

5362
def test_ode_exposes_num_njtimes_evals(self):
5463
info = self.ode.get_info()
5564
print("ode.get_info() =\n", info)
5665
assert 'NumJtimesEvals' in info
5766
assert info['NumJtimesEvals'] > 0
67+
68+
69+
class GetInfoTestSpilsCVODES(GetInfoTestSpilsCVODE):
70+
solvername = 'cvodes'

packages/scikits-odes/src/scikits/odes/tests/test_on_funcs.py renamed to packages/scikits-odes/src/scikits/odes/tests/test_on_funcs_cvode.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,16 @@ def ontstop_vc(t, y, solver):
125125

126126
return 0
127127

128-
class TestOn(TestCase):
128+
class TestOnCVODE(TestCase):
129129
"""
130130
Check integrate.dae
131131
"""
132+
solvername = 'cvode'
132133

133134
def test_cvode_rootfn_noroot(self):
134135
#test calling sequence. End is reached before root is found
135136
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
136-
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
137+
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
137138
old_api=False)
138139
soln = solver.solve(tspan, y0)
139140
assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred"
@@ -144,7 +145,7 @@ def test_cvode_rootfn_noroot(self):
144145
def test_cvode_rootfn(self):
145146
#test root finding and stopping: End is reached at a root
146147
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
147-
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
148+
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
148149
old_api=False)
149150
soln = solver.solve(tspan, y0)
150151
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Root not found!"
@@ -155,7 +156,7 @@ def test_cvode_rootfn(self):
155156
def test_cvode_rootfnacc(self):
156157
#test root finding and accumilating: End is reached normally, roots stored
157158
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
158-
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
159+
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
159160
onroot=onroot_va,
160161
old_api=False)
161162
soln = solver.solve(tspan, y0)
@@ -171,7 +172,7 @@ def test_cvode_rootfnacc(self):
171172
def test_cvode_rootfn_stop(self):
172173
#test root finding and stopping: End is reached at a root with a function
173174
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
174-
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
175+
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
175176
onroot=onroot_vb,
176177
old_api=False)
177178
soln = solver.solve(tspan, y0)
@@ -183,7 +184,7 @@ def test_cvode_rootfn_stop(self):
183184
def test_cvode_rootfn_test(self):
184185
#test root finding and accumilating: End is reached after a number of root
185186
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
186-
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
187+
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
187188
onroot=onroot_vc,
188189
old_api=False)
189190
soln = solver.solve(tspan, y0)
@@ -199,7 +200,7 @@ def test_cvode_rootfn_test(self):
199200
def test_cvode_rootfn_two(self):
200201
#test two root finding
201202
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
202-
solver = ode('cvode', rhs_fn, nr_rootfns=2, rootfn=root_fn2,
203+
solver = ode(self.solvername, rhs_fn, nr_rootfns=2, rootfn=root_fn2,
203204
onroot=onroot_vc,
204205
old_api=False)
205206
soln = solver.solve(tspan, y0)
@@ -215,7 +216,7 @@ def test_cvode_rootfn_two(self):
215216
def test_cvode_rootfn_end(self):
216217
#test root finding with root at endtime
217218
tspan = np.arange(0, 30 + 1, 1.0, DTYPE)
218-
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn3,
219+
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn3,
219220
onroot=onroot_vc,
220221
old_api=False)
221222
soln = solver.solve(tspan, y0)
@@ -233,7 +234,7 @@ def test_cvode_tstopfn_notstop(self):
233234
global n
234235
n = 0
235236
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
236-
solver = ode('cvode', rhs_fn, tstop=T1+1, ontstop=ontstop_va,
237+
solver = ode(self.solvername, rhs_fn, tstop=T1+1, ontstop=ontstop_va,
237238
old_api=False)
238239

239240
soln = solver.solve(tspan, y0)
@@ -247,7 +248,7 @@ def test_cvode_tstopfn(self):
247248
global n
248249
n = 0
249250
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
250-
solver = ode('cvode', rhs_fn, tstop=T1,
251+
solver = ode(self.solvername, rhs_fn, tstop=T1,
251252
old_api=False)
252253
soln = solver.solve(tspan, y0)
253254
assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Tstop not found!"
@@ -264,7 +265,7 @@ def test_cvode_tstopfnacc(self):
264265
global n
265266
n = 0
266267
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
267-
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_va,
268+
solver = ode(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_va,
268269
old_api=False)
269270
soln = solver.solve(tspan, y0)
270271
assert len(soln.tstop.t) == 9, "ERROR: Did not find all tstop"
@@ -282,7 +283,7 @@ def test_cvode_tstopfn_stop(self):
282283
global n
283284
n = 0
284285
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
285-
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vb,
286+
solver = ode(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vb,
286287
old_api=False)
287288

288289
soln = solver.solve(tspan, y0)
@@ -302,7 +303,7 @@ def test_cvode_tstopfn_test(self):
302303
global n
303304
n = 0
304305
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
305-
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vc,
306+
solver = ode(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vc,
306307
old_api=False)
307308

308309
soln = solver.solve(tspan, y0)
@@ -315,3 +316,7 @@ def test_cvode_tstopfn_test(self):
315316
assert allclose([soln.tstop.t[-1], soln.tstop.y[-1,0], soln.tstop.y[-1,1]],
316317
[30.0, -1452.5024, -294.30],
317318
atol=atol, rtol=rtol)
319+
320+
321+
class TestOnCVODES(TestOnCVODE):
322+
solvername = 'cvodes'

packages/scikits-odes/src/scikits/odes/tests/test_on_funcs_ida.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,16 @@ def ontstop_vc(t, y, ydot, solver):
127127

128128
return 0
129129

130-
class TestOn(TestCase):
130+
class TestOnIDA(TestCase):
131131
"""
132132
Check integrate.dae
133133
"""
134+
solvername = 'ida'
134135

135136
def test_ida_rootfn_noroot(self):
136137
#test calling sequence. End is reached before root is found
137138
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
138-
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
139+
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
139140
old_api=False)
140141
soln = solver.solve(tspan, y0, yp0)
141142
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
@@ -146,7 +147,7 @@ def test_ida_rootfn_noroot(self):
146147
def test_ida_rootfn(self):
147148
#test root finding and stopping: End is reached at a root
148149
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
149-
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
150+
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
150151
old_api=False)
151152
soln = solver.solve(tspan, y0, yp0)
152153
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Root not found!"
@@ -157,7 +158,7 @@ def test_ida_rootfn(self):
157158
def test_ida_rootfnacc(self):
158159
#test root finding and accumilating: End is reached normally, roots stored
159160
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
160-
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
161+
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
161162
onroot=onroot_va,
162163
old_api=False)
163164
soln = solver.solve(tspan, y0, yp0)
@@ -173,7 +174,7 @@ def test_ida_rootfnacc(self):
173174
def test_ida_rootfn_stop(self):
174175
#test root finding and stopping: End is reached at a root with a function
175176
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
176-
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
177+
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
177178
onroot=onroot_vb,
178179
old_api=False)
179180
soln = solver.solve(tspan, y0, yp0)
@@ -185,7 +186,7 @@ def test_ida_rootfn_stop(self):
185186
def test_ida_rootfn_test(self):
186187
#test root finding and accumilating: End is reached after a number of root
187188
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
188-
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
189+
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
189190
onroot=onroot_vc,
190191
old_api=False)
191192
soln = solver.solve(tspan, y0, yp0)
@@ -201,7 +202,7 @@ def test_ida_rootfn_test(self):
201202
def test_ida_rootfn_two(self):
202203
#test two root finding
203204
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
204-
solver = dae('ida', rhs_fn, nr_rootfns=2, rootfn=root_fn2,
205+
solver = dae(self.solvername, rhs_fn, nr_rootfns=2, rootfn=root_fn2,
205206
onroot=onroot_vc,
206207
old_api=False)
207208
soln = solver.solve(tspan, y0, yp0)
@@ -217,7 +218,7 @@ def test_ida_rootfn_two(self):
217218
def test_ida_rootfn_end(self):
218219
#test root finding with root at endtime
219220
tspan = np.arange(0, 30 + 1, 1.0, DTYPE)
220-
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn3,
221+
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn3,
221222
onroot=onroot_vc,
222223
old_api=False)
223224
soln = solver.solve(tspan, y0, yp0)
@@ -235,7 +236,7 @@ def test_ida_tstopfn_notstop(self):
235236
global n
236237
n = 0
237238
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
238-
solver = dae('ida', rhs_fn, tstop=T1+1, ontstop=ontstop_va,
239+
solver = dae(self.solvername, rhs_fn, tstop=T1+1, ontstop=ontstop_va,
239240
old_api=False)
240241
soln = solver.solve(tspan, y0, yp0)
241242
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
@@ -248,7 +249,7 @@ def test_ida_tstopfn(self):
248249
global n
249250
n = 0
250251
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
251-
solver = dae('ida', rhs_fn, tstop=T1,
252+
solver = dae(self.solvername, rhs_fn, tstop=T1,
252253
old_api=False)
253254
soln = solver.solve(tspan, y0, yp0)
254255
assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Tstop not found!"
@@ -265,7 +266,7 @@ def test_ida_tstopfnacc(self):
265266
global n
266267
n = 0
267268
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
268-
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_va,
269+
solver = dae(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_va,
269270
old_api=False)
270271
soln = solver.solve(tspan, y0, yp0)
271272
assert len(soln.tstop.t) == 9, "ERROR: Did not find all tstop"
@@ -283,7 +284,7 @@ def test_ida_tstopfn_stop(self):
283284
global n
284285
n = 0
285286
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
286-
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vb,
287+
solver = dae(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vb,
287288
old_api=False)
288289

289290
soln = solver.solve(tspan, y0, yp0)
@@ -303,7 +304,7 @@ def test_ida_tstopfn_test(self):
303304
global n
304305
n = 0
305306
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
306-
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vc,
307+
solver = dae(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vc,
307308
old_api=False)
308309

309310
soln = solver.solve(tspan, y0, yp0)
@@ -316,3 +317,7 @@ def test_ida_tstopfn_test(self):
316317
assert allclose([soln.tstop.t[-1], soln.tstop.y[-1,0], soln.tstop.y[-1,1]],
317318
[30.0, -1452.5024, -294.30],
318319
atol=atol, rtol=rtol)
320+
321+
322+
class TestOnIDAS(TestOnIDA):
323+
solvername = 'idas'

packages/scikits-odes/src/scikits/odes/tests/test_user_return_vals_cvode.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,7 @@ def jac_vec_error_immediate(v, Jv, t, y):
101101
return -1
102102

103103
class TestCVodeReturn(TestCase):
104-
105-
def __init__(self, *args, **kwargs):
106-
super(TestCVodeReturn, self).__init__(*args, **kwargs)
107-
self.solvername = "cvode"
104+
solvername = "cvode"
108105

109106
def test_normal_rhs(self):
110107
solver = ode(self.solvername, normal_rhs, old_api=False)
@@ -312,6 +309,4 @@ def test_jac_vec_error_immediate(self):
312309
)
313310

314311
class TestCVodesReturn(TestCVodeReturn):
315-
def __init__(self, *args, **kwargs):
316-
super(TestCVodesReturn, self).__init__(*args, **kwargs)
317-
self.solvername = "cvodes"
312+
solvername = "cvodes"

packages/scikits-odes/src/scikits/odes/tests/test_user_return_vals_ida.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def jac_error_immediate(t, y, ydot, residual, cj, J):
7777
return -1
7878

7979
class TestIdaReturn(TestCase):
80-
81-
def __init__(self, *args, **kwargs):
82-
super(TestIdaReturn, self).__init__(*args, **kwargs)
83-
self.solvername = "ida"
80+
solvername = "ida"
8481

8582
def test_normal_rhs(self):
8683
solver = dae(self.solvername, normal_rhs, old_api=False)
@@ -235,6 +232,4 @@ def test_jac_error_immediate(self):
235232

236233

237234
class TestIdasReturn(TestIdaReturn):
238-
def __init__(self, *args, **kwargs):
239-
super(TestIdasReturn, self).__init__(*args, **kwargs)
240-
self.solvername = "idas"
235+
solvername = "idas"

0 commit comments

Comments
 (0)