-
Notifications
You must be signed in to change notification settings - Fork 90
AIFS checkpoint update and fixes #606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AIFS checkpoint update and fixes #606
Conversation
|
Disclaimer: This is AI-generated, please review response for accuracy Greptile SummaryThis PR fixes a critical bug in the AIFS model's
|
| Filename | Overview |
|---|---|
| earth2studio/models/px/aifs.py | Adds invariants support, fixes channel indexing bug, and syncs with AIFSENS implementation |
| test/models/px/test_aifs.py | Updates tests to include invariants parameter and fixes hardcoded shape assertions |
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
- Add optional dependency group "aifs11" for the newer anemoi stack - Load checkpoint variable ordering from ai-models.json to handle v1.0 vs v1.1 differences - Make generated forcing indices checkpoint-driven (remove hard-coded 92..100 assumptions) - Add version switching to AIFS.load_default_package() and AIFS.load_model() - Update install docs and add regression tests for variable-name mapping and version selection
Ensure AIFS.load_model() runs the v1.1 optional-dependency check when checkpoint autodetection resolves to 1.1, preventing cryptic runtime errors when the aifs11 (anemoi) stack is not installed.
|
Okay I updated the checkpoint to version 1.1 as well in this PR. ECMWF inference scriptimport datetime
import os
import torch
from collections import defaultdict
import numpy as np
import earthkit.data as ekd
import earthkit.regrid as ekr
from anemoi.inference.runners.simple import SimpleRunner
from anemoi.inference.outputs.printer import print_state
from ecmwf.opendata import Client as OpendataClient
# https://huggingface.co/ecmwf/aifs-single-1.1/blob/main/run_AIFS_v1.1.ipynb
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL =["vsw","sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1,2]
DATE = OpendataClient().latest()
DATE = datetime.datetime(2026, 1, 1, 0)
print("Initial date is", DATE)
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Get data
def get_open_data(param, levelist=[]):
fields = defaultdict(list)
# Get the data for the current date and the previous date
for date in [DATE - datetime.timedelta(hours=6), DATE]:
data = ekd.from_source("ecmwf-open-data", source="aws", date=date, param=param, levelist=levelist)
for f in data:
# Open data is between -180 and 180, we need to shift it to 0-360
assert f.to_numpy().shape == (721,1440)
values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
# Interpolate the data to from 0.25 to N320
values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
# Add the values to the list
name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
fields[name].append(values)
# Create a single matrix for each parameter
for param, values in fields.items():
fields[param] = np.stack(values)
return fields
fields = {}
fields.update(get_open_data(param=PARAM_SFC))
soil=get_open_data(param=PARAM_SOIL,levelist=SOIL_LEVELS)
mapping = {'sot_1': 'stl1', 'sot_2': 'stl2',
'vsw_1': 'swvl1','vsw_2': 'swvl2'}
for k,v in soil.items():
fields[mapping[k]]=v
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))
# Transform GH to Z
for level in LEVELS:
gh = fields.pop(f"gh_{level}")
fields[f"z_{level}"] = gh * 9.80665
input_state = dict(date=DATE, fields=fields)
# Run model
step = 0
checkpoint = {"huggingface":"ecmwf/aifs-single-1.1"}
runner = SimpleRunner(checkpoint, device="cuda")
for state in runner.run(input_state=input_state, lead_time=24):
print_state(state)
variables = []
arrays = []
for name, field in state['fields'].items():
variables.append(name)
arrays.append(field)
if not os.path.exists("outputs"):
os.makedirs("outputs")
np.save(f"outputs/data_{step}.npy", np.stack(arrays, axis=-1))
np.save(f"outputs/vars_{step}.npy", np.array(variables))
np.save(f"outputs/lat.npy", state['latitudes'])
np.save(f"outputs/lon.npy", state['latitudes'])
step+=1Earth2Studio scriptfrom earth2studio.models.px import AIFS
from earth2studio.data import IFS, fetch_data
import torch
from datetime import datetime, timedelta
import numpy as np
time = np.array([datetime(2026, 1, 1, 0)], dtype=np.datetime64)
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model = AIFS.load_model(AIFS.load_default_package()).to("cuda:0")
data = IFS()
x, coords = fetch_data(
source=data,
time=time,
lead_time=model.input_coords()["lead_time"],
variable=model.input_coords()["variable"],
device="cuda:0",
)
model_iter = model.create_iterator(x, coords)
outputs = []
for step, (x, coords) in enumerate(model_iter):
print(x.shape)
if step == 0:
continue # Vanilla first step has different outputs
outputs.append(x.cpu())
if step > 4:
np.save("aifs_variable_oliver.npy", coords['variable'])
break
torch.save(torch.cat(outputs, dim=1), "aifs_reference_oliver.pt")Comparison scriptimport os
import numpy as np
import torch
import matplotlib.pyplot as plt
# --------------------
# Config
# --------------------
NEW_PT = "aifs_reference_oliver.pt"
NEW_VARS = "aifs_variable_oliver.npy"
# Choose indices to visualize
TIME_IDX = 0
LEAD_IDX = 3
# Plotting options
CMAP = "RdBu_r"
DIFF_CMAP = "PiYG"
# Read in data producted from E2S
new_vars = np.load(NEW_VARS, allow_pickle=True)
new_tensor = torch.load(NEW_PT, map_location="cpu")
new_data = new_tensor.detach().cpu().numpy()[0, LEAD_IDX]
# Read in data from ECMWF model, and then interpolate back to lat lon
from earth2studio.models.px import AIFS
model = AIFS.load_model(AIFS.load_default_package())
org_data = torch.Tensor(np.load(f"../aifs_test/outputs/data_{LEAD_IDX}.npy"))
org_vars = model._ckpt_var_to_e2s(np.load(f"../aifs_test/outputs/vars_{LEAD_IDX}.npy", allow_pickle=True))
org_data = org_data.to(dtype=torch.float64)
org_data = model.inverse_interpolation_matrix @ org_data
org_data = org_data.to(dtype=torch.float32)
org_data = torch.reshape(org_data, [721, 1440, -1])
org_data = torch.permute(org_data, (2, 0, 1))
org_data = org_data.cpu().numpy()
# Ensure they are lists of strings
common_vars = [v for v in new_vars if v in org_vars]
org_index = {v: i for i, v in enumerate(org_vars)}
new_index = {v: i for i, v in enumerate(new_vars)}
n = len(common_vars)
ncols = 3
nrows = n
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 3.5 * nrows))
# Ensure axs is 2D
axs = np.atleast_2d(axs)
for i, v in enumerate(common_vars):
ov = org_data[org_index[v], :, :]
mv = new_data[new_index[v], :, :]
# Robust color scale for the two fields together
stack = np.concatenate([ov.ravel(), mv.ravel()])
stack = stack[np.isfinite(stack)]
if stack.size > 0:
vmin, vmax = np.nanpercentile(stack, [2, 98])
else:
vmin, vmax = np.nanmin(ov), np.nanmax(ov)
# Diff symmetric scale
diff = mv - ov
dstack = diff[np.isfinite(diff)]
if dstack.size > 0:
dmin, dmax = np.nanpercentile(dstack, [2, 98])
dabs = max(abs(dmin), abs(dmax))
else:
dabs = max(abs(np.nanmin(diff)), abs(np.nanmax(diff)))
# OLD
ax = axs[i, 0]
im0 = ax.imshow(ov, cmap=CMAP, vmin=vmin, vmax=vmax, origin="upper")
ax.set_title(f"ECMWF {v} [t={TIME_IDX}, l={LEAD_IDX}]")
ax.set_xticks([]); ax.set_yticks([])
fig.colorbar(im0, ax=ax, orientation="horizontal", fraction=0.05, pad=0.06)
# NEW
ax = axs[i, 1]
im1 = ax.imshow(mv, cmap=CMAP, vmin=vmin, vmax=vmax, origin="upper")
ax.set_title(f"NEW {v} [t={TIME_IDX}, l={LEAD_IDX}]")
ax.set_xticks([]); ax.set_yticks([])
fig.colorbar(im1, ax=ax, orientation="horizontal", fraction=0.05, pad=0.06)
# DIFF
ax = axs[i, 2]
im2 = ax.imshow(diff, cmap="PiYG", vmin=-dabs, vmax=dabs, origin="upper")
ax.set_title(f"Δ (MK - OLD) {v}")
ax.set_xticks([]); ax.set_yticks([])
fig.colorbar(im2, ax=ax, orientation="horizontal", fraction=0.05, pad=0.06)
plt.tight_layout()
plt.savefig("aifs_ens_compare_test.jpg", dpi=150) |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
|
Merged changes in: #609 |


Earth2Studio Pull Request
Description
Fixes an AIFS bug with channels and brings it in sync with AIFSENS
Also updates the AIFS checkpoint to 1.1 which has some fixes from ECMWF
Closes: #600
Closes: #578