Skip to content
Draft
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
mapie_clf = MultiLabelClassificationController(
predict_function=clf.predict_proba,
method=method,
metric_control="recall",
risk="recall",
target_level=1 - alpha,
confidence_level=0.9,
rcps_bound=bound,
Expand Down Expand Up @@ -236,7 +236,7 @@
mapie_clf = MultiLabelClassificationController(
predict_function=clf.predict_proba,
method="ltt",
metric_control="precision",
risk="precision",
target_level=1 - alpha,
confidence_level=0.9,
)
Expand Down
141 changes: 73 additions & 68 deletions mapie/risk_control/multi_label_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,15 @@
from sklearn.utils import check_random_state
from sklearn.utils.validation import _check_y, _num_samples, indexable

from mapie.utils import (
_check_alpha,
_check_n_jobs,
_check_verbose,
check_is_fitted,
)
from mapie.utils import _check_alpha, _check_n_jobs, _check_verbose, check_is_fitted

from .methods import (
find_best_predict_param,
find_precision_best_predict_param,
get_r_hat_plus,
ltt_procedure,
)
from .risks import compute_risk_precision, compute_risk_recall
from .risks import precision, recall


class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):
Expand All @@ -44,14 +39,16 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):
and each array is of shape (n_samples, 2) corresponding to the
probabilities of the negative and positive class for each label.

metric_control : Optional[str]
Metric to control. Either "recall" or "precision".
By default ``recall``.
risk : str
The risk metric to control ("precision" or "recall").
The selected risk determines which conformal prediction methods are valid:
- "precision" implies that method must be "ltt"
- "recall" implies that method can be "crc" (default) or "rcps"

method : Optional[str]
Method to use for the prediction sets. If `metric_control` is
Method to use for the prediction sets. If `risk` is
"recall", then the method can be either "crc" (default) or "rcps".
If `metric_control` is "precision", then the method used to control
If `risk` is "precision", then the method used to control
the precision is "ltt".

target_level : Optional[Union[float, Iterable[float]]]
Expand Down Expand Up @@ -134,15 +131,15 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):
List of list of all index that satisfy fwer controlling.
This attribute is computed when the user wants to
control precision score.
Only relevant when metric_control="precision" as it uses
Only relevant when risk="precision" as it uses
learn then test (ltt) procedure.
Contains n_alpha lists.

valid_predict_params: List[List[Any]]
List of list of all thresholds that satisfy fwer controlling.
This attribute is computed when the user wants to
control precision score.
Only relevant when metric_control="precision" as it uses
Only relevant when risk="precision" as it uses
learn then test (ltt) procedure.
Contains n_alpha lists.

Expand Down Expand Up @@ -182,7 +179,12 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):
[False True False]]
"""

valid_methods_by_metric_ = {"precision": ["ltt"], "recall": ["rcps", "crc"]}
risk_choice_map = {
"precision": precision,
"recall": recall,
}

valid_methods_by_metric_ = {"precision": ["ltt"], "recall": ["crc", "rcps"]}
valid_methods = list(chain(*valid_methods_by_metric_.values()))
valid_metric_ = list(valid_methods_by_metric_.keys())
valid_bounds_ = ["hoeffding", "bernstein", "wsr", None]
Expand All @@ -195,7 +197,7 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):
def __init__(
self,
predict_function: Callable[[ArrayLike], Union[list[NDArray], NDArray]],
metric_control: Optional[str] = "recall",
risk: str = "recall",
method: Optional[str] = None,
target_level: Union[float, Iterable[float]] = 0.9,
confidence_level: Optional[float] = None,
Expand All @@ -205,16 +207,21 @@ def __init__(
verbose: int = 0,
) -> None:
self._predict_function = predict_function
self.metric_control = metric_control
self._risk_name = risk
self._risk = self._check_and_convert_risk(risk)
self.method = method
self._check_metric_control()
self._check_method()

alpha = []
for target in (
target_level if isinstance(target_level, Iterable) else [target_level]
):
alpha.append(1 - target) # higher is better for precision/recall
assert self._risk.higher_is_better, (
"Current implemented risks (precision and recall) are defined such that "
"'higher is better'. The 'lower is better' case is not implemented."
)
alpha.append(1 - target) # for higher is better only

self._alpha = np.array(_check_alpha(alpha))

self._check_confidence_level(confidence_level)
Expand All @@ -235,6 +242,16 @@ def is_fitted(self):
"""Returns True if the controller is fitted"""
return self._is_fitted

def _check_and_convert_risk(self, risk):
"""Check and convert risk parameter."""

if risk not in self.risk_choice_map:
raise ValueError(
f"risk must be one of: {list(self.risk_choice_map.keys())}"
)

return self.risk_choice_map[risk]

def _check_parameters(self) -> None:
"""
Check n_jobs, verbose, and random_states.
Expand All @@ -258,18 +275,15 @@ def _check_method(self) -> None:
Raise error if the name of the method is not
in self.valid_methods_
"""
self.method = cast(str, self.method)
self.metric_control = cast(str, self.metric_control)
valid_methods = self.valid_methods_by_metric_[self._risk_name]
if self.method is None:
self.method = valid_methods[0]
return

if self.method not in self.valid_methods_by_metric_[self.metric_control]:
if self.method not in valid_methods:
raise ValueError(
"Invalid method for metric: "
+ "You are controlling "
+ self.metric_control
+ " and you are using invalid method: "
+ self.method
+ ". Use instead: "
+ "".join(self.valid_methods_by_metric_[self.metric_control])
f"Invalid method '{self.method}' for risk '{self._risk_name}'. "
f"Valid methods are: {valid_methods}."
)

def _check_all_labelled(self, y: NDArray) -> None:
Expand Down Expand Up @@ -307,11 +321,12 @@ def _check_confidence_level(self, confidence_level: Optional[float]):
Raises
------
ValueError
If confidence_level is ``None`` and method is RCPS or
if confidence_level is not in [0, 1] and method
is RCPS.
If confidence_level is ``None`` and method requires it
(RCPS or LTT) or if confidence_level is not in (0, 1).

Warning
If confidence_level is not ``None`` and method is CRC
(because it will be ignored).
"""
if (not isinstance(confidence_level, float)) and (confidence_level is not None):
raise ValueError(
Expand All @@ -321,18 +336,16 @@ def _check_confidence_level(self, confidence_level: Optional[float]):
if confidence_level is None:
raise ValueError(
"Invalid confidence_level. "
"confidence_level cannot be ``None`` when controlling "
"Recall with RCPS or Precision with LTT"
f"confidence_level cannot be ``None`` when using method '{self.method}'."
)
elif (confidence_level <= 0) or (confidence_level >= 1):
raise ValueError(
"Invalid confidence_level. confidence_level must be in ]0, 1["
"Invalid confidence_level. confidence_level must be in (0, 1)"
)
if (self.method == "crc") and (confidence_level is not None):
warnings.warn(
"WARNING: you are using crc method, hence "
+ "even if the confidence_level is not ``None``, it won't be"
+ "taken into account"
"WARNING: you are using method 'crc', hence "
"even if confidence_level is not ``None``, it will be ignored."
)

def _check_valid_index(self, alpha: NDArray):
Expand Down Expand Up @@ -361,7 +374,7 @@ def _check_compute_risks_first_call(self) -> bool:
bool
True if it is the first time, else False.
"""
return not hasattr(self, "risks")
return not hasattr(self, "_risks")

def _check_bound(self, bound: Optional[str]):
"""
Expand All @@ -384,24 +397,6 @@ def _check_bound(self, bound: Optional[str]):
+ "taken into account."
)

