Skip to content

Commit 424acf6

Browse files
committed
merging conflicts
2 parents 872bb53 + 0d4ac46 commit 424acf6

File tree

15 files changed

+624
-45
lines changed

15 files changed

+624
-45
lines changed

config/training_config.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ data_sources:
1010
conus_hydrofabric: /projects/mhpi/data/hydrofabric/v2.2/conus_nextgen.gpkg
1111
local_hydrofabric: /projects/mhpi/data/hydrofabric/v2.2/jrb_2.gpkg
1212
network: /projects/mhpi/tbindas/ddr/data/network.zarr
13-
transition_matrix: /projects/mhpi/data/hydrofabric/v2.2/jrb_transition_matrix.csv
13+
transition_matrix: /projects/mhpi/data/hydrofabric/v2.2/conus_transition_matrices.zarr
1414
statistics: /projects/mhpi/tbindas/ddr/data/statistics
1515
streamflow: /projects/mhpi/data/MERIT/streamflow/zarr/${forcings}
1616
observations: /projects/mhpi/data/observations/gages_9000.zarr
@@ -21,10 +21,10 @@ train:
2121
start_time: 1981/10/01
2222
end_time: 1995/09/30
2323
checkpoint: null
24-
spatial_checkpoint: null
24+
spatial_checkpoint: /projects/mhpi/tbindas/ddr/runs/0.1.0-ddr_jrb-merit_conus_v6.18_snow/2025-02-19_09-26-46/saved_models/_0.1.0-ddr_jrb-merit_conus_v6.18_snow_epoch_2_mb_0.pt
2525
leakance_checkpoint: null
2626
dropout_threshold: null
27-
epochs: 3
27+
epochs: 100
2828
learning_rate:
2929
'0': 0.005
3030
'3': 0.001
@@ -37,6 +37,8 @@ train:
3737
- 1.0
3838
rho: 365
3939
shuffle: true
40+
warmup: 3
41+
4042
params:
4143
attributes:
4244
- mean.impervious
@@ -60,10 +62,12 @@ params:
6062
- 3.0
6163
defaults:
6264
p: 21
65+
tau: 3
66+
save_path: ./
6367

6468
np_seed: 1
6569
seed: 0
66-
device: cpu # mps:0
70+
device: 0 # mps:0
6771

6872
kan:
6973
hidden_size: 11

