Skip to content

Commit

Permalink
update quantile computation in quantilegbm and quantilelgbm
Browse files Browse the repository at this point in the history
  • Loading branch information
basaks committed Jul 8, 2024
1 parent 97b98bd commit 1c76ac3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
2 changes: 1 addition & 1 deletion configs/ref_gradientboost_quantiles.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ learning:
'min_weight_fraction_leaf': Real(0.0, 0.5, prior='uniform')

prediction:
quantiles: 0.95
quantiles: 0.9
outbands: 4


Expand Down
4 changes: 3 additions & 1 deletion configs/ref_quantile_lgbm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ learning:
target_transform: identity
random_state: 1
max_depth: 20
upper_alpha: 0.95
lower_alpha: 0.05


prediction:
prediction_template: configs/data/sirsam/dem_foc2.tif
quantiles: 0.95
quantiles: 0.90
outbands: 4


Expand Down
33 changes: 25 additions & 8 deletions uncoverml/optimise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def __init__(self, target_transform='identity',
self.alpha = alpha
self.upper_alpha = upper_alpha
self.lower_alpha = lower_alpha
self.interval = upper_alpha - lower_alpha

@staticmethod
def collect_prediction(regressor, X_test):
Expand All @@ -658,7 +659,7 @@ def fit(self, X, y, *args, **kwargs):
def predict(self, X, *args, **kwargs):
return self.predict_dist(X, *args, **kwargs)[0]

def predict_dist(self, X, interval=0.95, *args, ** kwargs):
def predict_dist(self, X, interval=0.90, *args, ** kwargs):
Ey = self.gb.predict(X)

ql_ = self.collect_prediction(self.gb_quantile_lower, X)
Expand All @@ -667,9 +668,17 @@ def predict_dist(self, X, interval=0.95, *args, ** kwargs):
Vy = ((qu_ - ql_) / (norm.ppf(self.upper_alpha) - norm.ppf(self.lower_alpha))) ** 2

# to make gbm quantile model consistent with other quantile based models
ql, qu = norm.interval(interval, loc=Ey, scale=np.sqrt(Vy))

return Ey, Vy, ql, qu
if interval == self.interval:
return Ey, Vy, ql_, qu_
else:
# if the interval matches (upper_alpha-lower_alpha), we don't need to compute ql, qu
# and also don't need to make assumpition of normal distribution to compute ql and qu
log.warn("===============================================")
log.warn("Used normal distribution assumption to compute quantiles."
" Using quantiles=(upper_alpha-lower_alpha) will remove this requirement!")
log.warn("===============================================")
ql, qu = norm.interval(self.interval, loc=Ey, scale=np.sqrt(Vy))
return Ey, Vy, ql, qu


class GBMReg(GradientBoostingRegressor, TagsMixin):
Expand Down Expand Up @@ -820,7 +829,7 @@ def fit(self, X, y, *args, **kwargs):
def predict(self, X, *args, **kwargs):
return self.predict_dist(X, *args, **kwargs)[0]

def predict_dist(self, X, interval=0.95, *args, ** kwargs):
def predict_dist(self, X, interval=0.9, *args, ** kwargs):
Ey = self.gb.predict(X)

ql_ = self.collect_prediction(self.gb_quantile_lower, X)
Expand All @@ -829,9 +838,17 @@ def predict_dist(self, X, interval=0.95, *args, ** kwargs):
Vy = ((qu_ - ql_) / (norm.ppf(self.upper_alpha) - norm.ppf(self.lower_alpha))) ** 2

# to make gbm quantile model consistent with other quantile based models
ql, qu = norm.interval(interval, loc=Ey, scale=np.sqrt(Vy))

return Ey, Vy, ql, qu
if interval == self.interval:
return Ey, Vy, ql_, qu_
else:
# if the interval matches (upper_alpha-lower_alpha), we don't need to compute ql, qu
# and also don't need to make assumpition of normal distribution to compute ql and qu
log.warn("===============================================")
log.warn("Used normal distribution assumption to compute quantiles."
" Using quantiles=(upper_alpha-lower_alpha) will remove this requirement!")
log.warn("===============================================")
ql, qu = norm.interval(self.interval, loc=Ey, scale=np.sqrt(Vy))
return Ey, V, ql, qu


class CatBoostWrapper(CatBoostRegressor, TagsMixin):
Expand Down

0 comments on commit 1c76ac3

Please sign in to comment.