Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions basicts/metrics/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float

"""

if len(prediction.shape) == 4: # (Bs, L, N, 1) else (Bs, N, 1)
prediction = torch.mean(prediction, dim=1)
target = torch.mean(target, dim=1)

if np.isnan(null_val):
mask = ~torch.isnan(target)
else:
Expand All @@ -32,18 +36,17 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float
mask /= torch.mean(mask) # Normalize mask to avoid bias in the loss due to the number of valid entries
mask = torch.nan_to_num(mask) # Replace any NaNs in the mask with zero

prediction_mean = torch.mean(prediction, dim=1, keepdim=True)
target_mean = torch.mean(target, dim=1, keepdim=True)
prediction_mean = torch.mean(prediction, dim=0, keepdim=True)
target_mean = torch.mean(target, dim=0, keepdim=True)

# 计算偏差 (X - mean_X) 和 (Y - mean_Y)
prediction_dev = prediction - prediction_mean
target_dev = target - target_mean

# 计算皮尔逊相关系数
numerator = torch.sum(prediction_dev * target_dev, dim=1, keepdim=True) # 分子
denominator = torch.sqrt(torch.sum(prediction_dev ** 2, dim=1, keepdim=True) * torch.sum(target_dev ** 2, dim=1, keepdim=True)) # 分母
numerator = torch.sum(prediction_dev * target_dev, dim=0, keepdim=True) # 分子
denominator = torch.sqrt(torch.sum(prediction_dev ** 2, dim=0, keepdim=True) * torch.sum(target_dev ** 2, dim=0, keepdim=True)) # 分母
loss = numerator / denominator

loss = loss * mask # Apply the mask to the loss
loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero

Expand Down
9 changes: 6 additions & 3 deletions basicts/metrics/r_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float =

"""

if len(prediction.shape) == 4: # (Bs, L, N, 1) else (Bs, N, 1)
prediction = torch.mean(prediction, dim=1)
target = torch.mean(target, dim=1)

eps = 5e-5
if np.isnan(null_val):
mask = ~torch.isnan(target)
Expand All @@ -34,11 +38,10 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float =
prediction = torch.nan_to_num(prediction)
target = torch.nan_to_num(target)

ss_res = torch.sum(torch.pow((target - prediction), 2), dim=1) # 残差平方和
ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=1, keepdim=True), 2), dim=1) # 总平方和
ss_res = torch.sum(torch.pow((target - prediction), 2), dim=0) # 残差平方和
ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=0, keepdim=True), 2), dim=0) # 总平方和

# 计算 R^2
loss = 1 - (ss_res / (ss_tot + eps))

loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero
return torch.mean(loss)
Loading