diff --git a/earth2studio/run.py b/earth2studio/run.py index f7d48b4c1..55c11288b 100644 --- a/earth2studio/run.py +++ b/earth2studio/run.py @@ -44,10 +44,15 @@ def deterministic( io: IOBackend, output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, + checkpoint_path: str | None = None, + checkpoint_interval: int | None = None, + resume_from_step: int | None = None, ) -> IOBackend: """Built in deterministic workflow. - This workflow creates a determinstic inference pipeline to produce a forecast - prediction using a prognostic model. + This workflow creates a deterministic inference pipeline to produce a forecast + prediction using a prognostic model. Supports saving and resuming from checkpoints to + handle GPU memory constraints in long-running simulations. + Parameters ---------- @@ -65,12 +70,41 @@ def deterministic( IO output coordinate system override, by default OrderedDict({}) device : torch.device, optional Device to run inference on, by default None + checkpoint_path : str, optional + Path to save/load checkpoints, by default None + checkpoint_interval : int, optional + Save checkpoint every N steps, by default None + resume_from_step : int, optional + Resume from this step number, by default None Returns ------- IOBackend Output IO object + + + Examples + -------- + + Basic usage without checkpointing: + >>> io = deterministic(time, nsteps, prognostic_model, data, io_backend) + + Save checkpoints every 5 steps: + >>> io = deterministic(time, nsteps, prognostic_model, data, io_backend, + checkpoint_path="checkpoint.pt", checkpoint_interval=5) + + Resume from step 10: + >>> io = deterministic(time, nsteps, prognostic_model, data, io_backend, + checkpoint_path="checkpoint.pt", resume_from_step=10) + """ + from earth2studio.utils.checkpoint import ( + load_checkpoint, + save_checkpoint, + should_checkpoint, + validate_checkpoint_compatibility, + ) + # sphinx - deterministic end logger.info("Running simple workflow!") # Load model onto the device @@ -81,30 +115,52 @@ def deterministic( ) logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) - # sphinx - fetch data start - # Fetch data from data source and load onto device - prognostic_ic = prognostic.input_coords() - time = to_time_array(time) - if hasattr(prognostic, "interp_method"): - interp_to = prognostic_ic - interp_method = prognostic.interp_method - else: - interp_to = None - interp_method = "nearest" - - x, coords = fetch_data( - source=data, - time=time, - variable=prognostic_ic["variable"], - lead_time=prognostic_ic["lead_time"], - device=device, - interp_to=interp_to, - interp_method=interp_method, - ) + # Handle resume from checkpoint + if resume_from_step is not None: + if checkpoint_path is None: + raise ValueError( + "checkpoint_path must be provided when resume_from_step is specified" + ) + checkpoint = load_checkpoint(checkpoint_path, device) + + logger.info(f"Resuming from checkpoint at step {resume_from_step}") - logger.success(f"Fetched data from {data.__class__.__name__}") - # sphinx - fetch data end + if not validate_checkpoint_compatibility(checkpoint["coords"], prognostic): + raise ValueError("Checkpoint incompatible with current prognostic model") + + x, coords = checkpoint["state"], checkpoint["coords"] + start_step = resume_from_step + + logger.success("Resumed from checkpoint, skipping data fetch") + else: + # Normal initialization - fetch from data source + + # sphinx - fetch data start + # Fetch data from data source and load onto device + prognostic_ic = prognostic.input_coords() + time = to_time_array(time) + + if hasattr(prognostic, "interp_method"): + interp_to = prognostic_ic + interp_method = prognostic.interp_method + else: + interp_to = None + interp_method = "nearest" + + x, coords = fetch_data( + source=data, + time=time, + variable=prognostic_ic["variable"], + lead_time=prognostic_ic["lead_time"], + device=device, + interp_to=interp_to, + interp_method=interp_method, + ) + start_step = 0 + + logger.success(f"Fetched data from {data.__class__.__name__}") + # sphinx - fetch data end # Set up IO backend total_coords = prognostic.output_coords(prognostic.input_coords()).copy() @@ -130,21 +186,62 @@ def deterministic( # Map lat and lon if needed x, coords = map_coords(x, coords, prognostic.input_coords()) - # Create prognostic iterator - model = prognostic.create_iterator(x, coords) - logger.info("Inference starting!") - with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar: - for step, (x, coords) in enumerate(model): - # Subselect domain/variables as indicated in output_coords - x, coords = map_coords(x, coords, output_coords) - io.write(*split_coords(x, coords)) - pbar.update(1) - if step == nsteps: - break + if resume_from_step is not None: + # CHECKPOINT RESUME PATH - Manual time-stepping + logger.info("Using manual time-stepping for checkpointed run") + with tqdm( + total=(nsteps + 1) - start_step, desc="Running inference", position=1 + ) as pbar: + for current_step in range(start_step, nsteps + 1): + x_out, coords_out = map_coords(x, coords, output_coords) + io.write(*split_coords(x_out, coords_out)) + pbar.update(1) + + if ( + should_checkpoint( + current_step, checkpoint_interval, checkpoint_path + ) + and checkpoint_path is not None + ): + save_checkpoint( + current_step, x, coords, checkpoint_path, "deterministic" + ) + logger.info(f"Saved checkpoint at step {current_step}") + + if current_step < nsteps: + x, coords = prognostic(x, coords) + + logger.success("Inference complete") + return io + else: + # NORMAL PATH - Use existing iterator + # Create prognostic iterator + model = prognostic.create_iterator(x, coords) + + logger.info("Inference starting!") + with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar: + for step, (x, coords) in enumerate(model): + # Subselect domain/variables as indicated in output_coords + x, coords = map_coords(x, coords, output_coords) + io.write(*split_coords(x, coords)) + pbar.update(1) + + # Save checkpoint if needed + if ( + should_checkpoint( + current_step, checkpoint_interval, checkpoint_path + ) + and checkpoint_path is not None + ): + save_checkpoint(step, x, coords, checkpoint_path, "deterministic") + logger.info(f"Saved checkpoint at step {step}") + + if step == nsteps: + break - logger.success("Inference complete") - return io + logger.success("Inference complete") + return io # sphinx - diagnostic start @@ -157,10 +254,14 @@ def diagnostic( io: IOBackend, output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, + checkpoint_path: str | None = None, + checkpoint_interval: int | None = None, + resume_from_step: int | None = None, ) -> IOBackend: """Built in diagnostic workflow. - This workflow creates a determinstic inference pipeline that couples a prognostic - model with a diagnostic model. + This workflow creates a deterministic inference pipeline that couples a prognostic + model with a diagnostic model. Supports saving and resuming from checkpoints to handle + GPU memory constraints in long-running simulations. Parameters ---------- @@ -180,12 +281,41 @@ def diagnostic( IO output coordinate system override, by default OrderedDict({}) device : torch.device, optional Device to run inference on, by default None + checkpoint_path : str, optional + Path to save/load checkpoints, by default None + checkpoint_interval : int, optional + Save checkpoint every N steps, by default None + resume_from_step : int, optional + Resume from this step number, by default None Returns ------- IOBackend Output IO object + + + Examples + -------- + + Basic usage without checkpointing: + >>> io = diagnostic(time, nsteps, prognostic_model, diagnostic_model, data, io_backend) + + Save checkpoints every 5 steps: + >>> io = diagnostic(time, nsteps, prognostic_model, diagnostic_model, data, io_backend, + checkpoint_path="checkpoint.pt", checkpoint_interval=5) + + Resume from step 10: + >>> io = diagnostic(time, nsteps, prognostic_model, diagnostic_model, data, io_backend, + checkpoint_path="checkpoint.pt", resume_from_step=10) """ + + from earth2studio.utils.checkpoint import ( + load_checkpoint, + save_checkpoint, + should_checkpoint, + validate_checkpoint_compatibility, + ) + # sphinx - diagnostic end logger.info("Running diagnostic workflow!") # Load model onto the device @@ -197,27 +327,51 @@ def diagnostic( logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) diagnostic = diagnostic.to(device) - # Fetch data from data source and load onto device - prognostic_ic = prognostic.input_coords() + diagnostic_ic = diagnostic.input_coords() - time = to_time_array(time) - if hasattr(prognostic, "interp_method"): - interp_to = prognostic_ic - interp_method = prognostic.interp_method + + if resume_from_step is not None: + if checkpoint_path is None: + raise ValueError( + "checkpoint_path must be provided when resume_from_step is specified" + ) + checkpoint = load_checkpoint(checkpoint_path, device) + logger.info(f"Resuming from checkpoint at step {resume_from_step}") + + if not validate_checkpoint_compatibility(checkpoint["coords"], prognostic): + raise ValueError("Checkpoint incompatible with current prognostic model") + + x, coords = checkpoint["state"], checkpoint["coords"] + start_step = resume_from_step + + logger.success("Resumed from checkpoint, skipping data fetch") + else: - interp_to = None - interp_method = "nearest" - - x, coords = fetch_data( - source=data, - time=time, - variable=prognostic_ic["variable"], - lead_time=prognostic_ic["lead_time"], - device=device, - interp_to=interp_to, - interp_method=interp_method, - ) - logger.success(f"Fetched data from {data.__class__.__name__}") + + # Normal initialization - fetch from data source + # Fetch data from data source and load onto device + prognostic_ic = prognostic.input_coords() + diagnostic_ic = diagnostic.input_coords() + time = to_time_array(time) + if hasattr(prognostic, "interp_method"): + interp_to = prognostic_ic + interp_method = prognostic.interp_method + else: + interp_to = None + interp_method = "nearest" + + x, coords = fetch_data( + source=data, + time=time, + variable=prognostic_ic["variable"], + lead_time=prognostic_ic["lead_time"], + device=device, + interp_to=interp_to, + interp_method=interp_method, + ) + start_step = 0 + + logger.success(f"Fetched data from {data.__class__.__name__}") # Set up IO backend total_coords = prognostic.output_coords(prognostic.input_coords()) @@ -244,27 +398,64 @@ def diagnostic( io.add_array(total_coords, var_names) # Map lat and lon if needed + prognostic_ic = prognostic.input_coords() x, coords = map_coords(x, coords, prognostic_ic) - # Create prognostic iterator - model = prognostic.create_iterator(x, coords) + if resume_from_step is not None: + # CHECKPOINT RESUME PATH - Manual time-stepping + logger.info("Using manual time-stepping for checkpointed diagnostic run") + prognostic_ic = prognostic.input_coords() - logger.info("Inference starting!") - with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar: - for step, (x, coords) in enumerate(model): + with tqdm( + total=(nsteps + 1) - start_step, desc="Running inference", position=1 + ) as pbar: + for current_step in range(start_step, nsteps + 1): + # Run diagnostic on current state + x_diag, coords_diag = map_coords(x, coords, diagnostic_ic) + x_diag, coords_diag = diagnostic(x_diag, coords_diag) + + # Output the diagnostic result + x_out, coords_out = map_coords(x_diag, coords_diag, output_coords) + io.write(*split_coords(x_out, coords_out)) + pbar.update(1) + + if ( + should_checkpoint( + current_step, checkpoint_interval, checkpoint_path + ) + and checkpoint_path is not None + ): + save_checkpoint( + current_step, x, coords, checkpoint_path, "diagnostic" + ) + logger.info(f"Saved checkpoint at step {current_step}") + + if current_step < nsteps: + x, coords = prognostic(x, coords) - # Run diagnostic - x, coords = map_coords(x, coords, diagnostic_ic) - x, coords = diagnostic(x, coords) - # Subselect domain/variables as indicated in output_coords - x, coords = map_coords(x, coords, output_coords) - io.write(*split_coords(x, coords)) - pbar.update(1) - if step == nsteps: - break + logger.success("Inference complete") + return io + + else: + # Create prognostic iterator + model = prognostic.create_iterator(x, coords) + + logger.info("Inference starting!") + with tqdm(total=nsteps + 1, desc="Running inference", position=1) as pbar: + for step, (x, coords) in enumerate(model): + + # Run diagnostic + x, coords = map_coords(x, coords, diagnostic_ic) + x, coords = diagnostic(x, coords) + # Subselect domain/variables as indicated in output_coords + x, coords = map_coords(x, coords, output_coords) + io.write(*split_coords(x, coords)) + pbar.update(1) + if step == nsteps: + break - logger.success("Inference complete") - return io + logger.success("Inference complete") + return io # sphinx - ensemble start @@ -279,8 +470,14 @@ def ensemble( batch_size: int | None = None, output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, + checkpoint_path: str | None = None, + checkpoint_interval: int | None = None, + resume_from_step: int | None = None, ) -> IOBackend: """Built in ensemble workflow. + This workflow creates multiple forecast runs with perturbed initial conditions. + Supports saving and resuming from checkpoints to handle GPU memory constraints + in large ensemble simulations. Parameters ---------- @@ -296,21 +493,48 @@ def ensemble( Data source io : IOBackend IO object - perturbation_method : Perturbation + perturbation: Perturbation Method to perturb the initial condition to create an ensemble. batch_size: int, optional - Number of ensemble members to run in a single batch, - by default None. + Number of ensemble members to run in a single batch, by default None. output_coords: CoordSystem, optional IO output coordinate system override, by default OrderedDict({}) device : torch.device, optional Device to run inference on, by default None + checkpoint_path : str, optional + Path to save/load checkpoints, by default None + checkpoint_interval : int, optional + Save checkpoint every N steps, by default None + resume_from_step : int, optional + Resume from this step number, by default None Returns ------- IOBackend Output IO object + + + Examples + -------- + Basic usage without checkpointing: + >>> io = ensemble(time, nsteps, nensemble, prognostic_model, data, io_backend, perturbation) + + Save checkpoints every 5 steps: + >>> io = ensemble(time, nsteps, nensemble, prognostic_model, data, io_backend, perturbation, + ... checkpoint_path="checkpoint.pt", checkpoint_interval=5) + + Resume from step 10: + >>> io = ensemble(time, nsteps, nensemble, prognostic_model, data, io_backend, perturbation, + ... checkpoint_path="checkpoint.pt", resume_from_step=10) """ + + from earth2studio.utils.checkpoint import ( + load_checkpoint, + save_checkpoint, + should_checkpoint, + validate_checkpoint_compatibility, + ) + # sphinx - ensemble end logger.info("Running ensemble inference!") @@ -323,26 +547,57 @@ def ensemble( logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) - # Fetch data from data source and load onto device - prognostic_ic = prognostic.input_coords() - time = to_time_array(time) - if hasattr(prognostic, "interp_method"): - interp_to = prognostic_ic - interp_method = prognostic.interp_method + x0: torch.Tensor + coords0: CoordSystem + + if resume_from_step is not None: + if checkpoint_path is None: + raise ValueError( + "checkpoint_path must be provided when resume_from_step is specified" + ) + checkpoint = load_checkpoint(checkpoint_path, device) + logger.info(f"Resuming ensemble from checkpoint at step {resume_from_step}") + + if not validate_checkpoint_compatibility(checkpoint["coords"], prognostic): + raise ValueError("Checkpoint incompatible with current prognostic model") + + # Expect checkpoint to contain all ensemble members + x, coords = checkpoint["state"], checkpoint["coords"] + start_step = resume_from_step + + if coords.get("ensemble") is None or len(coords["ensemble"]) != nensemble: + raise ValueError( + f"Checkpoint ensemble size {len(coords.get('ensemble', []))} does not match requested ensemble size {nensemble}" + ) + + logger.success( + "Resumed ensemble from checkpoint, skipping data fetch perturbation" + ) + else: - interp_to = None - interp_method = "nearest" - - x0, coords0 = fetch_data( - source=data, - time=time, - variable=prognostic_ic["variable"], - lead_time=prognostic_ic["lead_time"], - device=device, - interp_to=interp_to, - interp_method=interp_method, - ) - logger.success(f"Fetched data from {data.__class__.__name__}") + + # Fetch data from data source and load onto device + prognostic_ic = prognostic.input_coords() + time = to_time_array(time) + if hasattr(prognostic, "interp_method"): + interp_to = prognostic_ic + interp_method = prognostic.interp_method + else: + interp_to = None + interp_method = "nearest" + + x0, coords0 = fetch_data( + source=data, + time=time, + variable=prognostic_ic["variable"], + lead_time=prognostic_ic["lead_time"], + device=device, + interp_to=interp_to, + interp_method=interp_method, + ) + start_step = 0 + + logger.success(f"Fetched data from {data.__class__.__name__}") # Set up IO backend with information from output_coords (if applicable). total_coords = prognostic.output_coords(prognostic.input_coords()).copy() @@ -370,55 +625,95 @@ def ensemble( batch_size = min(nensemble, batch_size) number_of_batches = ceil(nensemble / batch_size) - logger.info( - f"Starting {nensemble} Member Ensemble Inference with \ - {number_of_batches} number of batches." - ) - batch_id = 0 - for batch_id in tqdm( - range(0, nensemble, batch_size), - total=number_of_batches, - desc="Total Ensemble Batches", - position=2, - ): - - # Get fresh batch data - x = x0.to(device) - - # Expand x, coords for ensemble - mini_batch_size = min(batch_size, nensemble - batch_id) - coords = { - "ensemble": np.arange(batch_id, batch_id + mini_batch_size) - } | coords0.copy() - - # Unsqueeze x for batching ensemble - x = x.unsqueeze(0).repeat(mini_batch_size, *([1] * x.ndim)) - - # Map lat and lon if needed - x, coords = map_coords(x, coords, prognostic_ic) - - # Perturb ensemble - x, coords = perturbation(x, coords) - - # Create prognostic iterator - model = prognostic.create_iterator(x, coords) + if resume_from_step is not None: + # CHECKPOINT RESUME PATH - Manual time-stepping + logger.info("Using manual time-stepping for checkpointed ensemble run") with tqdm( - total=nsteps + 1, - desc=f"Running batch {batch_id} inference", - position=1, - leave=False, + total=(nsteps + 1) - start_step, desc="Running inference", position=1 ) as pbar: - for step, (x, coords) in enumerate(model): - # Subselect domain/variables as indicated in output_coords - x, coords = map_coords(x, coords, output_coords) + for current_step in range(start_step, nsteps + 1): - io.write(*split_coords(x, coords)) + # Output current ensemble state (all members at once) + x_out, coords_out = map_coords(x, coords, output_coords) + io.write(*split_coords(x_out, coords_out)) pbar.update(1) - if step == nsteps: - break - batch_id += 1 + if ( + should_checkpoint( + current_step, checkpoint_interval, checkpoint_path + ) + and checkpoint_path is not None + ): + save_checkpoint( + current_step, x, coords, checkpoint_path, "ensemble" + ) + logger.info(f"Saved ensemble checkpoint at step {current_step}") - logger.success("Inference complete") - return io + if current_step < nsteps: + x, coords = prognostic(x, coords) + + logger.success("Inference complete") + return io + + else: + logger.info( + f"Starting {nensemble} Member Ensemble Inference with \ + {number_of_batches} number of batches." + ) + batch_id = 0 + for batch_id in tqdm( + range(0, nensemble, batch_size), + total=number_of_batches, + desc="Total Ensemble Batches", + position=2, + ): + + # Get fresh batch data + x = x0.to(device) + + # Expand x, coords for ensemble + mini_batch_size = min(batch_size, nensemble - batch_id) + coords = { + "ensemble": np.arange(batch_id, batch_id + mini_batch_size) + } | coords0.copy() + + # Unsqueeze x for batching ensemble + x = x.unsqueeze(0).repeat(mini_batch_size, *([1] * x.ndim)) + + # Map lat and lon if needed + x, coords = map_coords(x, coords, prognostic_ic) + + # Perturb ensemble + x, coords = perturbation(x, coords) + + # Create prognostic iterator + model = prognostic.create_iterator(x, coords) + + with tqdm( + total=nsteps + 1, + desc=f"Running batch {batch_id} inference", + position=1, + leave=False, + ) as pbar: + for step, (x, coords) in enumerate(model): + # Subselect domain/variables as indicated in output_coords + x, coords = map_coords(x, coords, output_coords) + + io.write(*split_coords(x, coords)) + pbar.update(1) + + if batch_id == 0 and should_checkpoint( + step, checkpoint_interval, checkpoint_path + ): + logger.warning( + "Ensemble checkpointing in batched mode requires manual time-stepping - use resume_from_step for full functionality" + ) + + if step == nsteps: + break + + batch_id += 1 + + logger.success("Inference complete") + return io diff --git a/earth2studio/utils/checkpoint.py b/earth2studio/utils/checkpoint.py new file mode 100644 index 000000000..9f25e7ed8 --- /dev/null +++ b/earth2studio/utils/checkpoint.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any + +import torch + +from earth2studio.models.px import PrognosticModel +from earth2studio.utils.coords import CoordSystem + + +def save_checkpoint( + step: int, + state: torch.Tensor, + coords: CoordSystem, + checkpoint_path: str, + workflow_type: str = "deterministic", + metadata: dict[str, Any] | None = None, +) -> None: + """Save workflow checkpoint to disk. + Parameters + ---------- + + step : int + Current simulation step number + state : torch.Tensor + Current atmospheric state tensor on GPU + coords : CoordSystem + Current coordinate system (OrderedDict or coordinate arrays) + checkpoint_path : str + File path where checkpoint will be saved + workflow_type : str, optional + Type of workflow being checkpointed, be default "deterministic" + metadata : dict[str, Any], optional + Additional metadata to store with checkpoint, by default None + """ + + checkpoint = { + "step": step, + "state": state, + "coords": coords, + "workflow_type": workflow_type, + "torch_rng_state": torch.get_rng_state(), + "metadata": metadata or {}, + } + + if torch.cuda.is_available(): + checkpoint["cuda_rng_state"] = torch.cuda.get_rng_state() + + Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True) + torch.save(checkpoint, checkpoint_path) + + +def load_checkpoint(checkpoint_path: str, device: torch.device) -> dict[str, Any]: + """Load workflow checkpoint from disk.""" + + if not Path(checkpoint_path).exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + if "torch_rng_state" in checkpoint: + torch.set_rng_state(checkpoint["torch_rng_state"].cpu()) + + if "cuda_rng_state" in checkpoint and torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint["cuda_rng_state"].cpu()) + + return checkpoint + + +def validate_checkpoint_compatibility( + checkpoint_coords: CoordSystem, prognostic: PrognosticModel +) -> bool: + """Validate that checkpoint is compatible with prognostic model.""" + try: + expected_coords = prognostic.input_coords() + + for key in expected_coords: + if key not in checkpoint_coords: + return False + if key == "batch": + continue + if expected_coords[key].shape != checkpoint_coords[key].shape: + return False + + return True + except Exception: + return False + + +def should_checkpoint( + step: int, + checkpoint_interval: int | None, + checkpoint_path: str | None, +) -> bool: + """Determine if checkpoint should be saved at current step.""" + return ( + checkpoint_path is not None + and checkpoint_interval is not None + and step % checkpoint_interval == 0 + ) diff --git a/test/utils/test_checkpoint.py b/test/utils/test_checkpoint.py new file mode 100644 index 000000000..448913e66 --- /dev/null +++ b/test/utils/test_checkpoint.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +from collections import OrderedDict + +import numpy as np +import pytest +import torch + +from earth2studio.utils.checkpoint import ( + load_checkpoint, + save_checkpoint, + should_checkpoint, + validate_checkpoint_compatibility, +) + + +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +def test_checkpoint_save_load_cycle(device): + """Test complete save/load cycle preserves data""" + if device == "cuda:0" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = torch.device(device) + + # Create test data + x = torch.randn(2, 3, 4, 5).to(device) + coords = OrderedDict( + { + "batch": np.array([0, 1]), + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + checkpoint_path = f.name + + try: + # Save checkpoint + save_checkpoint(10, x, coords, checkpoint_path, "deterministic") + + # Load checkpoint + loaded = load_checkpoint(checkpoint_path, device) + + # Verify data integrity + assert loaded["step"] == 10 + assert loaded["workflow_type"] == "deterministic" + assert torch.allclose(loaded["state"], x) + + # Verify coordinates + loaded_coords = loaded["coords"] + assert loaded_coords.keys() == coords.keys() + for key in coords.keys(): + np.testing.assert_array_equal(loaded_coords[key], coords[key]) + + finally: + os.unlink(checkpoint_path) + + +@pytest.mark.parametrize( + "step,interval,path,expected", + [ + (5, None, None, False), + (5, 10, None, False), + (5, None, "/path", False), + (0, 5, "/path", True), + (5, 5, "/path", True), + (10, 5, "/path", True), + (3, 5, "/path", False), + (7, 5, "/path", False), + ], +) +def test_should_checkpoint(step, interval, path, expected): + """Test checkpoint decision logic""" + assert should_checkpoint(step, interval, path) == expected + + +def test_validate_checkpoint_compatibility(): + """Test checkpoint compatibility validation""" + + # Create mock prognostic model + class MockPrognostic: + def input_coords(self): + return OrderedDict( + { + "batch": np.array([]), + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + prognostic = MockPrognostic() + + # Compatible coordinates (batch can be different size) + compatible_coords = OrderedDict( + { + "batch": np.array([0, 1]), + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + assert validate_checkpoint_compatibility(compatible_coords, prognostic) + + # Incompatible coordinates (wrong variables) + incompatible_coords = OrderedDict( + { + "batch": np.array([0, 1]), + "variable": np.array(["u", "v"]), # Missing 't' + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + assert not validate_checkpoint_compatibility(incompatible_coords, prognostic) + + +@pytest.mark.parametrize("workflow_type", ["deterministic", "diagnostic", "ensemble"]) +def test_checkpoint_workflow_type(workflow_type): + """Test checkpoint saves workflow type correctly""" + device = torch.device("cpu") + x = torch.randn(2, 3, 4, 5) + coords = OrderedDict( + { + "batch": np.array([0, 1]), + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + checkpoint_path = f.name + + try: + save_checkpoint(5, x, coords, checkpoint_path, workflow_type) + loaded = load_checkpoint(checkpoint_path, device) + assert loaded["workflow_type"] == workflow_type + finally: + os.unlink(checkpoint_path) + + +def test_checkpoint_contains_rng_state(): + """Test checkpoint includes RNG states for reproducibility""" + device = torch.device("cpu") + x = torch.randn(2, 3) + coords = OrderedDict( + {"batch": np.array([0, 1]), "variable": np.array(["u", "v", "t"])} + ) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + checkpoint_path = f.name + + try: + save_checkpoint(0, x, coords, checkpoint_path, "deterministic") + loaded = load_checkpoint(checkpoint_path, device) + + # Should contain RNG states + assert "torch_rng_state" in loaded + assert isinstance(loaded["torch_rng_state"], torch.ByteTensor) + + # CUDA RNG state only if CUDA available + if torch.cuda.is_available(): + assert "cuda_rng_state" in loaded + + finally: + os.unlink(checkpoint_path) + + +def test_checkpoint_coordinate_types(): + """Test checkpoint handles different coordinate types""" + device = torch.device("cpu") + x = torch.randn(2, 3, 4) + coords = OrderedDict( + { + "batch": np.array([0, 1]), + "variable": np.array(["u", "v", "t"]), + "time": np.array( + [ + np.datetime64("2024-01-01"), + np.datetime64("2024-01-02"), + np.datetime64("2024-01-03"), + np.datetime64("2024-01-04"), + ] + ), + } + ) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + checkpoint_path = f.name + + try: + save_checkpoint(0, x, coords, checkpoint_path, "deterministic") + loaded = load_checkpoint(checkpoint_path, device) + + # Verify datetime handling + loaded_coords = loaded["coords"] + np.testing.assert_array_equal(loaded_coords["time"], coords["time"]) + + finally: + os.unlink(checkpoint_path) + + +def test_load_checkpoint_file_not_found(): + """Test load_checkpoint raises FileNotFoundError for missing files""" + device = torch.device("cpu") + + with tempfile.NamedTemporaryFile(delete=False) as f: + nonexistent_path = f.name + "_nonexistent" + + with pytest.raises(FileNotFoundError): + load_checkpoint(nonexistent_path, device) + + +def test_validate_checkpoint_compatibility_missing_dims(): + """Test validation fails for missing dimensions""" + + class MockPrognostic: + def input_coords(self): + return OrderedDict( + { + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + prognostic = MockPrognostic() + + # Missing required dimension + incomplete_coords = OrderedDict( + { + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + # Missing 'lon' + } + ) + + assert not validate_checkpoint_compatibility(incomplete_coords, prognostic) + + +def test_validate_checkpoint_compatibility_shape_mismatch(): + """Test validation fails for shape mismatches""" + + class MockPrognostic: + def input_coords(self): + return OrderedDict( + { + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 4), + "lon": np.linspace(0, 360, 5), + } + ) + + prognostic = MockPrognostic() + + # Wrong shape for lat dimension + mismatched_coords = OrderedDict( + { + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 6), # Different size + "lon": np.linspace(0, 360, 5), + } + ) + + assert not validate_checkpoint_compatibility(mismatched_coords, prognostic) + + +@pytest.mark.parametrize("interval", [1, 2, 5, 10]) +def test_checkpoint_interval_patterns(interval): + """Test checkpoint saving at different intervals""" + path = "/dummy/path" + + # Test steps 0-20 + expected_saves = [i for i in range(0, 21, interval)] + + actual_saves = [ + step for step in range(21) if should_checkpoint(step, interval, path) + ] + + assert actual_saves == expected_saves + + +def test_checkpoint_ensemble_coordinates(): + """Test checkpoint handling of ensemble coordinates""" + device = torch.device("cpu") + x = torch.randn(4, 3, 8, 16) # 4 ensemble members + coords = OrderedDict( + { + "ensemble": np.array([0, 1, 2, 3]), + "variable": np.array(["u", "v", "t"]), + "lat": np.linspace(-90, 90, 8), + "lon": np.linspace(0, 360, 16), + } + ) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + checkpoint_path = f.name + + try: + save_checkpoint(2, x, coords, checkpoint_path, "ensemble") + loaded = load_checkpoint(checkpoint_path, device) + + # Verify ensemble dimension preserved + assert "ensemble" in loaded["coords"] + assert len(loaded["coords"]["ensemble"]) == 4 + np.testing.assert_array_equal(loaded["coords"]["ensemble"], coords["ensemble"]) + + finally: + os.unlink(checkpoint_path) + + +def test_checkpoint_metadata_structure(): + """Test checkpoint contains expected metadata fields""" + device = torch.device("cpu") + x = torch.randn(2, 3) + coords = OrderedDict( + {"batch": np.array([0, 1]), "variable": np.array(["u", "v", "t"])} + ) + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + checkpoint_path = f.name + + try: + save_checkpoint(7, x, coords, checkpoint_path, "diagnostic") + loaded = load_checkpoint(checkpoint_path, device) + + # Check all required fields + required_fields = [ + "step", + "state", + "coords", + "workflow_type", + "torch_rng_state", + ] + for field in required_fields: + assert field in loaded, f"Missing required field: {field}" + + # Check field types + assert isinstance(loaded["step"], int) + assert isinstance(loaded["state"], torch.Tensor) + assert isinstance(loaded["coords"], OrderedDict) + assert isinstance(loaded["workflow_type"], str) + + finally: + os.unlink(checkpoint_path)