Skip to content

Commit be7adcd

Browse files
committed
added warm up and fix seed functions
1 parent b55f92a commit be7adcd

File tree

4 files changed

+144
-2
lines changed

4 files changed

+144
-2
lines changed

example/tutorial_quick_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
c_val = val_csv.load_constant(var_constant, convert_time_series=False)
8383
y_val = val_csv.load_time_series(target, remove_nan=False)
8484

85-
val_epoch = 100 # Select the epoch for testing
85+
val_epoch = EPOCH # Select the epoch for testing
8686

8787
# load the model
8888
test_model = loadModel(output_s, epoch=val_epoch)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
5+
import sys
6+
sys.path.append('..')
7+
8+
from hydroDL.master.master import loadModel
9+
from hydroDL.model.crit import RmseLoss
10+
from hydroDL.model.rnn import CudnnLstmModel as LSTM
11+
from hydroDL.model.rnn import CpuLstmModel as LSTM_CPU
12+
from hydroDL.model.train import trainModel
13+
from hydroDL.model.test import testModel
14+
from hydroDL.post.stat import statError as cal_metric
15+
from hydroDL.data.load_csv import LoadCSV
16+
from hydroDL.utils.norm import re_folder, trans_norm
17+
from hydroDL.utils.norm import fix_seed
18+
19+
# set configuration
20+
fix_seed(42)
21+
output_s = "./output/quick_start/" # output path
22+
csv_path_s = "./demo_data/" # demo data path
23+
all_date_list = ["2015-04-01", "2017-03-31"] # demo data time period
24+
train_date_list = ["2015-04-01", "2016-03-31"] # training period
25+
# time series variables list
26+
var_time_series = ["VGRD_10_FORA", "DLWRF_FORA", "UGRD_10_FORA", "DSWRF_FORA", "TMP_2_FORA", "SPFH_2_FORA", "APCP_FORA", ]
27+
# constant variables list
28+
var_constant = ["flag_extraOrd", "Clay", "Bulk", "Sand", "flag_roughness", "flag_landcover", "flag_vegDense", "Silt", "NDVI",
29+
"flag_albedo", "flag_waterbody", "Capa", ]
30+
# target variable list
31+
target = ["SMAP_AM"]
32+
33+
# generate output folder
34+
re_folder(output_s)
35+
36+
# hyperparameter
37+
EPOCH = 20
38+
BATCH_SIZE = 50
39+
RHO = 30
40+
HIDDEN_SIZE = 256
41+
WARM_UP_DAY = 10
42+
# WARM_UP_DAY = None
43+
44+
# load your datasets
45+
"""
46+
You can change it with your data. The data structure is as follows:
47+
x_train (forcing data, e.g. precipitation, temperature ...): [pixels, time, features]
48+
c_train (constant data, e.g. soil properties, land cover ...): [pixels, features]
49+
target (e.g. soil moisture, streamflow ...): [pixels, time, 1]
50+
51+
Data type: numpy.float
52+
"""
53+
train_csv = LoadCSV(csv_path_s, train_date_list, all_date_list)
54+
x_train = train_csv.load_time_series(var_time_series) # data size: [pixels, time, features]
55+
c_train = train_csv.load_constant(var_constant, convert_time_series=False) # [pixels, features]
56+
y_train = train_csv.load_time_series(target, remove_nan=False) # [pixels, time, 1]
57+
58+
# define model and loss function
59+
loss_fn = RmseLoss() # loss function
60+
# select model: GPU or CPU
61+
if torch.cuda.is_available():
62+
LSTM = LSTM
63+
else:
64+
LSTM = LSTM_CPU
65+
model = LSTM(nx=len(var_time_series) + len(var_constant), ny=len(target), hiddenSize=HIDDEN_SIZE, warmUpDay=WARM_UP_DAY)
66+
67+
# training the model
68+
last_model = trainModel(
69+
model,
70+
x_train,
71+
y_train,
72+
c_train,
73+
loss_fn,
74+
nEpoch=EPOCH,
75+
miniBatch=[BATCH_SIZE, RHO],
76+
saveEpoch=1,
77+
saveFolder=output_s,
78+
)
79+
80+
# validation the result
81+
# load validation datasets
82+
val_date_list = ["2016-04-01", "2017-03-31"] # validation period
83+
# load your data. same as training data
84+
val_csv = LoadCSV(csv_path_s, val_date_list, all_date_list)
85+
x_val = val_csv.load_time_series(var_time_series)
86+
c_val = val_csv.load_constant(var_constant, convert_time_series=False)
87+
y_val = val_csv.load_time_series(target, remove_nan=False)
88+
89+
val_epoch = EPOCH # Select the epoch for testing
90+
91+
# load the model
92+
test_model = loadModel(output_s, epoch=val_epoch)
93+
94+
# set the path to save result
95+
save_csv = os.path.join(output_s, "predict.csv")
96+
97+
# validation
98+
pred_val = testModel(test_model, x_val, c_val, batchSize=len(x_train), filePathLst=[save_csv],)
99+
100+
# select the metrics
101+
metrics_list = ["Bias", "RMSE", "ubRMSE", "Corr"]
102+
pred_val = pred_val.numpy()
103+
# denormalization
104+
pred_val = trans_norm(pred_val, csv_path_s, var_s=target[0], from_raw=False)
105+
y_val = trans_norm(y_val, csv_path_s, var_s=target[0], from_raw=False)
106+
pred_val, y_val = np.squeeze(pred_val), np.squeeze(y_val)
107+
metrics_dict = cal_metric(pred_val, y_val) # calculate the metrics
108+
metrics = ["Median {}: {:.4f}".format(x, np.nanmedian(metrics_dict[x])) for x in metrics_list]
109+
print("Epoch {}: {}".format(val_epoch, metrics))

