Skip to content

Commit 675bf7d

Browse files
committed
Cleanup and documentation
1 parent e2fb707 commit 675bf7d

File tree

5 files changed

+210
-32
lines changed

5 files changed

+210
-32
lines changed

gnn_model/evaluations.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,49 @@
1+
"""
2+
OCELOT Model Evaluation and Visualization Suite
3+
4+
Author: Azadeh Gholoubi
5+
Organization: NOAA/NCEP/EMC
6+
7+
Description:
8+
This module provides comprehensive evaluation and visualization tools for the OCELOT
9+
(Observation-Centered Earth Learning Observation Transformer) graph neural network model.
10+
It includes specialized diagnostics for multiple observation types including satellite
11+
radiances, conventional surface observations, and atmospheric soundings (radiosondes).
12+
13+
Key Features:
14+
- Multi-panel geospatial visualization of predictions vs. truth with difference maps
15+
- Radiosonde-specific diagnostics including vertical profile analysis by atmospheric layer
16+
- Pressure-stratified error analysis (surface, mid-troposphere, upper atmosphere)
17+
- Statistical metrics: RMSE, bias, R², sMAPE, MAE with percentile-based robust estimates
18+
- Quality control integration with observation masking
19+
- Channel-wise evaluation for multi-channel instruments (ATMS, AMSU-A, SSMIS, etc.)
20+
21+
Main Functions:
22+
- plot_ocelot_target_diff: 3-panel maps showing prediction, truth, and difference
23+
- plot_radiosonde_by_layer: Scatter plots stratified by atmospheric layers
24+
- plot_radiosonde_pressure_distribution: Vertical distribution of observations
25+
- plot_radiosonde_error_vs_pressure: Error profiles as function of pressure
26+
- print_radiosonde_layer_stats: Tabular statistics for each atmospheric layer
27+
- plot_instrument_maps: Generic instrument evaluation with geographic context
28+
29+
Technical Notes:
30+
- Handles missing data and QC-failed observations through mask columns
31+
- Uses robust statistics (percentiles) to handle outliers
32+
- Supports multiple projection systems via Cartopy
33+
- Generates publication-quality figures with proper units and annotations
34+
- Implements symmetric difference scaling for intuitive visualization
35+
36+
Usage:
37+
Called during model validation/testing phases to generate diagnostic plots
38+
and statistical summaries. Reads CSV files exported during validation and
39+
produces figures in the specified output directory.
40+
41+
Dependencies:
42+
- numpy, pandas: Data manipulation
43+
- matplotlib: Plotting and visualization
44+
- cartopy: Geospatial projections and map features
45+
"""
46+
147
import os
248
import numpy as np
349
import pandas as pd
@@ -237,9 +283,6 @@ def plot_radiosonde_by_layer(
237283
fig_dir: str = PLOT_DIR,
238284
):
239285
"""
240-
Radiosonde-specific evaluation with layer stratification.
241-
Validates FIX 1, 2, 3 are working correctly.
242-
243286
Creates scatter plots of predicted vs true values for each atmospheric layer:
244287
- Surface (850-1200 hPa)
245288
- Mid-troposphere (400-850 hPa)
@@ -255,12 +298,12 @@ def plot_radiosonde_by_layer(
255298

256299
os.makedirs(fig_dir, exist_ok=True)
257300

258-
# Check FIX 2: pressure_normalized should be present
301+
# pressure_normalized should be present
259302
if 'pressure_normalized' in df.columns:
260-
print("✅ FIX 2 ACTIVE: pressure_normalized metadata found")
303+
print("pressure_normalized metadata found")
261304
else:
262-
print("⚠️ WARNING: pressure_normalized not found in CSV!")
263-
print(" Note: FIX 2 may still be active during training, but metadata")
305+
print("WARNING: pressure_normalized not found in CSV!")
306+
print(" Note: may still be active during training, but metadata")
264307
print(" is not exported to validation CSVs. This is expected.")
265308

266309
# Define layers matching FIX 3 configuration
@@ -363,7 +406,7 @@ def plot_radiosonde_pressure_distribution(
363406
):
364407
"""
365408
Visualize distribution of radiosonde observations across pressure levels.
366-
Validates FIX 1 (nearest matching) is working - should see ~80% retention.
409+
Validates (nearest matching) is working - should see ~80% retention.
367410
"""
368411
filepath = f"{data_dir}/val_radiosonde_target_epoch{epoch}_batch{batch_idx}_step0.csv"
369412
try:
@@ -387,7 +430,7 @@ def plot_radiosonde_pressure_distribution(
387430

388431
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
389432
fig.suptitle(f'Radiosonde Pressure Distribution - Epoch {epoch}\n'
390-
f'Validates FIX 1 (Nearest-Level Matching)', fontsize=14)
433+
f'Validates Nearest-Level Matching', fontsize=14)
391434

392435
# Histogram of pressure distribution
393436
axes[0].hist(pressure_valid, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
@@ -451,9 +494,9 @@ def plot_radiosonde_pressure_distribution(
451494
retention_pct = 100 * total / total_rows
452495
print(f" Data retention: {retention_pct:.1f}% ({total:,}/{total_rows:,})")
453496
if retention_pct > 70:
454-
print(f" FIX 1 WORKING: High retention (~{retention_pct:.0f}% vs expected ~80%)")
497+
print(f" FIX 1 WORKING: High retention (~{retention_pct:.0f}% vs expected ~80%)")
455498
else:
456-
print(f" ⚠️ WARNING: Low retention ({retention_pct:.1f}%), FIX 1 may not be active")
499+
print(f" WARNING: Low retention ({retention_pct:.1f}%), FIX 1 may not be active")
457500

458501

459502
def print_radiosonde_layer_stats(
@@ -485,7 +528,7 @@ def print_radiosonde_layer_stats(
485528

486529
print(f"\n{'='*80}")
487530
print(f"RADIOSONDE VALIDATION BY LAYER - Epoch {epoch}")
488-
print(f"Validates FIX 3: Level-Specific Normalization")
531+
print(f"Validates: Level-Specific Normalization")
489532
print(f"{'='*80}")
490533

491534
for fname in ['airTemperature', 'dewPointTemperature', 'airPressure']:
@@ -535,19 +578,19 @@ def print_radiosonde_layer_stats(
535578

536579
# WARNING: Check for catastrophically bad results
537580
if r2 < 0:
538-
print(f" ⚠️ WARNING: Negative R² indicates predictions worse than mean!")
581+
print(f" WARNING: Negative R² indicates predictions worse than mean!")
539582
if rmse > 10 and fname in ['airTemperature', 'dewPointTemperature']:
540-
print(f" ⚠️ WARNING: Very high RMSE (>{rmse:.1f}K) - check denormalization!")
583+
print(f" WARNING: Very high RMSE (>{rmse:.1f}K) - check denormalization!")
541584

542585
print(f"\n{'='*80}")
543-
print("If FIX 3 is working correctly, RMSE should be similar across all layers")
586+
print("If FIX 3 is working correctly, RMSE should be similar across all layers")
544587
print(" (uniform error means layer-specific normalization is effective)")
545-
print(f"\n📊 EXPECTED PERFORMANCE (with all 3 fixes working):")
588+
print(f"\n EXPECTED PERFORMANCE (with all 3 fixes working):")
546589
print(f" • airTemperature RMSE: ~2.0-2.5 K per layer")
547590
print(f" • dewPointTemperature RMSE: ~2.5-3.0 K per layer")
548591
print(f" • airPressure RMSE: ~12-20 hPa per layer")
549592
print(f" • R² values: >0.98 for all layers")
550-
print(f"\n⚠️ If seeing very high RMSE or negative R² values:")
593+
print(f"\n If seeing very high RMSE or negative R² values:")
551594
print(f" 1. Check if model was trained WITH the fixes enabled")
552595
print(f" 2. Verify observation_config.yaml has all 3 fixes active")
553596
print(f" 3. Check process_timeseries.py has FIX 1, 2, 3 implemented")

gnn_model/gnn_model.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,59 @@
1+
"""
2+
OCELOT Graph Neural Network Model Architecture
3+
4+
Author: Azadeh Gholoubi
5+
Organization: NOAA/NCEP/EMC
6+
7+
Description:
8+
Core GNN architecture for the OCELOT model. Implements a heterogeneous graph neural network with PyTorch Lightning
9+
for processing multi-instrument observational data on an icosahedral mesh structure.
10+
11+
Architecture Components:
12+
- Encoder: Maps heterogeneous observations to a common latent space on mesh nodes
13+
* Supports InteractionNet or BipartiteGAT (Graph Attention Network)
14+
* Configurable layers, attention heads, and dropout
15+
16+
- Processor: Propagates information across the mesh graph
17+
* InteractionNet: Message passing with edge features
18+
* SlidingWindowTransformerProcessor: Temporal attention mechanism
19+
* Configurable depth, attention heads, and dropout
20+
21+
- Decoder: Projects mesh representations back to observation space for predictions
22+
* Supports InteractionNet or BipartiteGAT
23+
* Inverse-distance weighted aggregation for multi-connectivity
24+
* Configurable layers and attention heads
25+
26+
Key Features:
27+
- Multi-instrument support (satellites, surface stations, radiosondes)
28+
- Weighted loss function with per-instrument and per-channel weights
29+
- Latent rollout for multi-step predictions
30+
- Gradient checkpointing for memory efficiency
31+
- Distributed training with PyTorch Lightning DDP
32+
- Automatic mixed precision (FP16) support
33+
- Comprehensive validation and CSV export for analysis
34+
35+
Model Pipeline:
36+
1. Encode: Observation features → Mesh latent representations
37+
2. Process: Message passing on mesh graph (multiple steps)
38+
3. Decode: Mesh → Target predictions at observation locations
39+
4. Loss: Weighted Huber loss with instrument/channel priorities
40+
5. Rollout: Optional multi-step autoregressive predictions
41+
42+
Training Details:
43+
- Optimizer: Adam with configurable learning rate
44+
- Loss: Weighted Huber loss (robust to outliers)
45+
- Regularization: LayerNorm, Dropout, gradient clipping
46+
- Monitoring: Training/validation losses, per-instrument metrics
47+
- Checkpointing: Model state, optimizer state, epoch tracking
48+
- CSV Export: Predictions, targets, masks for offline evaluation
49+
50+
Performance Optimizations:
51+
- Gradient checkpointing to reduce memory usage
52+
- Mixed precision training (FP16)
53+
- Efficient scatter operations for aggregation
54+
- Distributed data parallel (DDP) for multi-GPU training
55+
"""
56+
157
import lightning.pytorch as pl
258
import os
359
import time

gnn_model/process_timeseries.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,40 @@
1+
"""
2+
Observation Data Processing and Feature Extraction Pipeline
3+
4+
Author: Azadeh Gholoubi
5+
Organization: NOAA/NCEP/EMC
6+
7+
Description:
8+
Core data processing module for the OCELOT (Observation-Centered Earth Learning
9+
Observation Transformer) GNN model. Handles time-series binning, feature extraction,
10+
normalization, and quality control for multi-instrument observational datasets including
11+
satellites (ATMS, AMSU-A, SSMIS, SEVIRI, AVHRR, ASCAT) and conventional observations
12+
(surface stations, radiosondes).
13+
14+
Key Functions:
15+
- organize_bins_times: Temporal binning of observations into input-target pairs with
16+
support for latent rollout (multiple target sub-windows)
17+
- extract_features: Extracts and normalizes features from zarr files with instrument-
18+
specific QC, metadata handling, and level-specific normalization for radiosondes
19+
- _normalize_by_level_groups: Pressure-stratified normalization for atmospheric soundings
20+
21+
Special Features:
22+
- Radiosonde pressure-level matching with configurable tolerance
23+
- Pressure metadata augmentation for vertical context
24+
- Layer-specific normalization (surface/mid/upper atmosphere)
25+
- Chunked scanning for memory-efficient processing of large zarr datasets
26+
- Reproducible subsampling with stable seeding
27+
- Multi-channel support for satellite instruments
28+
- Quality control integration (QC flags, mask propagation)
29+
30+
Technical Details:
31+
- Supports both single-year and multi-year zarr files
32+
- Handles missing data, fill values, and sentinel values
33+
- Implements cosine transformation for cyclic metadata (wind direction)
34+
- Feature normalization using pre-computed statistics
35+
- Efficient indexing with numpy advanced indexing
36+
"""
37+
138
import hashlib
239
import numpy as np
340
import pandas as pd
@@ -293,7 +330,7 @@ def _stats_from_cfg(feature_stats, inst_name, feat_keys):
293330

294331
def _normalize_by_level_groups(features, pressures, feature_stats, inst_name, feat_keys):
295332
"""
296-
FIX 3: Apply level-specific normalization for radiosondes.
333+
Apply level-specific normalization for radiosondes.
297334
Normalizes features separately for different atmospheric layers (surface, mid, upper).
298335
299336
Args:
@@ -371,8 +408,6 @@ def extract_features(z_dict, data_summary, bin_name, observation_config, feature
371408
Adds per-channel masks for inputs and targets so features can be missing independently.
372409
Inputs: keep a row if ANY feature channel is valid; metadata can be missing (imputed later).
373410
Targets: require metadata row to be valid; features may be missing per-channel.
374-
375-
## MODIFIED to support latent rollout (multiple target windows).
376411
"""
377412
print(f"\nProcessing {bin_name}...")
378413
for obs_type in list(data_summary[bin_name].keys()):
@@ -401,7 +436,7 @@ def extract_features(z_dict, data_summary, bin_name, observation_config, feature
401436
matching_mode = level_selection.get("matching_mode", "exact")
402437

403438
if matching_mode == "nearest":
404-
# FIX 1: Nearest-level matching with tolerance
439+
# Nearest-level matching with tolerance
405440
tolerance = level_selection.get("tolerance_hpa", 50)
406441
if input_idx.size:
407442
p_vals = z[col][input_idx]
@@ -420,7 +455,7 @@ def extract_features(z_dict, data_summary, bin_name, observation_config, feature
420455
keep_mask_tg |= (np.abs(p_vals_tg - level) <= tolerance)
421456
target_indices_list[i] = idx[keep_mask_tg]
422457
else:
423-
# Original exact matching (fallback)
458+
# Exact matching (fallback)
424459
if input_idx.size:
425460
input_idx = input_idx[np.isin(z[col][input_idx], levels)]
426461
for i, idx in enumerate(target_indices_list):
@@ -573,7 +608,7 @@ def _get_feature(arrs, name, idx):
573608
# Extract input features
574609
input_features_raw = np.column_stack([_get_feature(z, k, input_idx) for k in feat_keys]).astype(np.float32)
575610

576-
# FIX 2: Separate computed metadata from zarr-based metadata
611+
# Separate computed metadata from zarr-based metadata
577612
# Computed metadata keys that we compute on-the-fly
578613
computed_meta_keys = {'pressure_normalized', 'log_pressure_height'}
579614
zarr_meta_keys = [k for k in meta_keys if k not in computed_meta_keys]
@@ -584,7 +619,7 @@ def _get_feature(arrs, name, idx):
584619
input_lon_raw = z["longitude"][input_idx]
585620
input_times_raw = z["time"][input_idx]
586621

587-
# FIX 2: Add pressure-based metadata for radiosondes/conventional obs
622+
# Add pressure-based metadata for radiosondes/conventional obs
588623
if inst_name == 'radiosonde' and 'airPressure' in feat_keys:
589624
# Check if pressure metadata features are requested
590625
pressure_meta_features = []
@@ -627,13 +662,13 @@ def _get_feature(arrs, name, idx):
627662
target_times_raw_list.append(np.array([], dtype=np.float32))
628663
else:
629664
target_features_raw_list.append(np.column_stack([_get_feature(z, k, target_idx) for k in feat_keys]).astype(np.float32))
630-
# FIX 2: Use only zarr-based metadata keys (filter out computed ones)
665+
# Use only zarr-based metadata keys (filter out computed ones)
631666
target_metadata_raw_list.append(_stack_or_empty(z, zarr_meta_keys, target_idx))
632667
target_lat_raw_list.append(z["latitude"][target_idx])
633668
target_lon_raw_list.append(z["longitude"][target_idx])
634669
target_times_raw_list.append(z["time"][target_idx])
635670

636-
# FIX 2: Add pressure-based metadata for targets (radiosondes)
671+
# Add pressure-based metadata for targets (radiosondes)
637672
if inst_name == 'radiosonde' and 'airPressure' in feat_keys:
638673
pressure_meta_requested = ('pressure_normalized' in meta_keys) or ('log_pressure_height' in meta_keys)
639674

@@ -972,7 +1007,7 @@ def _apply_relational_qc():
9721007
else:
9731008
# Conventional processing (surface_obs, radiosonde)
9741009

975-
# FIX 3: Try level-specific normalization first (for radiosondes)
1010+
# Try level-specific normalization first (for radiosondes)
9761011
input_features_norm = None
9771012
if inst_name == 'radiosonde' and 'airPressure' in feat_keys:
9781013
# Extract pressure values for level grouping
@@ -984,7 +1019,7 @@ def _apply_relational_qc():
9841019
)
9851020

9861021
if input_features_norm is not None:
987-
print(f" [{inst_name}] Using level-specific normalization (FIX 3)")
1022+
print(f" [{inst_name}] Using level-specific normalization")
9881023

9891024
# Fall back to global normalization if level-specific not available
9901025
if input_features_norm is None:
@@ -1034,7 +1069,7 @@ def _apply_relational_qc():
10341069
continue
10351070

10361071
# Target normalization with clipping (conventional style)
1037-
# FIX 3: Try level-specific normalization for radiosondes
1072+
# Try level-specific normalization for radiosondes
10381073
target_features_norm = None
10391074
if inst_name == 'radiosonde' and 'airPressure' in feat_keys:
10401075
# Extract pressure values for level grouping

gnn_model/run_gnn.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ echo "Visible GPUs on this node:"
5757
nvidia-smi
5858

5959
# Launch training (env is propagated to ranks)
60-
srun --export=ALL --kill-on-bad-exit=1 --cpu-bind=cores python train_gnn.py
60+
# srun --export=ALL --kill-on-bad-exit=1 --cpu-bind=cores python train_gnn.py
6161

6262
# Resume training from the latest checkpoint
63-
# srun --export=ALL --kill-on-bad-exit=1 --cpu-bind=cores python train_gnn.py --resume_from_latest
63+
srun --export=ALL --kill-on-bad-exit=1 --cpu-bind=cores python train_gnn.py --resume_from_latest
6464
# srun --export=ALL --kill-on-bad-exit=1 --cpu-bind=cores python train_gnn.py --resume_from_checkpoint checkpoints/last.ckpt

0 commit comments

Comments
 (0)