Skip to content

Commit

Permalink
Merge pull request #1488 from OceanParcels/fix_repeated_release_traj_…
Browse files Browse the repository at this point in the history
…chunks_padding

Fixing bug in zarr output when setting chunking and repeatdt
  • Loading branch information
erikvansebille authored Dec 22, 2023
2 parents 3cf8b2a + 07360a0 commit a3d90da
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _create_variables_attribute_dict(self):

if self.time_origin.calendar is not None:
attrs['time']['units'] = "seconds since " + str(self.time_origin)
attrs['time']['calendar'] = 'standard' if self.time_origin.calendar == 'np_datetime64' else self.time_origin.calendar
attrs['time']['calendar'] = _set_calendar(self.time_origin.calendar)

for vname in self.vars_to_write:
if vname not in ['time', 'lat', 'lon', 'depth', 'id']:
Expand Down Expand Up @@ -183,7 +183,7 @@ def _extend_zarr_dims(self, Z, store, dtype, axis):
if len(obs) == Z.shape[1]:
obs.append(np.arange(self.chunks[1])+obs[-1]+1)
else:
extra_trajs = max(self.maxids - Z.shape[0], self.chunks[0])
extra_trajs = self.maxids - Z.shape[0]
if len(Z.shape) == 2:
a = np.full((extra_trajs, Z.shape[1]), self.fill_value_map[dtype], dtype=dtype)
else:
Expand Down Expand Up @@ -233,7 +233,7 @@ def write(self, pset, time, indices=None):
if (self.maxids > len(ids)) or (self.maxids > self.chunks[0]):
arrsize = (self.maxids, self.chunks[1])
else:
arrsize = (len(ids), 1)
arrsize = (len(ids), self.chunks[1])
ds = xr.Dataset(attrs=self.metadata, coords={"trajectory": ("trajectory", pids),
"obs": ("obs", np.arange(arrsize[1], dtype=np.int32))})
attrs = self._create_variables_attribute_dict()
Expand Down
19 changes: 19 additions & 0 deletions tests/test_particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,25 @@ def IncrLon(particle, fieldset, time):
ds.close()


@pytest.mark.parametrize('mode', ['scipy', 'jit'])
@pytest.mark.parametrize('repeatdt', [1, 2])
@pytest.mark.parametrize('nump', [1, 10])
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmpdir):
runtime = 8
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.zeros((nump, 1)),
lat=np.zeros((nump, 1)), repeatdt=repeatdt)
outfilepath = tmpdir.join("pfile_chunks_repeatedrelease.zarr")
chunks = (20, 10)
pfile = pset.ParticleFile(outfilepath, outputdt=1, chunks=chunks)

def DoNothing(particle, fieldset, time):
pass

pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile)
ds = xr.open_zarr(outfilepath)
assert ds['time'].shape == (int(nump*runtime/repeatdt), chunks[1])


@pytest.mark.parametrize('mode', ['scipy', 'jit'])
def test_write_timebackward(fieldset, mode, tmpdir):
outfilepath = tmpdir.join("pfile_write_timebackward.zarr")
Expand Down

0 comments on commit a3d90da

Please sign in to comment.