Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 54 additions & 47 deletions mapie/risk_control/multi_label_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from sklearn.utils import check_random_state
from sklearn.utils.validation import _check_y, _num_samples, indexable

from mapie.utils import (
_check_alpha,
_check_n_jobs,
_check_verbose,
check_is_fitted,
)
from mapie.utils import _check_alpha, _check_n_jobs, _check_verbose, check_is_fitted

from .methods import (
find_best_predict_param,
Expand All @@ -38,40 +33,45 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):

Parameters
----------
predict_function : Callable[[ArrayLike], NDArray]
predict_function : Callable[[ArrayLike], Union[list[NDArray], NDArray]]
predict_proba method of a fitted multi-label classifier.
It should return a list of arrays where the length of the list is n_classes
and each array is of shape (n_samples, 2) corresponding to the
probabilities of the negative and positive class for each label.
It can return either:
- a list of arrays of length n_classes where each array is of shape
(n_samples, 2) with probabilities of the negative and positive class
(as output by ``MultiOutputClassifier``), or
- an ndarray of shape (n_samples, n_classes) or (n_samples, n_classes, 2)
containing positive probabilities, or positive and negative probabilities
(assuming last dimension is [neg, pos]).

metric_control : Optional[str]
Metric to control. Either "recall" or "precision".
By default ``recall``.

method : Optional[str]
Method to use for the prediction sets. If `metric_control` is
"recall", then the method can be either "crc" (default) or "rcps".
If `metric_control` is "precision", then the method used to control
the precision is "ltt".
"recall", the method can be either "crc" (default) or "rcps".
If `metric_control` is "precision", the method used is "ltt".
If ``None``, the default is "crc" for recall and "ltt" for precision.

target_level : Optional[Union[float, Iterable[float]]]
The minimum performance level for the metric. Must be between 0 and 1.
Can be a float or a list of floats.
Can be a float or any iterable of floats.
By default ``0.9``.

confidence_level : Optional[float]
Can be a float, or ``None``. If using method="rcps", then it
can not be set to ``None``.
Between 0 and 1, the level of certainty at which we compute
the Upper Confidence Bound of the average risk.
Higher ``confidence_level`` produce larger (more conservative) prediction
sets.
By default ``None``.
Can be a float, or ``None``. If using method="rcps" or method="ltt"
(precision control), then it cannot be set to ``None`` and must lie in
(0, 1). Between 0 and 1, the level of certainty at which we compute
the Upper Confidence Bound of the average risk. Higher ``confidence_level``
produce larger (more conservative) prediction sets. By default ``None``.

rcps_bound : Optional[Union[str, ``None``]]
Method used to compute the Upper Confidence Bound of the
average risk. Only necessary with the RCPS method.
By default ``None``.
average risk. Only necessary with the RCPS method. If provided when
using CRC or LTT it is ignored and a warning is raised. By default ``None``.
predict_params : Optional[ArrayLike]
Array of parameters (thresholds λ) to consider for controlling the risk.
Defaults to np.arange(0, 1, 0.01). Length sets ``n_predict_params``.