hydroDL/model/rnn/CudnnLstmModel.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class CudnnLstmModel(torch.nn.Module):
15-
def __init__(self, *, nx, ny, hiddenSize, dr=0.5):
15+
def __init__(self, *, nx, ny, hiddenSize, dr=0.5, warmUpDay=None):
1616
super(CudnnLstmModel, self).__init__()
1717
self.nx = nx
1818
self.ny = ny
@@ -32,6 +32,7 @@ def __init__(self, *, nx, ny, hiddenSize, dr=0.5):
3232
self.name = "CudnnLstmModel"
3333
self.is_legacy = True
3434
# self.drtest = torch.nn.Dropout(p=0.4)
35+
self.warmUpDay = warmUpDay
3536

3637
def forward(self, x, doDropMC=False, dropoutFalse=False):
3738
"""
@@ -41,6 +42,9 @@ def forward(self, x, doDropMC=False, dropoutFalse=False):
4142
:param dropoutFalse:
4243
:return:
4344
"""
45+
if not self.warmUpDay is None:
46+
x, warmUpDay = self.extend_day(x, warmUpDay=self.warmUpDay)
47+
4448
x0 = F.relu(self.linearIn(x))
4549
if torch.__version__ > "1.9":
4650
outLSTM, (hn, cn) = self.lstm(x0)
@@ -50,4 +54,19 @@ def forward(self, x, doDropMC=False, dropoutFalse=False):
5054
)
5155
# outLSTMdr = self.drtest(outLSTM)
5256
out = self.linearOut(outLSTM)
57+
58+
if not self.warmUpDay is None:
59+
out = self.reduce_day(out, warmUpDay=warmUpDay)
60+
5361
return out
62+
63+
def extend_day(self, x, warm_up_day):
64+
x_num_day = x.shape[0]
65+
warm_up_day = min(x_num_day, warm_up_day)
66+
x_select = x[:warm_up_day, :, :]
67+
x = torch.cat([x_select, x], dim=0)
68+
return x, warm_up_day
69+
70+
def reduce_day(self, x, warm_up_day):
71+
x = x[warm_up_day:,:,:]
72+
return x

hydroDL/utils/norm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,17 @@ def re_folder(path_s, del_old_path=False):
6969
pass
7070
else:
7171
re_folder_rec(path_s)
72+
73+
def fix_seed(SEED):
74+
import os
75+
import numpy as np
76+
import random
77+
import torch
78+
np.random.seed(SEED)
79+
random.seed(SEED)
80+
torch.backends.cudnn.deterministic = True
81+
torch.backends.cudnn.benchmark = False
82+
torch.manual_seed(SEED)
83+
torch.cuda.manual_seed(SEED)
84+
torch.cuda.manual_seed_all(SEED)
85+
os.environ["PYTHONHASHSEED"] = str(SEED)

0 commit comments

Comments
 (0)