Skip to content

Commit ee8c21e

Browse files
authored
AIFS-ENS Initial time-step fix (#615)
* Fix AIFS ENS
1 parent a56ab29 commit ee8c21e

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

earth2studio/models/px/aifsens.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -846,13 +846,15 @@ def _fill_input(self, x: torch.Tensor, coords: CoordSystem) -> torch.Tensor:
846846
out[:, :, 0, indices[valid_mask]] = x[0, 0, 0, ...]
847847
out[:, :, 1, indices[valid_mask]] = x[0, 0, 1, ...]
848848

849-
# Drop generated forcing dimension range from output
850-
out = torch.cat([out[:, :, :, :92, ...], out[:, :, :, 101:, ...]], dim=3)
849+
# Drop generated forcing / invariants from output
850+
all_indices = torch.arange(len(VARIABLES))
851+
keep = torch.isin(
852+
all_indices, self.model.data_indices.data.output.forcing, invert=True
853+
)
854+
out = out[:, :, :, keep, ...]
851855

852856
# Update coordinates with remaining variable names
853-
all_indices = torch.arange(len(VARIABLES))
854-
variable_mask = ~torch.isin(all_indices, self.forcing_ids)
855-
selected_variables = [VARIABLES[i] for i in all_indices[variable_mask].tolist()]
857+
selected_variables = [VARIABLES[i] for i in all_indices[keep].tolist()]
856858

857859
out_coords = coords.copy()
858860
out_coords["variable"] = np.array(selected_variables)

test/models/px/test_aifsens.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def test_aifsens_iter(ensemble, device):
235235
if not isinstance(time, Iterable):
236236
time = [time]
237237

238-
next(p_iter)
239238
for i, (out, out_coords) in enumerate(p_iter):
240239
assert len(out.shape) == 6
241240
assert out.shape == torch.Size(
@@ -245,7 +244,7 @@ def test_aifsens_iter(ensemble, device):
245244
out_coords["variable"] == p.output_coords(p.input_coords())["variable"]
246245
).all()
247246
assert (out_coords["ensemble"] == np.arange(ensemble)).all()
248-
assert out_coords["lead_time"][0] == np.timedelta64(6 * (i + 1), "h")
247+
assert out_coords["lead_time"][0] == np.timedelta64(6 * (i), "h")
249248

250249
if i > 5:
251250
break

0 commit comments

Comments
 (0)