Skip to content

Commit

Permalink
Added additional input checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 8, 2023
1 parent 51ac620 commit 149f402
Show file tree
Hide file tree
Showing 20 changed files with 280 additions and 164 deletions.
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/Beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = Beta_Torch
Expand Down
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/Cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = Cauchy_Torch
Expand Down
11 changes: 11 additions & 0 deletions lightgbmlss/distributions/Expectile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def __init__(self,
expectiles: List = [0.1, 0.5, 0.9],
penalize_crossing: bool = False,
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if not isinstance(expectiles, list):
raise ValueError("Expectiles must be a list.")
if not all([0 < expectile < 1 for expectile in expectiles]):
raise ValueError("Expectiles must be between 0 and 1.")
if not isinstance(penalize_crossing, bool):
raise ValueError("penalize_crossing must be a boolean. Please choose from True or False.")

# Set the parameters specific to the distribution
distribution = Expectile_Torch
torch.distributions.Distribution.set_default_validate_args(False)
Expand Down
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/Gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = Gamma_Torch
Expand Down
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/Gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = Gaussian_Torch
Expand Down
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/Gumbel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = Gumbel_Torch
Expand Down
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/Laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = Laplace_Torch
Expand Down
17 changes: 12 additions & 5 deletions lightgbmlss/distributions/LogNormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@ def __init__(self,
response_fn: str = "exp",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")

# Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'exp' or 'softplus'.")
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")

# Set the parameters specific to the distribution
distribution = LogNormal_Torch
Expand Down
31 changes: 18 additions & 13 deletions lightgbmlss/distributions/NegativeBinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,36 @@ class NegativeBinomial(DistributionClass):
Response function for transforming the distributional parameters to the correct support. Options are
"sigmoid" (sigmoid).
loss_fn: str
Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score).
Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable.
Hence, using the CRPS disregards any variation in the curvature of the loss function.
Loss function. Options are "nll" (negative log-likelihood).
"""
def __init__(self,
stabilization: str = "None",
response_fn_total_count: str = "relu",
response_fn_probs: str = "sigmoid",
loss_fn: str = "nll"
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll"]:
raise ValueError("Invalid loss function. Please select 'nll'.")

# Specify Response Functions for total_count
if response_fn_total_count == "exp":
response_fn_total_count = exp_fn
elif response_fn_total_count == "softplus":
response_fn_total_count = softplus_fn
elif response_fn_total_count == "relu":
response_fn_total_count = relu_fn
response_functions_total_count = {"exp": exp_fn, "softplus": softplus_fn, "relu": relu_fn}
if response_fn_total_count in response_functions_total_count:
response_fn_total_count = response_functions_total_count[response_fn_total_count]
else:
raise ValueError("Invalid response function for total_count. Please choose from 'exp', 'softplus' or relu.")
raise ValueError(
"Invalid response function for total_count. Please choose from 'exp', 'softplus' or 'relu'.")

# Specify Response Functions for probs
if response_fn_probs == "sigmoid":
response_fn_probs = sigmoid_fn
response_functions_probs = {"sigmoid": sigmoid_fn}
if response_fn_probs in response_functions_probs:
response_fn_probs = response_functions_probs[response_fn_probs]
else:
raise ValueError("Invalid response function for probs. Please select 'sigmoid'.")
raise ValueError(
"Invalid response function for probs. Please select 'sigmoid'.")

# Set the parameters specific to the distribution
distribution = NegativeBinomial_Torch
Expand Down
25 changes: 14 additions & 11 deletions lightgbmlss/distributions/Poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,27 @@ class Poisson(DistributionClass):
Response function for transforming the distributional parameters to the correct support. Options are
"exp" (exponential), "softplus" (softplus) or "relu" (rectified linear unit).
loss_fn: str
Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score).
Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable.
Hence, using the CRPS disregards any variation in the curvature of the loss function.
Loss function. Options are "nll" (negative log-likelihood).
"""
def __init__(self,
stabilization: str = "None",
response_fn: str = "relu",
loss_fn: str = "nll"
):
# # Specify Response Functions
if response_fn == "exp":
response_fn = exp_fn
elif response_fn == "softplus":
response_fn = softplus_fn
elif response_fn == "relu":
response_fn = relu_fn

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll"]:
raise ValueError("Invalid loss function. Please select 'nll'.")

# Specify Response Functions
response_functions = {"exp": exp_fn, "softplus": softplus_fn, "relu": relu_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function for total_count. Please choose from 'exp', 'softplus' or relu.")
raise ValueError(
"Invalid response function for total_count. Please choose from 'exp', 'softplus' or 'relu'.")

# Set the parameters specific to the distribution
distribution = Poisson_Torch
Expand Down
73 changes: 38 additions & 35 deletions lightgbmlss/distributions/SplineFlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,45 +59,22 @@ def __init__(self,
loss_fn: str = "nll"
):

# Check if stabilization method is valid.
if not isinstance(stabilization, str):
raise ValueError("stabilization must be a string.")
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Options are 'None', 'MAD' or 'L2'.")

# Check if loss function is valid.
if not isinstance(loss_fn, str):
raise ValueError("loss_fn must be a string.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss_fn. Options are 'nll' or 'crps'.")

# Number of parameters
if not isinstance(order, str):
raise ValueError("order must be a string.")
if order == "quadratic":
n_params = 2*count_bins + (count_bins-1)
elif order == "linear":
n_params = 3*count_bins + (count_bins-1)
else:
raise ValueError("Invalid order specification. Options are 'linear' or 'quadratic'.")

# Specify Target Transform
if not isinstance(target_support, str):
raise ValueError("target_support must be a string.")
if target_support == "real":
target_transform = identity_transform
discrete = False
elif target_support == "positive":
target_transform = SoftplusTransform()
discrete = False
elif target_support == "positive_integer":
target_transform = SoftplusTransform()
discrete = True
elif target_support == "unit_interval":
target_transform = SigmoidTransform()
discrete = False

transforms = {
"real": (identity_transform, False),
"positive": (SoftplusTransform(), False),
"positive_integer": (SoftplusTransform(), True),
"unit_interval": (SigmoidTransform(), False)
}

if target_support in transforms:
target_transform, discrete = transforms[target_support]
else:
raise ValueError("Invalid target_support. Options are 'real', 'positive', 'positive_integer' or 'unit_interval'.")
raise ValueError(
"Invalid target_support. Options are 'real', 'positive', 'positive_integer', or 'unit_interval'.")

# Check if count_bins is valid
if not isinstance(count_bins, int):
Expand All @@ -109,6 +86,32 @@ def __init__(self,
if not isinstance(bound, float):
raise ValueError("bound must be a float.")

# Number of parameters
if not isinstance(order, str):
raise ValueError("order must be a string.")

order_params = {
"quadratic": 2 * count_bins + (count_bins - 1),
"linear": 3 * count_bins + (count_bins - 1)
}

if order in order_params:
n_params = order_params[order]
else:
raise ValueError("Invalid order specification. Options are 'linear' or 'quadratic'.")

# Check if stabilization method is valid.
if not isinstance(stabilization, str):
raise ValueError("stabilization must be a string.")
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Options are 'None', 'MAD' or 'L2'.")

# Check if loss function is valid.
if not isinstance(loss_fn, str):
raise ValueError("loss_fn must be a string.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss_fn. Options are 'nll' or 'crps'.")

# Specify parameter dictionary
param_dict = {f"param_{i + 1}": identity_fn for i in range(n_params)}
torch.distributions.Distribution.set_default_validate_args(False)
Expand Down
Loading

0 comments on commit 149f402

Please sign in to comment.