Skip to content

Commit

Permalink
Merge pull request #653 from OceanParcels/multiple_timestamps
Browse files Browse the repository at this point in the history
Multiple timestamps
  • Loading branch information
erikvansebille authored Oct 8, 2019
2 parents d82448a + 8f42449 commit 59e4459
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 13 deletions.
2 changes: 1 addition & 1 deletion parcels/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include):
body += [c.Assign("particles[p].state", "res")] # Store return code on particle
update_pdt = c.If("_next_dt_set == 1", c.Block([c.Assign("_next_dt_set", "0"), c.Assign("particles[p].dt", "_next_dt")]))
body += [c.If("res == SUCCESS || res == DELETE", c.Block([c.Statement("particles[p].time += particles[p].dt"), update_pdt,
dt_pos, dt_0_break, c.Statement("continue")]),
dt_pos, dt_0_break, c.Statement("continue")]),
c.Block([c.Statement("get_particle_backup(&particle_backup, &(particles[p]))"),
dt_pos, c.Statement("break")]))]

Expand Down
27 changes: 19 additions & 8 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ def __init__(self, name, data, lon=None, lat=None, depth=None, time=None, grid=N
self.units = unitconverters_map[self.fieldtype]
else:
raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'")
self.timestamps = timestamps

if timestamps is not None:
# Check whether flattened or not
if all(isinstance(f, np.ndarray) for f in timestamps):
self.timestamps = np.array([stamp for f in timestamps for stamp in f])
if all(isinstance(stamp, np.datetime64) for stamp in timestamps):
self.timestamps = timestamps
else:
self.timestamps = timestamps
if type(interp_method) is dict:
if self.name in interp_method:
self.interp_method = interp_method[self.name]
Expand Down Expand Up @@ -228,12 +234,13 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
# Ensure the timestamps array is compatible with the user-provided datafiles.
if timestamps is not None:
if isinstance(filenames, list):
assert len(filenames) == len(timestamps), 'Number of files and number of timestamps must be equal.'
assert len(filenames) == len(timestamps), 'Outer dimension of timestamps should correspond to number of files.'
elif isinstance(filenames, dict):
for k in filenames.keys():
assert(len(filenames[k]) == len(timestamps)), 'Number of files and number of timestamps must be equal.'
assert(len(filenames[k]) == len(timestamps)), 'Outer dimension of timestamps should correspond to number of files.'
else:
raise TypeError("filenames type is inconsistent with manual timestamp provision.")
raise TypeError("Filenames type is inconsistent with manual timestamp provision."
+ "Should be dict or list")

if isinstance(variable, xr.core.dataarray.DataArray):
lonlat_filename = variable
Expand Down Expand Up @@ -302,9 +309,13 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
# Concatenate time variable to determine overall dimension
# across multiple files
if timestamps is not None:
dataFiles = []
for findex in range(len(data_filenames)):
for f in [data_filenames[findex]] * len(timestamps[findex]):
dataFiles.append(f)
timestamps = np.array([stamp for file in timestamps for stamp in file])
timeslices = timestamps
time = np.concatenate(timeslices)
dataFiles = np.array(data_filenames)
time = timeslices
elif netcdf_engine == 'xarray':
with NetcdfFileBuffer(data_filenames, dimensions, indices, netcdf_engine) as filebuffer:
time = filebuffer.time
Expand Down Expand Up @@ -1073,7 +1084,7 @@ def advancetime(self, field_new, advanceForward):
def computeTimeChunk(self, data, tindex):
g = self.grid
timestamp = None if self.timestamps is None else self.timestamps[tindex]
filebuffer = NetcdfFileBuffer(self.dataFiles[g.ti+tindex], self.dimensions, self.indices,
filebuffer = NetcdfFileBuffer(self.dataFiles[g.ti + tindex], self.dimensions, self.indices,
self.netcdf_engine, timestamp=timestamp,
interp_method=self.interp_method,
data_full_zdim=self.data_full_zdim,
Expand Down
12 changes: 8 additions & 4 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def from_netcdf(cls, filenames, variables, dimensions, indices=None, fieldtype=N
1. spherical (default): Lat and lon in degree, with a
correction for zonal velocity U near the poles.
2. flat: No conversion, lat/lon are assumed to be in m.
:param timestamps: A numpy array containing the timestamps for each of the files in filenames.
:param timestamps: list of lists or array of arrays containing the timestamps for
each of the files in filenames. Outer list/array corresponds to files, inner
array corresponds to indices within files.
Default is None if dimensions includes time.
:param allow_time_extrapolation: boolean whether to allow for extrapolation
(i.e. beyond the last available time snapshot)
Expand All @@ -255,12 +257,14 @@ def from_netcdf(cls, filenames, variables, dimensions, indices=None, fieldtype=N
logger.warning_once("Time already provided, defaulting to dimensions['time'] over timestamps.")
timestamps = None

# Typecast timestamps to numpy array & correct shape.
# Typecast timestamps to (nested) numpy array.
if timestamps is not None:
if isinstance(timestamps, list):
timestamps = np.array(timestamps)
timestamps = np.reshape(timestamps, [timestamps.size, 1])

if any(isinstance(i, list) for i in timestamps):
timestamps = np.array([np.array(sub) for sub in timestamps])
assert isinstance(timestamps, np.ndarray), "Timestamps must be nested list or array"
assert all(isinstance(file, np.ndarray) for file in timestamps), "Timestamps must be nested list or array"
fields = {}
if 'creation_log' not in kwargs.keys():
kwargs['creation_log'] = 'from_netcdf'
Expand Down
23 changes: 23 additions & 0 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,29 @@ def test_vector_fields(mode, swapUV):
assert abs(pset[0].lat - .5) < 1e-9


@pytest.mark.parametrize('datetype', ['float', 'datetime64'])
def test_timestaps(datetype, tmpdir):
data1, dims1 = generate_fieldset(10, 10, 1, 10)
data2, dims2 = generate_fieldset(10, 10, 1, 4)
if datetype == 'float':
dims1['time'] = np.arange(0, 10, 1) * 3600
dims2['time'] = np.arange(10, 14, 1) * 3600
else:
dims1['time'] = np.arange('2005-02-01', '2005-02-11', dtype='datetime64[D]')
dims2['time'] = np.arange('2005-02-11', '2005-02-15', dtype='datetime64[D]')

fieldset1 = FieldSet.from_data(data1, dims1)
fieldset1.write(tmpdir.join('file1'))

fieldset2 = FieldSet.from_data(data2, dims2)
fieldset2.write(tmpdir.join('file2'))

fieldset3 = FieldSet.from_parcels(tmpdir.join('file*'))
timestamps = [dims1['time'], dims2['time']]
fieldset4 = FieldSet.from_parcels(tmpdir.join('file*'), timestamps=timestamps)
assert np.allclose(fieldset3.U.grid.time_full, fieldset4.U.grid.time_full)


@pytest.mark.parametrize('mode', ['scipy', 'jit'])
@pytest.mark.parametrize('time_periodic', [86400., False])
@pytest.mark.parametrize('dt_sign', [-1, 1])
Expand Down

0 comments on commit 59e4459

Please sign in to comment.