Skip to content

Conversation

@loliverhennigh
Copy link
Collaborator

@loliverhennigh loliverhennigh commented Jan 6, 2026

Earth2Studio Pull Request

Description

Fixes an AIFS bug with channels and brings it in sync with AIFSENS
Also updates the AIFS checkpoint to 1.1 which has some fixes from ECMWF

Closes: #600
Closes: #578

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 6, 2026

Disclaimer: This is AI-generated, please review response for accuracy

Greptile Summary

This PR fixes a critical bug in the AIFS model's _update_input method and synchronizes its implementation with AIFSENS. The main changes include:

  • Fixed tensor indexing bug in _update_input method (line 716-720): changed x[:, 1:2, :, ...] to x[:, :, :, ...] to correctly populate time-dependent forcings across both time steps instead of only the second time step
  • Added invariants support by introducing a new invariants parameter to store static fields (lsm, sdor, slor, z)
  • Implemented _add_invariants method to properly inject invariant fields into input tensors
  • Updated input_variables and output_variables properties to use dynamic data_indices instead of hardcoded channel ranges (92-101)
  • Added fcstep parameter to predict_step and _forward methods for proper multi-step forecasting
  • Improved _update_input to compute time-dependent forcings for both time steps (t0 and t1) and concatenate them
  • Added imports for IFS data source to fetch invariant fields from ECMWF
  • Updated test fixtures to include data.output.forcing indices and invariants parameter

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes fix a critical indexing bug, add proper invariants support, and bring AIFS in sync with the already-tested AIFSENS implementation. All modifications are well-tested with comprehensive test coverage including the new invariants parameter and dynamic shape assertions.
  • No files require special attention

Important Files Changed

Filename Overview
earth2studio/models/px/aifs.py Adds invariants support, fixes channel indexing bug, and syncs with AIFSENS implementation
test/models/px/test_aifs.py Updates tests to include invariants parameter and fixes hardcoded shape assertions

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 6, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

Y.Takano and others added 7 commits January 8, 2026 09:38
- Add optional dependency group "aifs11" for the newer anemoi stack
- Load checkpoint variable ordering from ai-models.json to handle v1.0 vs v1.1 differences
- Make generated forcing indices checkpoint-driven (remove hard-coded 92..100 assumptions)
- Add version switching to AIFS.load_default_package() and AIFS.load_model()
- Update install docs and add regression tests for variable-name mapping and version selection
Ensure AIFS.load_model() runs the v1.1 optional-dependency check when checkpoint
autodetection resolves to 1.1, preventing cryptic runtime errors when the aifs11
(anemoi) stack is not installed.
@NickGeneva NickGeneva changed the title aifs fix / aifsens sync AIFS checkpoint update and fixes Jan 12, 2026
@NickGeneva
Copy link
Collaborator

NickGeneva commented Jan 12, 2026

Okay I updated the checkpoint to version 1.1 as well in this PR.
I set up inference scripts to compare the implemation with ECMWF's scripts from the HF repo of the checkpoint:

ECMWF inference script
import datetime
import os
import torch
from collections import defaultdict

import numpy as np
import earthkit.data as ekd
import earthkit.regrid as ekr

from anemoi.inference.runners.simple import SimpleRunner
from anemoi.inference.outputs.printer import print_state

from ecmwf.opendata import Client as OpendataClient

# https://huggingface.co/ecmwf/aifs-single-1.1/blob/main/run_AIFS_v1.1.ipynb

PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL =["vsw","sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1,2]


DATE = OpendataClient().latest()
DATE = datetime.datetime(2026, 1, 1, 0)
print("Initial date is", DATE)


seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# Get data

def get_open_data(param, levelist=[]):
    fields = defaultdict(list)
    # Get the data for the current date and the previous date
    for date in [DATE - datetime.timedelta(hours=6), DATE]:
        data = ekd.from_source("ecmwf-open-data", source="aws", date=date, param=param, levelist=levelist)
        for f in data:
            # Open data is between -180 and 180, we need to shift it to 0-360
            assert f.to_numpy().shape == (721,1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            # Interpolate the data to from 0.25 to N320
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            # Add the values to the list
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            fields[name].append(values)

    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)

    return fields

fields = {}
fields.update(get_open_data(param=PARAM_SFC))
soil=get_open_data(param=PARAM_SOIL,levelist=SOIL_LEVELS)
mapping = {'sot_1': 'stl1', 'sot_2': 'stl2',
           'vsw_1': 'swvl1','vsw_2': 'swvl2'}
for k,v in soil.items():
    fields[mapping[k]]=v

fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))

# Transform GH to Z
for level in LEVELS:
    gh = fields.pop(f"gh_{level}")
    fields[f"z_{level}"] = gh * 9.80665

input_state = dict(date=DATE, fields=fields)


# Run model
step = 0
checkpoint = {"huggingface":"ecmwf/aifs-single-1.1"}
runner = SimpleRunner(checkpoint, device="cuda")
for state in runner.run(input_state=input_state, lead_time=24):
    print_state(state)

    variables = []
    arrays = []
    for name, field in state['fields'].items():
        variables.append(name)
        arrays.append(field)

    if not os.path.exists("outputs"):
        os.makedirs("outputs")

    np.save(f"outputs/data_{step}.npy", np.stack(arrays, axis=-1))
    np.save(f"outputs/vars_{step}.npy", np.array(variables))
    np.save(f"outputs/lat.npy", state['latitudes'])
    np.save(f"outputs/lon.npy", state['latitudes'])
    step+=1
