11import logging
22import random
33import time
4+ from pathlib import Path
45
56import hydra
67import numpy as np
78import torch
9+ from hydra .core .hydra_config import HydraConfig
810from omegaconf import DictConfig
911from torch .utils .data import DataLoader
1012from torch .nn .functional import mse_loss
1517from ddr .dataset .utils import downsample
1618from ddr .dataset .streamflow import StreamflowReader as streamflow
1719from ddr .dataset .train_dataset import train_dataset
20+ from ddr .analysis .metrics import Metrics
21+ from ddr .analysis .plots import plot_time_series
22+ from ddr .analysis .utils import save_state
1823
1924log = logging .getLogger (__name__ )
2025
@@ -39,22 +44,46 @@ def train(cfg, flow, routing_model, nn):
3944 drop_last = True ,
4045 )
4146
42- optimizer = torch .optim .Adam (params = nn .parameters (), lr = cfg .train .learning_rate [str (0 )])
47+ if cfg .train .spatial_checkpoint :
48+ file_path = Path (cfg .train .spatial_checkpoint )
49+ log .info (f"Loading spatial_nn from checkpoint: { file_path .stem } " )
50+ state = torch .load (file_path )
51+ state_dict = state ["model_state_dict" ]
52+ for key in state_dict .keys ():
53+ state_dict [key ] = state_dict [key ].to (cfg .device )
54+ nn .load_state_dict (state ["model_state_dict" ])
55+ torch .set_rng_state (state ["rng_state" ])
56+ start_epoch = state ["epoch" ]
57+ # start_mini_batch = 0 if state["mini_batch"] == 0 else state["mini_batch"] + 1 # Start from the next mini-batch
58+ if torch .cuda .is_available () and "cuda_rng_state" in state :
59+ torch .cuda .set_rng_state (state ["cuda_rng_state" ])
60+ if start_epoch in cfg .train .learning_rate .keys ():
61+ lr = cfg .train .learning_rate [start_epoch ]
62+ else :
63+ key_list = list (cfg .train .learning_rate .keys ())
64+ lr = cfg .train .learning_rate [key_list [- 1 ]]
65+ else :
66+ log .info ("Creating new spatial model" )
67+ start_epoch = 1
68+ # start_mini_batch = 0
69+ lr = cfg .train .learning_rate [str (0 )]
4370
44- for epoch in range (0 , cfg .train .epochs + 1 ):
71+ optimizer = torch .optim .Adam (params = nn .parameters (), lr = lr )
72+
73+ for epoch in range (start_epoch , cfg .train .epochs + 1 ):
4574 routing_model .epoch = epoch
4675 for i , hydrofabric in enumerate (dataloader , start = 0 ):
4776 routing_model .mini_batch = i
4877
4978 streamflow_predictions = flow (cfg = cfg , hydrofabric = hydrofabric )
50- q_prime = streamflow_predictions ["streamflow" ] @ torch . tensor ( hydrofabric .transition_matrix . to_numpy (), dtype = torch . float32 , device = cfg . device )
79+ q_prime = streamflow_predictions ["streamflow" ] @ hydrofabric .transition_matrix
5180 spatial_params = nn (
5281 inputs = hydrofabric .normalized_spatial_attributes .to (cfg .device )
5382 )
5483 dmc_kwargs = {
5584 "hydrofabric" : hydrofabric ,
5685 "spatial_parameters" : spatial_params ,
57- "streamflow" : q_prime ,
86+ "streamflow" : torch . tensor ( q_prime , device = cfg . device , dtype = torch . float32 )
5887 }
5988 dmc_output = routing_model (** dmc_kwargs )
6089
@@ -68,24 +97,58 @@ def train(cfg, flow, routing_model, nn):
6897 np_nan_mask = nan_mask .streamflow .values
6998
7099 filtered_ds = hydrofabric .observations .where (~ nan_mask , drop = True )
71- filtered_observations = torch .tensor (filtered_ds .streamflow .values , device = cfg .device )[
100+ filtered_observations = torch .tensor (filtered_ds .streamflow .values , device = cfg .device , dtype = torch . float32 )[
72101 :, 1 :- 1
73102 ] # Cutting off days to match with realigned timesteps
74103
75104 filtered_predictions = daily_runoff [~ np_nan_mask ]
76105
77106 loss = mse_loss (
78- input = filtered_predictions .transpose (0 , 1 )[cfg .warmup :].unsqueeze (2 ),
79- target = filtered_observations .transpose (0 , 1 )[cfg .warmup :].unsqueeze (2 ),
107+ input = filtered_predictions .transpose (0 , 1 )[cfg .train . warmup :].unsqueeze (2 ),
108+ target = filtered_observations .transpose (0 , 1 )[cfg .train . warmup :].unsqueeze (2 ),
80109 )
81110
82- log .info ("Running gradient-averaged backpropagation" )
111+ log .info ("Running backpropagation" )
83112
84113 loss .backward ()
85114 optimizer .step ()
86115 optimizer .zero_grad ()
87116
88- print (f"Loss: { loss .item } " )
117+ np_pred = filtered_predictions .detach ().cpu ().numpy ()
118+ np_target = filtered_observations .detach ().cpu ().numpy ()
119+ plotted_dates = dataset .dates .batch_daily_time_range [
120+ 1 :- 1
121+ ]
122+ metrics = Metrics (pred = np_pred , target = np_target )
123+ pred_nse = metrics .nse
124+ pred_nse_filtered = pred_nse [~ np .isinf (pred_nse ) & ~ np .isnan (pred_nse )]
125+ median_nse = torch .tensor (pred_nse_filtered ).median ()
126+
127+ # TODO: scale out when we have more gauges
128+ # random_index = np.random.randint(low=0, high=filtered_observations.shape[0], size=(1,))[0]
129+ random_gage = - 1
130+ plot_time_series (
131+ filtered_predictions [- 1 ].detach ().cpu ().numpy (),
132+ filtered_observations [- 1 ].cpu ().numpy (),
133+ plotted_dates ,
134+ dataset .obs_reader .gage_dict ["STAID" ][random_gage ],
135+ dataset .obs_reader .gage_dict ["STANAME" ][random_gage ],
136+ metrics = {"nse" : pred_nse [- 1 ]},
137+ path = cfg .params .save_path / f"plots/epoch_{ epoch } _mb_{ i } _validation_plot.png" ,
138+ warmup = cfg .train .warmup ,
139+ )
140+
141+ save_state (
142+ epoch = epoch ,
143+ mini_batch = i ,
144+ mlp = nn ,
145+ optimizer = optimizer ,
146+ name = cfg .name ,
147+ saved_model_path = cfg .params .save_path / "saved_models" ,
148+ )
149+
150+ print (f"Loss: { loss .item ()} " )
151+ print (f"Median NSE: { median_nse } " )
89152
90153 if epoch in cfg .train .learning_rate .keys ():
91154 log .info (f"Updating learning rate: { cfg .train .learning_rate [epoch ]} " )
@@ -101,6 +164,9 @@ def train(cfg, flow, routing_model, nn):
101164)
102165def main (cfg : DictConfig ) -> None :
103166 _set_seed (cfg = cfg )
167+ cfg .params .save_path = Path (HydraConfig .get ().run .dir )
168+ (cfg .params .save_path / "plots" ).mkdir (exist_ok = True )
169+ (cfg .params .save_path / "saved_models" ).mkdir (exist_ok = True )
104170 try :
105171 start_time = time .perf_counter ()
106172 nn = kan (
@@ -111,7 +177,8 @@ def main(cfg: DictConfig) -> None:
111177 num_hidden_layers = cfg .kan .num_hidden_layers ,
112178 grid = cfg .kan .grid ,
113179 k = cfg .kan .k ,
114- seed = cfg .seed
180+ seed = cfg .seed ,
181+ device = cfg .device
115182 )
116183 routing_model = dmc (
117184 cfg = cfg ,
0 commit comments