n_jobs: Optional[int]
Expand Down Expand Up @@ -186,8 +186,6 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin):
valid_methods = list(chain(*valid_methods_by_metric_.values()))
valid_metric_ = list(valid_methods_by_metric_.keys())
valid_bounds_ = ["hoeffding", "bernstein", "wsr", None]
predict_params = np.arange(0, 1, 0.01)
n_predict_params = len(predict_params)
fit_attributes = ["risks"]
sigma_init = 0.25 # Value given in the paper [1]
cal_size = 0.3
Expand All @@ -200,6 +198,7 @@ def __init__(
target_level: Union[float, Iterable[float]] = 0.9,
confidence_level: Optional[float] = None,
rcps_bound: Optional[Union[str, None]] = None,
predict_params: ArrayLike = np.arange(0, 1, 0.01),
n_jobs: Optional[int] = None,
random_state: Optional[Union[int, np.random.RandomState]] = None,
verbose: int = 0,
Expand All @@ -223,6 +222,9 @@ def __init__(
self._check_bound(rcps_bound)
self._rcps_bound = rcps_bound

self.predict_params = np.asarray(predict_params)
self.n_predict_params = len(self.predict_params)

self.n_jobs = n_jobs
self.random_state = random_state
self.verbose = verbose
Expand Down Expand Up @@ -296,7 +298,7 @@ def _check_all_labelled(self, y: NDArray) -> None:
def _check_confidence_level(self, confidence_level: Optional[float]):
"""
Check that confidence_level is not ``None`` when the
method is RCPS and that it is between 0 and 1.
method is RCPS or LTT and that it is between 0 and 1.

Parameters
----------
Expand All @@ -307,9 +309,9 @@ def _check_confidence_level(self, confidence_level: Optional[float]):
Raises
------
ValueError
If confidence_level is ``None`` and method is RCPS or
if confidence_level is not in [0, 1] and method
is RCPS.
If confidence_level is ``None`` and method is RCPS or LTT, or
if confidence_level is not in [0, 1] and method
is RCPS or LTT.
Warning
If confidence_level is not ``None`` and method is CRC
"""
Expand Down Expand Up @@ -405,28 +407,33 @@ def _check_metric_control(self):
def _transform_pred_proba(
self, y_pred_proba: Union[Sequence[NDArray], NDArray]
) -> NDArray:
"""If the output of the predict_proba is a list of arrays (output of
the ``predict_proba`` of ``MultiOutputClassifier``) we transform it
into an array of shape (n_samples, n_classes, 1), otherwise, we add
one dimension at the end.

Parameters
----------
y_pred_proba : Union[List, NDArray]
Output of the multi-label classifier.

Returns
-------
NDArray of shape (n_samples, n_classes, 1)
Output of the model ready for risk computation.
"""Transform predict_function outputs to shape (n_samples, n_classes, 1)
containing positive-class probabilities.

- If a list of arrays is provided (e.g., MultiOutputClassifier), each
array is expected to be of shape (n_samples, 2); we take the positive
class column.
- If an ndarray is provided, it can be of shape (n_samples, n_classes)
containing positive-class probabilities, or
(n_samples, n_classes, 2) containing both class probabilities.
"""
if isinstance(y_pred_proba, np.ndarray):
y_pred_proba_array = y_pred_proba
if y_pred_proba.ndim == 3:
# assume last dim is [neg, pos], keep positive class
y_pred_pos = y_pred_proba[..., 1]
elif y_pred_proba.ndim == 2:
# already positive-class probabilities
y_pred_pos = y_pred_proba
else:
raise ValueError(
"When predict_proba returns an ndarray, it must have 2 or 3 "
"dimensions: (n_samples, n_classes) or (n_samples, n_classes, 2)."
)
else:
y_pred_proba_stacked = np.stack(y_pred_proba, axis=0)[:, :, 1]
y_pred_proba_array = np.moveaxis(y_pred_proba_stacked, 0, -1)
# list of length n_classes with (n_samples, 2) arrays
y_pred_pos = np.stack([proba[:, 1] for proba in y_pred_proba], axis=1)

return np.expand_dims(y_pred_proba_array, axis=2)
return np.expand_dims(y_pred_pos, axis=2)

def compute_risks(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ def predict_proba(self, X: ArrayLike) -> NDArray:
return proba_out


class ArrayOutputModel3D:
"""
Dummy model returning ndarray of shape (n_samples, n_classes, 2)
to test ndarray handling in _transform_pred_proba.
"""

def __init__(self):
self.trained_ = True

def predict_proba(self, X: ArrayLike) -> NDArray:
X = np.asarray(X)
# 3 labels; positive class probabilities: 0.6, 0.7, 0.8
base = np.array([[0.4, 0.6], [0.3, 0.7], [0.2, 0.8]])
return np.repeat(base[np.newaxis, ...], len(X), axis=0)


X_toy = np.arange(9).reshape(-1, 1)
y_toy = np.stack(
[
Expand Down Expand Up @@ -738,6 +754,64 @@ def test_toy_dataset_predictions(strategy: str) -> None:
np.testing.assert_allclose(y_ps[:, :, 0], y_toy_mapie[strategy], rtol=1e-6)


def test_transform_pred_proba_ndarray_2d() -> None:
"""Ensure 2D ndarray predict_proba is accepted and reshaped."""
y_pred = np.array([[0.6, 0.7, 0.8], [0.4, 0.5, 0.6]])
clf = MultiLabelClassificationController(predict_function=toy_predict_function)
y_out = clf._transform_pred_proba(y_pred)
assert y_out.shape == (2, 3, 1)
np.testing.assert_allclose(y_out[..., 0], y_pred)


def test_transform_pred_proba_ndarray_3d() -> None:
"""Ensure 3D ndarray predict_proba keeps positive class column."""
model = ArrayOutputModel3D()
clf = MultiLabelClassificationController(predict_function=model.predict_proba)
proba = model.predict_proba(X_toy)
y_out = clf._transform_pred_proba(proba)
assert y_out.shape == (len(X_toy), 3, 1)
np.testing.assert_allclose(y_out[..., 0], proba[..., 1])


def test_transform_pred_proba_list_of_arrays() -> None:
"""Ensure list-of-arrays predict_proba (MultiOutputClassifier style) works."""
clf = MultiLabelClassificationController(predict_function=toy_predict_function)
proba_list = toy_predict_function(X_toy) # MultiOutputClassifier returns list
y_out = clf._transform_pred_proba(proba_list)
assert y_out.shape == (len(X_toy), y_toy.shape[1], 1)
expected = np.stack([p[:, 1] for p in proba_list], axis=1)
np.testing.assert_allclose(y_out[..., 0], expected)


def test_transform_pred_proba_ndarray_invalid_dims() -> None:
"""Ensure ndarray with invalid dimensionality raises a ValueError."""
clf = MultiLabelClassificationController(predict_function=toy_predict_function)
wrong_shape = np.array([0.6, 0.4, 0.8]) # 1D array instead of 2D/3D
with pytest.raises(
ValueError,
match=r"When predict_proba returns an ndarray, it must have 2 or 3 dimensions.*",
):
clf._transform_pred_proba(wrong_shape)


@pytest.mark.parametrize(
"metric_control,method", [("recall", "crc"), ("precision", "ltt")]
)
def test_calibrate_with_ndarray_predict_proba(metric_control: str, method: str) -> None:
"""End-to-end check that ndarray predict_proba works for both metrics."""
model = ArrayOutputModel3D()
mapie_clf = MultiLabelClassificationController(
predict_function=model.predict_proba,
metric_control=metric_control,
method=method,
confidence_level=0.9 if method != "crc" else None,
)
mapie_clf.calibrate(X_toy, y_toy)
y_ps = mapie_clf.predict(X_toy)
assert y_ps.shape[0] == len(X_toy)
assert y_ps.shape[1] == y_toy.shape[1]


@pytest.mark.parametrize("method", ["rcps", "crc"])
def test_error_wrong_method_metric_precision(method: str) -> None:
"""
Expand Down