-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Alexander März
committed
Aug 8, 2023
1 parent
3583fc9
commit 8fb028b
Showing
27 changed files
with
1,439 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""LightGBMLSS - An extension of LightGBM to probabilistic forecasting""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""LightGBMLSS - An extension of LightGBM to probabilistic forecasting""" |
24 changes: 24 additions & 0 deletions
24
tests/test_distribution_utils/test_calculate_start_values.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from ..utils import BaseTestClass, gen_test_data | ||
import numpy as np | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
def test_calculate_start_values(self, dist_class, loss_fn): | ||
# Create data for testing | ||
_, target, _ = gen_test_data(dist_class) | ||
|
||
# Set the loss function for testing | ||
dist_class.dist.loss_fn = loss_fn | ||
|
||
# Call the objective_fn method | ||
loss, start_values = dist_class.dist.calculate_start_values(target) | ||
|
||
# Assertions | ||
assert isinstance(loss, np.ndarray) | ||
assert not np.isnan(loss).any() | ||
assert not np.isinf(loss).any() | ||
|
||
assert isinstance(start_values, np.ndarray) | ||
assert start_values.shape[0] == dist_class.dist.n_dist_param | ||
assert not np.isnan(start_values).any() | ||
assert not np.isinf(start_values).any() |
115 changes: 115 additions & 0 deletions
115
tests/test_distribution_utils/test_compute_gradients_and_hessians.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from ..utils import BaseTestClass, gen_test_data | ||
from typing import List | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
def test_compute_gradients_and_hessians(self, dist_class, loss_fn, stabilization): | ||
# Create data for testing | ||
params, target, weights, _ = gen_test_data(dist_class, weights=True) | ||
if dist_class.dist.univariate: | ||
target = torch.tensor(target) | ||
else: | ||
target = torch.tensor(target)[:, :dist_class.dist.n_targets] | ||
start_values = np.array([0.5 for _ in range(dist_class.dist.n_dist_param)]) | ||
|
||
# Set the loss function for testing | ||
dist_class.dist.loss_fn = loss_fn | ||
|
||
# Set the stabilization for testing | ||
dist_class.dist.stabilization = stabilization | ||
|
||
# Call the function | ||
predt, loss = dist_class.dist.get_params_loss(params, target, start_values, requires_grad=True) | ||
grad, hess = dist_class.dist.compute_gradients_and_hessians(loss, predt, weights) | ||
|
||
# Assertions | ||
assert isinstance(predt, List) | ||
for i in range(len(predt)): | ||
assert isinstance(predt[i], torch.Tensor) | ||
assert not torch.isnan(predt[i]).any() | ||
assert not torch.isinf(predt[i]).any() | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() | ||
|
||
assert isinstance(grad, np.ndarray) | ||
assert isinstance(hess, np.ndarray) | ||
assert grad.shape == params.flatten().shape | ||
assert hess.shape == params.flatten().shape | ||
assert not np.isnan(grad).any() | ||
assert not np.isnan(hess).any() | ||
|
||
def test_compute_gradients_and_hessians_crps(self, dist_class_crps, stabilization): | ||
# Create data for testing | ||
params, target, weights, _ = gen_test_data(dist_class_crps, weights=True) | ||
if dist_class_crps.dist.univariate: | ||
target = torch.tensor(target) | ||
else: | ||
target = torch.tensor(target)[:, :dist_class_crps.dist.n_targets] | ||
start_values = np.array([0.5 for _ in range(dist_class_crps.dist.n_dist_param)]) | ||
|
||
# Set the loss function for testing | ||
dist_class_crps.dist.loss_fn = "crps" | ||
|
||
# Set the stabilization for testing | ||
dist_class_crps.dist.stabilization = stabilization | ||
|
||
# Call the function | ||
predt, loss = dist_class_crps.dist.get_params_loss(params, target, start_values, requires_grad=True) | ||
grad, hess = dist_class_crps.dist.compute_gradients_and_hessians(loss, predt, weights) | ||
|
||
# Assertions | ||
assert isinstance(predt, List) | ||
for i in range(len(predt)): | ||
assert isinstance(predt[i], torch.Tensor) | ||
assert not torch.isnan(predt[i]).any() | ||
assert not torch.isinf(predt[i]).any() | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() | ||
|
||
assert isinstance(grad, np.ndarray) | ||
assert isinstance(hess, np.ndarray) | ||
assert grad.shape == params.flatten().shape | ||
assert hess.shape == params.flatten().shape | ||
assert not np.isnan(grad).any() | ||
assert not np.isnan(hess).any() | ||
|
||
def test_compute_gradients_and_hessians_nans(self, dist_class, loss_fn, stabilization): | ||
# Create data for testing | ||
params, target, weights, _ = gen_test_data(dist_class, weights=True) | ||
params[0, 0] = np.nan | ||
if dist_class.dist.univariate: | ||
target = torch.tensor(target) | ||
else: | ||
target = torch.tensor(target)[:, :dist_class.dist.n_targets] | ||
start_values = np.array([0.5 for _ in range(dist_class.dist.n_dist_param)]) | ||
|
||
# Set the loss function for testing | ||
dist_class.dist.loss_fn = loss_fn | ||
|
||
# Set the stabilization for testing | ||
dist_class.dist.stabilization = stabilization | ||
|
||
# Call the function | ||
predt, loss = dist_class.dist.get_params_loss(params, target, start_values, requires_grad=True) | ||
grad, hess = dist_class.dist.compute_gradients_and_hessians(loss, predt, weights) | ||
|
||
# Assertions | ||
assert isinstance(predt, List) | ||
for i in range(len(predt)): | ||
assert isinstance(predt[i], torch.Tensor) | ||
assert not torch.isnan(predt[i]).any() | ||
assert not torch.isinf(predt[i]).any() | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() | ||
|
||
assert isinstance(grad, np.ndarray) | ||
assert isinstance(hess, np.ndarray) | ||
assert grad.shape == params.flatten().shape | ||
assert hess.shape == params.flatten().shape | ||
assert not np.isnan(grad).any() | ||
assert not np.isnan(hess).any() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from ..utils import BaseTestClass | ||
import torch | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
def test_crps_score(self, dist_class_crps): | ||
# Create data for testing | ||
torch.manual_seed(123) | ||
n_obs = 10 | ||
n_samples = 20 | ||
y = torch.rand(n_obs, 1) | ||
yhat_dist = torch.rand(n_samples, n_obs) | ||
|
||
# Call the function | ||
loss = dist_class_crps.dist.crps_score(y, yhat_dist) | ||
|
||
# Assertions | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() | ||
assert loss.shape == y.shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from ..utils import BaseTestClass | ||
|
||
from lightgbmlss.distributions import Beta, Gaussian, StudentT, Gamma, Cauchy, LogNormal, Weibull, Gumbel, Laplace | ||
from lightgbmlss.distributions.SplineFlow import * | ||
from lightgbmlss.distributions.distribution_utils import DistributionClass as univariate_dist_class | ||
from lightgbmlss.distributions.flow_utils import NormalizingFlowClass as flow_dist_class | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
#################################################################################################################### | ||
# Univariate Distribution | ||
#################################################################################################################### | ||
def test_univar_dist_select(self): | ||
# Create data for testing | ||
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1) | ||
candidate_distributions = [Beta, Gaussian, StudentT, Gamma, Cauchy, LogNormal, Weibull, Gumbel, Laplace] | ||
|
||
# Call the function | ||
dist_df = univariate_dist_class().dist_select( | ||
target, candidate_distributions, n_samples=10, plot=False | ||
).reset_index(drop=True) | ||
|
||
# Assertions | ||
assert isinstance(dist_df, pd.DataFrame) | ||
assert not dist_df.isna().any().any() | ||
assert isinstance(dist_df["distribution"].values[0], str) | ||
assert np.issubdtype(dist_df["nll"].dtype, np.float64) | ||
assert not np.isnan(dist_df["nll"].values).any() | ||
assert not np.isinf(dist_df["nll"].values).any() | ||
|
||
def test_univar_dist_select_plot(self): | ||
# Create data for testing | ||
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1) | ||
candidate_distributions = [Beta, Gaussian, StudentT, Gamma, Cauchy, LogNormal, Weibull, Gumbel, Laplace] | ||
|
||
# Call the function | ||
dist_df = univariate_dist_class().dist_select( | ||
target, candidate_distributions, n_samples=10, plot=True | ||
).reset_index(drop=True) | ||
|
||
# Assertions | ||
assert isinstance(dist_df, pd.DataFrame) | ||
assert not dist_df.isna().any().any() | ||
assert isinstance(dist_df["distribution"].values[0], str) | ||
assert np.issubdtype(dist_df["nll"].dtype, np.float64) | ||
assert not np.isnan(dist_df["nll"].values).any() | ||
assert not np.isinf(dist_df["nll"].values).any() | ||
|
||
#################################################################################################################### | ||
# Normalizing Flows | ||
#################################################################################################################### | ||
def test_flow_select(self): | ||
# Create data for testing | ||
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1) | ||
bound = np.max([np.abs(target.min()), target.max()]) | ||
target_support = "real" | ||
|
||
candidate_flows = [ | ||
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="linear"), | ||
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="quadratic") | ||
] | ||
|
||
# Call the function | ||
dist_df = flow_dist_class().flow_select( | ||
target, candidate_flows, n_samples=10, plot=False | ||
).reset_index(drop=True) | ||
|
||
# Assertions | ||
assert isinstance(dist_df, pd.DataFrame) | ||
assert not dist_df.isna().any().any() | ||
assert isinstance(dist_df["NormFlow"].values[0], str) | ||
assert np.issubdtype(dist_df["nll"].dtype, np.float64) | ||
assert not np.isnan(dist_df["nll"].values).any() | ||
assert not np.isinf(dist_df["nll"].values).any() | ||
|
||
def test_flow_select_plot(self): | ||
# Create data for testing | ||
target = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1) | ||
bound = np.max([np.abs(target.min()), target.max()]) | ||
target_support = "real" | ||
|
||
candidate_flows = [ | ||
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="linear"), | ||
SplineFlow(target_support=target_support, count_bins=2, bound=bound, order="quadratic") | ||
] | ||
|
||
# Call the function | ||
dist_df = flow_dist_class().flow_select( | ||
target, candidate_flows, n_samples=10, plot=True | ||
).reset_index(drop=True) | ||
|
||
# Assertions | ||
assert isinstance(dist_df, pd.DataFrame) | ||
assert not dist_df.isna().any().any() | ||
assert isinstance(dist_df["NormFlow"].values[0], str) | ||
assert np.issubdtype(dist_df["nll"].dtype, np.float64) | ||
assert not np.isnan(dist_df["nll"].values).any() | ||
assert not np.isinf(dist_df["nll"].values).any() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from ..utils import BaseTestClass | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
def test_draw_samples(self, dist_class): | ||
# Create data for testing | ||
predt_params = pd.DataFrame(np.array([0.5 for _ in range(dist_class.dist.n_dist_param)], dtype="float32")).T | ||
|
||
# Call the function | ||
dist_samples = dist_class.dist.draw_samples(predt_params) | ||
|
||
# Assertions | ||
if str(dist_class.dist).split(".")[2] != "Expectile": | ||
assert isinstance(dist_samples, (pd.DataFrame, type(None))) | ||
assert not dist_samples.isna().any().any() | ||
assert not np.isinf(dist_samples).any().any() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from ..utils import BaseTestClass, gen_test_data | ||
from typing import List | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
def test_get_params_loss(self, dist_class, loss_fn, requires_grad): | ||
# Create data for testing | ||
predt, target, _ = gen_test_data(dist_class) | ||
target = torch.tensor(target) | ||
start_values = np.array([0.5 for _ in range(dist_class.dist.n_dist_param)]) | ||
|
||
# Set the loss function for testing | ||
dist_class.dist.loss_fn = loss_fn | ||
|
||
# Call the function | ||
predt, loss = dist_class.dist.get_params_loss(predt, target, start_values, requires_grad) | ||
|
||
# Assertions | ||
assert isinstance(predt, List) | ||
for i in range(len(predt)): | ||
assert isinstance(predt[i], torch.Tensor) | ||
assert not torch.isnan(predt[i]).any() | ||
assert not torch.isinf(predt[i]).any() | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() | ||
|
||
def test_get_params_loss_nans(self, dist_class, loss_fn, requires_grad): | ||
# Create data for testing | ||
predt, target, _ = gen_test_data(dist_class) | ||
predt[0, 0] = np.nan | ||
target = torch.tensor(target) | ||
start_values = np.array([0.5 for _ in range(dist_class.dist.n_dist_param)]) | ||
|
||
# Set the loss function for testing | ||
dist_class.dist.loss_fn = loss_fn | ||
|
||
# Call the function | ||
predt, loss = dist_class.dist.get_params_loss(predt, target, start_values, requires_grad) | ||
|
||
# Assertions | ||
assert isinstance(predt, List) | ||
for i in range(len(predt)): | ||
assert isinstance(predt[i], torch.Tensor) | ||
assert not torch.isnan(predt[i]).any() | ||
assert not torch.isinf(predt[i]).any() | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() | ||
|
||
def test_get_params_loss_crps(self, dist_class_crps, requires_grad): | ||
# Create data for testing | ||
predt, target, _ = gen_test_data(dist_class_crps) | ||
target = torch.tensor(target) | ||
start_values = np.array([0.5 for _ in range(dist_class_crps.dist.n_dist_param)]) | ||
|
||
# Set the loss function for testing | ||
dist_class_crps.dist.loss_fn = "crps" | ||
|
||
# Call the function | ||
predt, loss = dist_class_crps.dist.get_params_loss(predt, target, start_values, requires_grad) | ||
|
||
# Assertions | ||
assert isinstance(predt, List) | ||
for i in range(len(predt)): | ||
assert isinstance(predt[i], torch.Tensor) | ||
assert not torch.isnan(predt[i]).any() | ||
assert not torch.isinf(predt[i]).any() | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() |
29 changes: 29 additions & 0 deletions
29
tests/test_distribution_utils/test_loss_fn_start_values.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from ..utils import BaseTestClass, gen_test_data | ||
import torch | ||
|
||
|
||
class TestClass(BaseTestClass): | ||
def test_loss_fn_start_values(self, dist_class, loss_fn): | ||
# Create data for testing | ||
_, target, _ = gen_test_data(dist_class) | ||
predt = [ | ||
torch.tensor(0.5, dtype=torch.float64).reshape(-1, 1).requires_grad_(True) for _ in | ||
range(dist_class.dist.n_dist_param) | ||
] | ||
if dist_class.dist.univariate: | ||
target = torch.tensor(target) | ||
else: | ||
target = torch.tensor(target)[:, :dist_class.dist.n_targets] | ||
|
||
# Set the loss function for testing | ||
dist_class.dist.loss_fn = loss_fn | ||
|
||
# Call the function | ||
if hasattr(dist_class.dist, "base_dist"): | ||
pass | ||
else: | ||
loss = dist_class.dist.loss_fn_start_values(predt, target) | ||
# Assertions | ||
assert isinstance(loss, torch.Tensor) | ||
assert not torch.isnan(loss).any() | ||
assert not torch.isinf(loss).any() |
Oops, something went wrong.