Earth2Studio script
from earth2studio.models.px import AIFS
from earth2studio.data import IFS, fetch_data

import torch
from datetime import datetime, timedelta
import numpy as np

time = np.array([datetime(2026, 1, 1, 0)], dtype=np.datetime64)

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


model = AIFS.load_model(AIFS.load_default_package()).to("cuda:0")
data = IFS()

x, coords = fetch_data(
        source=data,
        time=time,
        lead_time=model.input_coords()["lead_time"],
        variable=model.input_coords()["variable"],
        device="cuda:0",
)
model_iter = model.create_iterator(x, coords)

outputs = []
for step, (x, coords) in enumerate(model_iter):
    print(x.shape)
    if step == 0:
        continue # Vanilla first step has different outputs
    outputs.append(x.cpu())
    if step > 4:
        np.save("aifs_variable_oliver.npy", coords['variable'])
        break

torch.save(torch.cat(outputs, dim=1), "aifs_reference_oliver.pt")
Comparison script
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

# --------------------
# Config
# --------------------
NEW_PT  = "aifs_reference_oliver.pt"
NEW_VARS  = "aifs_variable_oliver.npy"

# Choose indices to visualize
TIME_IDX = 0
LEAD_IDX = 3

# Plotting options
CMAP = "RdBu_r"
DIFF_CMAP = "PiYG"

# Read in data producted from E2S
new_vars  = np.load(NEW_VARS, allow_pickle=True)
new_tensor  = torch.load(NEW_PT, map_location="cpu")
new_data  = new_tensor.detach().cpu().numpy()[0, LEAD_IDX]

# Read in data from ECMWF model, and then interpolate back to lat lon
from earth2studio.models.px import AIFS
model = AIFS.load_model(AIFS.load_default_package())
org_data = torch.Tensor(np.load(f"../aifs_test/outputs/data_{LEAD_IDX}.npy"))

org_vars = model._ckpt_var_to_e2s(np.load(f"../aifs_test/outputs/vars_{LEAD_IDX}.npy", allow_pickle=True))
org_data = org_data.to(dtype=torch.float64)
org_data = model.inverse_interpolation_matrix @ org_data
org_data = org_data.to(dtype=torch.float32)
org_data = torch.reshape(org_data, [721, 1440, -1])
org_data = torch.permute(org_data, (2, 0, 1))
org_data = org_data.cpu().numpy()

# Ensure they are lists of strings
common_vars = [v for v in new_vars if v in org_vars]
org_index = {v: i for i, v in enumerate(org_vars)}
new_index  = {v: i for i, v in enumerate(new_vars)}

n = len(common_vars)
ncols = 3
nrows = n
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 3.5 * nrows))

# Ensure axs is 2D
axs = np.atleast_2d(axs)

for i, v in enumerate(common_vars):
    ov = org_data[org_index[v], :, :]
    mv = new_data[new_index[v],  :, :]

    # Robust color scale for the two fields together
    stack = np.concatenate([ov.ravel(), mv.ravel()])
    stack = stack[np.isfinite(stack)]
    if stack.size > 0:
        vmin, vmax = np.nanpercentile(stack, [2, 98])
    else:
        vmin, vmax = np.nanmin(ov), np.nanmax(ov)

    # Diff symmetric scale
    diff = mv - ov
    dstack = diff[np.isfinite(diff)]
    if dstack.size > 0:
        dmin, dmax = np.nanpercentile(dstack, [2, 98])
        dabs = max(abs(dmin), abs(dmax))
    else:
        dabs = max(abs(np.nanmin(diff)), abs(np.nanmax(diff)))

    # OLD
    ax = axs[i, 0]
    im0 = ax.imshow(ov, cmap=CMAP, vmin=vmin, vmax=vmax, origin="upper")
    ax.set_title(f"ECMWF {v} [t={TIME_IDX}, l={LEAD_IDX}]")
    ax.set_xticks([]); ax.set_yticks([])
    fig.colorbar(im0, ax=ax, orientation="horizontal", fraction=0.05, pad=0.06)

    # NEW
    ax = axs[i, 1]
    im1 = ax.imshow(mv, cmap=CMAP, vmin=vmin, vmax=vmax, origin="upper")
    ax.set_title(f"NEW {v} [t={TIME_IDX}, l={LEAD_IDX}]")
    ax.set_xticks([]); ax.set_yticks([])
    fig.colorbar(im1, ax=ax, orientation="horizontal", fraction=0.05, pad=0.06)

    # DIFF
    ax = axs[i, 2]
    im2 = ax.imshow(diff, cmap="PiYG", vmin=-dabs, vmax=dabs, origin="upper")
    ax.set_title(f"Δ (MK - OLD) {v}")
    ax.set_xticks([]); ax.set_yticks([])
    fig.colorbar(im2, ax=ax, orientation="horizontal", fraction=0.05, pad=0.06)

plt.tight_layout()
plt.savefig("aifs_ens_compare_test.jpg", dpi=150)
Plot of prediction 24 hours out

aifs_ens_compare_test
aifs_ens_compare_test

@NickGeneva NickGeneva added the external An awesome external contributor PR label Jan 12, 2026
@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva NickGeneva merged commit ac91d83 into NVIDIA:main Jan 13, 2026
7 checks passed
@NickGeneva
Copy link
Collaborator

Merged changes in: #609

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

external An awesome external contributor PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛[BUG]: AIFS _fill_input bug 🐛[BUG]: AIFS cannot be used with CDS datasource out-of-the-box

3 participants