Skip to content

Commit

Permalink
Merge pull request #676 from OceanParcels/bugTimestamps
Browse files Browse the repository at this point in the history
Bug timestamps
  • Loading branch information
erikvansebille authored Oct 17, 2019
2 parents d047ba7 + b975fee commit 35c84f1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
18 changes: 7 additions & 11 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,7 @@ def __init__(self, name, data, lon=None, lat=None, depth=None, time=None, grid=N
self.units = unitconverters_map[self.fieldtype]
else:
raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'")
if timestamps is not None:
# Check whether flattened or not
if all(isinstance(f, np.ndarray) for f in timestamps):
self.timestamps = np.array([stamp for f in timestamps for stamp in f])
if all(isinstance(stamp, np.datetime64) for stamp in timestamps):
self.timestamps = timestamps
else:
self.timestamps = timestamps
self.timestamps = timestamps
if type(interp_method) is dict:
if self.name in interp_method:
self.interp_method = interp_method[self.name]
Expand Down Expand Up @@ -313,8 +306,7 @@ def from_netcdf(cls, filenames, variable, dimensions, indices=None, grid=None,
for findex in range(len(data_filenames)):
for f in [data_filenames[findex]] * len(timestamps[findex]):
dataFiles.append(f)
timestamps = np.array([stamp for file in timestamps for stamp in file])
timeslices = timestamps
timeslices = np.array([stamp for file in timestamps for stamp in file])
time = timeslices
elif netcdf_engine == 'xarray':
with NetcdfFileBuffer(data_filenames, dimensions, indices, netcdf_engine) as filebuffer:
Expand Down Expand Up @@ -1083,7 +1075,11 @@ def advancetime(self, field_new, advanceForward):

def computeTimeChunk(self, data, tindex):
g = self.grid
timestamp = None if self.timestamps is None else self.timestamps[tindex]
timestamp = self.timestamps
if timestamp is not None:
summedlen = np.cumsum([len(ls) for ls in self.timestamps])
timestamp = self.timestamps[np.where(g.ti + tindex < summedlen)[0][0]]

filebuffer = NetcdfFileBuffer(self.dataFiles[g.ti + tindex], self.dimensions, self.indices,
self.netcdf_engine, timestamp=timestamp,
interp_method=self.interp_method,
Expand Down
8 changes: 0 additions & 8 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,6 @@ def from_netcdf(cls, filenames, variables, dimensions, indices=None, fieldtype=N
logger.warning_once("Time already provided, defaulting to dimensions['time'] over timestamps.")
timestamps = None

# Typecast timestamps to (nested) numpy array.
if timestamps is not None:
if isinstance(timestamps, list):
timestamps = np.array(timestamps)
if any(isinstance(i, list) for i in timestamps):
timestamps = np.array([np.array(sub) for sub in timestamps])
assert isinstance(timestamps, np.ndarray), "Timestamps must be nested list or array"
assert all(isinstance(file, np.ndarray) for file in timestamps), "Timestamps must be nested list or array"
fields = {}
if 'creation_log' not in kwargs.keys():
kwargs['creation_log'] = 'from_netcdf'
Expand Down
7 changes: 7 additions & 0 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,23 @@ def test_timestaps(datetype, tmpdir):
dims2['time'] = np.arange('2005-02-11', '2005-02-15', dtype='datetime64[D]')

fieldset1 = FieldSet.from_data(data1, dims1)
fieldset1.U.data[0, :, :] = 2.
fieldset1.write(tmpdir.join('file1'))

fieldset2 = FieldSet.from_data(data2, dims2)
fieldset2.U.data[0, :, :] = 0.
fieldset2.write(tmpdir.join('file2'))

fieldset3 = FieldSet.from_parcels(tmpdir.join('file*'))
timestamps = [dims1['time'], dims2['time']]
fieldset4 = FieldSet.from_parcels(tmpdir.join('file*'), timestamps=timestamps)
assert np.allclose(fieldset3.U.grid.time_full, fieldset4.U.grid.time_full)

for d in [0, 8, 10]:
fieldset3.computeTimeChunk(d*3600, 1)
fieldset4.computeTimeChunk(d*3600, 1)
assert np.allclose(fieldset3.U.data, fieldset4.U.data)


@pytest.mark.parametrize('mode', ['scipy', 'jit'])
@pytest.mark.parametrize('time_periodic', [86400., False])
Expand Down

0 comments on commit 35c84f1

Please sign in to comment.