engine/weighted_transfer.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
@author Tadd Bindas
5+
6+
@date Febuary 17, 2025
7+
@version 0.2
8+
9+
A script to find the weighted-intersection of merit basins to CONUS catchments
10+
"""
11+
12+
from pathlib import Path
13+
14+
import geopandas as gpd
15+
import numpy as np
16+
import pandas as pd
17+
from scipy import sparse
18+
import zarr
19+
20+
zone = "73"
21+
path_1 = f"/projects/mhpi/data/MERIT/raw/basins/cat_pfaf_{zone}_MERIT_Hydro_v07_Basins_v01_bugfix1.shp"
22+
# path_2 = "/projects/mhpi/data/hydrofabric/v2.2/jrb_2.gpkg"
23+
path_2 = "/projects/mhpi/data/hydrofabric/v2.2/conus_nextgen.gpkg"
24+
out_path = Path(f"/projects/mhpi/data/hydrofabric/v2.2/conus_transition_matrices.zarr")
25+
26+
print("Reading shp files")
27+
gdf1 = gpd.read_file(path_1).set_crs(epsg=4326).to_crs(epsg=5070)
28+
gdf2 = gpd.read_file(path_2, layer="divides").to_crs(epsg=5070)
29+
30+
gdf1['gdf1_orig_area'] = gdf1.geometry.area
31+
gdf2['gdf2_orig_area'] = gdf2.geometry.area
32+
33+
print("Running gdf intersection")
34+
intersection = gpd.overlay(gdf1, gdf2, how='intersection')
35+
intersection['intersection_area'] = intersection.geometry.area
36+
intersection['gdf1_pct'] = (intersection['intersection_area'] / intersection['gdf1_orig_area'])
37+
38+
print("Running generating weighted transfer matrix")
39+
weight_matrix = pd.pivot_table(intersection,
40+
values='gdf1_pct',
41+
index='COMID', # replace with your actual column name from gdf2
42+
columns='divide_id', # replace with your actual column name from gdf1
43+
fill_value=0)
44+
45+
print("Saving to sparse zarr store")
46+
store = zarr.storage.LocalStore(root=out_path)
47+
if out_path.exists():
48+
root = zarr.open_group(store=store)
49+
else:
50+
root = zarr.create_group(store=store)
51+
52+
coo = sparse.coo_matrix(weight_matrix.to_numpy())
53+
54+
comid_order = np.array([int(float(_id.split("-")[1])) for _id in weight_matrix.columns.to_numpy()], dtype=np.int32)
55+
merit_basin_order = weight_matrix.index.to_numpy().astype(np.int32)
56+
57+
gauge_root = root.create_group(name=zone)
58+
indices_0 = gauge_root.create_array(
59+
name='indices_0', shape=coo.row.shape, dtype=coo.row.dtype
60+
)
61+
indices_1 = gauge_root.create_array(
62+
name='indices_1', shape=coo.col.shape, dtype=coo.row.dtype
63+
)
64+
values = gauge_root.create_array(
65+
name='values', shape=coo.data.shape, dtype=coo.data.dtype
66+
)
67+
comid_zarr_order = gauge_root.create_array(
68+
name='comid_order', shape=comid_order.shape, dtype=comid_order.dtype
69+
)
70+
merit_basins_zarr_order = gauge_root.create_array(
71+
name='merit_basins_order', shape=merit_basin_order.shape, dtype=merit_basin_order.dtype
72+
)
73+
indices_0[:] = coo.row
74+
indices_1[:] = coo.col
75+
values[:] = coo.data
76+
comid_zarr_order[:] = comid_order
77+
merit_basins_zarr_order[:] = merit_basin_order
78+
79+
gauge_root.attrs["format"] = "COO"
80+
gauge_root.attrs["shape"] = list(coo.shape)
81+
gauge_root.attrs["data_types"] = {
82+
"indices_0": coo.row.dtype.__str__(),
83+
"indices_1": coo.col.dtype.__str__(),
84+
"values": coo.data.dtype.__str__(),
85+
}
86+
print(f"{out_path} written to zarr")
87+
88+
# weight_matrix.to_csv("/projects/mhpi/data/hydrofabric/v2.2/73_conus_transition_matrix.csv")
89+
# print("Created transition matrix @ /projects/mhpi/data/hydrofabric/v2.2/73_conus_transition_matrix.csv")

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ maintainers = [
2424
]
2525

2626
dependencies = [
27-
"numpy==2.2.2",
27+
"numpy==2.2.3",
2828
"pandas==2.2.3",
2929
"geopandas==1.0.1",
3030
"pydantic==2.10.6",
@@ -34,12 +34,11 @@ dependencies = [
3434
"hydra-core==1.3.2",
3535
"tqdm==4.67.1",
3636
"polars==1.21.0",
37-
"zarr==3.0.2",
37+
"zarr==3.0.3",
3838
"sympy==1.13.1",
3939
"pykan==0.2.8",
4040
"scikit-learn==1.6.1",
4141
"matplotlib==3.10.0",
42-
"binsparse @ git+https://github.com/ivirshup/binsparse-python.git@main",
4342
]
4443

4544
[project.optional-dependencies]
@@ -86,4 +85,4 @@ explicit = true
8685
[[tool.uv.index]]
8786
name = "pytorch-cu124"
8887
url = "https://download.pytorch.org/whl/cu124"
89-
explicit = true
88+
explicit = true

scripts/train.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
22
import random
33
import time
4+
from pathlib import Path
45

56
import hydra
67
import numpy as np
78
import torch
9+
from hydra.core.hydra_config import HydraConfig
810
from omegaconf import DictConfig
911
from torch.utils.data import DataLoader
1012
from torch.nn.functional import mse_loss
@@ -15,6 +17,9 @@
1517
from ddr.dataset.utils import downsample
1618
from ddr.dataset.streamflow import StreamflowReader as streamflow
1719
from ddr.dataset.train_dataset import train_dataset
20+
from ddr.analysis.metrics import Metrics
21+
from ddr.analysis.plots import plot_time_series
22+
from ddr.analysis.utils import save_state
1823

1924
log = logging.getLogger(__name__)
2025

@@ -39,22 +44,46 @@ def train(cfg, flow, routing_model, nn):
3944
drop_last=True,
4045
)
4146

42-
optimizer = torch.optim.Adam(params=nn.parameters(), lr=cfg.train.learning_rate[str(0)])
47+
if cfg.train.spatial_checkpoint:
48+
file_path = Path(cfg.train.spatial_checkpoint)
49+
log.info(f"Loading spatial_nn from checkpoint: {file_path.stem}")
50+
state = torch.load(file_path)
51+
state_dict = state["model_state_dict"]
52+
for key in state_dict.keys():
53+
state_dict[key] = state_dict[key].to(cfg.device)
54+
nn.load_state_dict(state["model_state_dict"])
55+
torch.set_rng_state(state["rng_state"])
56+
start_epoch = state["epoch"]
57+
# start_mini_batch = 0 if state["mini_batch"] == 0 else state["mini_batch"] + 1 # Start from the next mini-batch
58+
if torch.cuda.is_available() and "cuda_rng_state" in state:
59+
torch.cuda.set_rng_state(state["cuda_rng_state"])
60+
if start_epoch in cfg.train.learning_rate.keys():
61+
lr = cfg.train.learning_rate[start_epoch]
62+
else:
63+
key_list = list(cfg.train.learning_rate.keys())
64+
lr = cfg.train.learning_rate[key_list[-1]]
65+
else:
66+
log.info("Creating new spatial model")
67+
start_epoch = 1
68+
# start_mini_batch = 0
69+
lr = cfg.train.learning_rate[str(0)]
4370

44-
for epoch in range(0, cfg.train.epochs + 1):
71+
optimizer = torch.optim.Adam(params=nn.parameters(), lr=lr)
72+
73+
for epoch in range(start_epoch, cfg.train.epochs + 1):
4574
routing_model.epoch = epoch
4675
for i, hydrofabric in enumerate(dataloader, start=0):
4776
routing_model.mini_batch = i
4877

4978
streamflow_predictions = flow(cfg=cfg, hydrofabric=hydrofabric)
50-
q_prime = streamflow_predictions["streamflow"] @ torch.tensor(hydrofabric.transition_matrix.to_numpy(), dtype=torch.float32, device=cfg.device)
79+
q_prime = streamflow_predictions["streamflow"] @ hydrofabric.transition_matrix
5180
spatial_params = nn(
5281
inputs=hydrofabric.normalized_spatial_attributes.to(cfg.device)
5382
)
5483
dmc_kwargs = {
5584
"hydrofabric": hydrofabric,
5685
"spatial_parameters": spatial_params,
57-
"streamflow": q_prime,
86+
"streamflow": torch.tensor(q_prime, device=cfg.device, dtype=torch.float32)
5887
}
5988
dmc_output = routing_model(**dmc_kwargs)
6089

@@ -68,24 +97,58 @@ def train(cfg, flow, routing_model, nn):
6897
np_nan_mask = nan_mask.streamflow.values
6998

7099
filtered_ds = hydrofabric.observations.where(~nan_mask, drop=True)
71-
filtered_observations = torch.tensor(filtered_ds.streamflow.values, device=cfg.device)[
100+
filtered_observations = torch.tensor(filtered_ds.streamflow.values, device=cfg.device, dtype=torch.float32)[
72101
:, 1:-1
73102
] # Cutting off days to match with realigned timesteps
74103

75104
filtered_predictions = daily_runoff[~np_nan_mask]
76105

77106
loss = mse_loss(
78-
input=filtered_predictions.transpose(0, 1)[cfg.warmup:].unsqueeze(2),
79-
target=filtered_observations.transpose(0, 1)[cfg.warmup:].unsqueeze(2),
107+
input=filtered_predictions.transpose(0, 1)[cfg.train.warmup:].unsqueeze(2),
108+
target=filtered_observations.transpose(0, 1)[cfg.train.warmup:].unsqueeze(2),
80109
)
81110

82-
log.info("Running gradient-averaged backpropagation")
111+
log.info("Running backpropagation")
83112

84113
loss.backward()
85114
optimizer.step()
86115
optimizer.zero_grad()
87116

88-
print(f"Loss: {loss.item}")
117+
np_pred = filtered_predictions.detach().cpu().numpy()
118+
np_target = filtered_observations.detach().cpu().numpy()
119+
plotted_dates = dataset.dates.batch_daily_time_range[
120+
1:-1
121+
]
122+
metrics = Metrics(pred=np_pred, target=np_target)
123+
pred_nse = metrics.nse
124+
pred_nse_filtered = pred_nse[~np.isinf(pred_nse) & ~np.isnan(pred_nse)]
125+
median_nse = torch.tensor(pred_nse_filtered).median()
126+
127+
# TODO: scale out when we have more gauges
128+
# random_index = np.random.randint(low=0, high=filtered_observations.shape[0], size=(1,))[0]
129+
random_gage = -1
130+
plot_time_series(
131+
filtered_predictions[-1].detach().cpu().numpy(),
132+
filtered_observations[-1].cpu().numpy(),
133+
plotted_dates,
134+
dataset.obs_reader.gage_dict["STAID"][random_gage],
135+
dataset.obs_reader.gage_dict["STANAME"][random_gage],
136+
metrics={"nse": pred_nse[-1]},
137+
path=cfg.params.save_path / f"plots/epoch_{epoch}_mb_{i}_validation_plot.png",
138+
warmup=cfg.train.warmup,
139+
)
140+
141+
save_state(
142+
epoch=epoch,
143+
mini_batch=i,
144+
mlp=nn,
145+
optimizer=optimizer,
146+
name=cfg.name,
147+
saved_model_path=cfg.params.save_path / "saved_models",
148+
)
149+
150+
print(f"Loss: {loss.item()}")
151+
print(f"Median NSE: {median_nse}")
89152

90153
if epoch in cfg.train.learning_rate.keys():
91154
log.info(f"Updating learning rate: {cfg.train.learning_rate[epoch]}")
@@ -101,6 +164,9 @@ def train(cfg, flow, routing_model, nn):
101164
)
102165
def main(cfg: DictConfig) -> None:
103166
_set_seed(cfg=cfg)
167+
cfg.params.save_path = Path(HydraConfig.get().run.dir)
168+
(cfg.params.save_path / "plots").mkdir(exist_ok=True)
169+
(cfg.params.save_path / "saved_models").mkdir(exist_ok=True)
104170
try:
105171
start_time = time.perf_counter()
106172
nn = kan(
@@ -111,7 +177,8 @@ def main(cfg: DictConfig) -> None:
111177
num_hidden_layers=cfg.kan.num_hidden_layers,
112178
grid=cfg.kan.grid,
113179
k=cfg.kan.k,
114-
seed=cfg.seed
180+
seed=cfg.seed,
181+
device=cfg.device
115182
)
116183
routing_model = dmc(
117184
cfg=cfg,

src/ddr/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ddr.nn.kan import kan
22
from ddr.routing.dmc import dmc
33
from ddr.dataset.streamflow import StreamflowReader
4+
from ddr.analysis.metrics import Metrics
45

5-
__all__ = ["dmc", "kan", "StreamflowReader"]
6+
__all__ = ["dmc", "kan", "StreamflowReader", "Metrics"]

src/ddr/analysis/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)