def _check_metric_control(self):
"""
Check that the metrics to control are valid
(can be a string or list of string.)
"""
if self.metric_control not in self.valid_metric_:
raise ValueError(
"Invalid metric. "
"Allowed scores must be in the following list "
+ ", ".join(self.valid_metric_)
)

if self.method is None:
if self.metric_control == "recall":
self.method = "crc"
else: # self.metric_control == "precision"
self.method = "ltt"

def _transform_pred_proba(
self, y_pred_proba: Union[Sequence[NDArray], NDArray]
) -> NDArray:
Expand Down Expand Up @@ -473,25 +468,35 @@ def compute_risks(
y_pred_proba = self._predict_function(X)
y_pred_proba_array = self._transform_pred_proba(y_pred_proba)

if self.metric_control == "recall":
risk = compute_risk_recall(self.predict_params, y_pred_proba_array, y)
else: # self.metric_control == "precision"
risk = compute_risk_precision(self.predict_params, y_pred_proba_array, y)
n_lambdas = len(self.predict_params)
n_samples = len(y_pred_proba_array)

y_pred_proba_array_repeat = np.repeat(y_pred_proba_array, n_lambdas, axis=2)
y_pred = (y_pred_proba_array_repeat > self.predict_params).astype(int)

risk = np.zeros((n_samples, n_lambdas))
for index_sample in range(n_samples):
for index_lambda in range(n_lambdas):
risk[index_sample, index_lambda], _ = (
self._risk.get_value_and_effective_sample_size(
y[index_sample, :], y_pred[index_sample, :, index_lambda]
)
)

if first_call or _refit:
self.risks = risk
self._risks = risk
else:
self.risks = np.vstack((self.risks, risk))
self._risks = np.vstack((self._risks, risk))

return self

def compute_best_predict_param(self) -> MultiLabelClassificationController:
"""
Compute optimal predict_params based on the computed risks.
"""
if self.metric_control == "precision":
self.n_obs = len(self.risks)
self.r_hat = self.risks.mean(axis=0)
if self._risk == precision:
self.n_obs = len(self._risks)
self.r_hat = self._risks.mean(axis=0)
self.valid_index, _ = ltt_procedure(
np.expand_dims(self.r_hat, axis=0),
np.expand_dims(self._alpha, axis=0),
Expand All @@ -505,9 +510,9 @@ def compute_best_predict_param(self) -> MultiLabelClassificationController:
self.best_predict_param, _ = find_precision_best_predict_param(
self.r_hat, self.valid_index, self.predict_params
)
else:
elif self._risk == recall:
self.r_hat, self.r_hat_plus = get_r_hat_plus(
self.risks,
self._risks,
self.predict_params,
self.method,
self._rcps_bound,
Expand Down
Loading
Loading