diff --git a/examples/risk_control/2-advanced-analysis/plot_tutorial_risk_control.py b/examples/risk_control/2-advanced-analysis/plot_tutorial_risk_control.py index 3416f1825..5dce494cf 100644 --- a/examples/risk_control/2-advanced-analysis/plot_tutorial_risk_control.py +++ b/examples/risk_control/2-advanced-analysis/plot_tutorial_risk_control.py @@ -128,7 +128,7 @@ mapie_clf = MultiLabelClassificationController( predict_function=clf.predict_proba, method=method, - metric_control="recall", + risk="recall", target_level=1 - alpha, confidence_level=0.9, rcps_bound=bound, @@ -236,7 +236,7 @@ mapie_clf = MultiLabelClassificationController( predict_function=clf.predict_proba, method="ltt", - metric_control="precision", + risk="precision", target_level=1 - alpha, confidence_level=0.9, ) diff --git a/mapie/risk_control/multi_label_classification.py b/mapie/risk_control/multi_label_classification.py index b9d94fda0..0eb2785ac 100644 --- a/mapie/risk_control/multi_label_classification.py +++ b/mapie/risk_control/multi_label_classification.py @@ -10,12 +10,7 @@ from sklearn.utils import check_random_state from sklearn.utils.validation import _check_y, _num_samples, indexable -from mapie.utils import ( - _check_alpha, - _check_n_jobs, - _check_verbose, - check_is_fitted, -) +from mapie.utils import _check_alpha, _check_n_jobs, _check_verbose, check_is_fitted from .methods import ( find_best_predict_param, @@ -23,7 +18,7 @@ get_r_hat_plus, ltt_procedure, ) -from .risks import compute_risk_precision, compute_risk_recall +from .risks import precision, recall class MultiLabelClassificationController(BaseEstimator, ClassifierMixin): @@ -44,14 +39,16 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin): and each array is of shape (n_samples, 2) corresponding to the probabilities of the negative and positive class for each label. - metric_control : Optional[str] - Metric to control. Either "recall" or "precision". - By default ``recall``. + risk : str + The risk metric to control ("precision" or "recall"). + The selected risk determines which conformal prediction methods are valid: + - "precision" implies that method must be "ltt" + - "recall" implies that method can be "crc" (default) or "rcps" method : Optional[str] - Method to use for the prediction sets. If `metric_control` is + Method to use for the prediction sets. If `risk` is "recall", then the method can be either "crc" (default) or "rcps". - If `metric_control` is "precision", then the method used to control + If `risk` is "precision", then the method used to control the precision is "ltt". target_level : Optional[Union[float, Iterable[float]]] @@ -134,7 +131,7 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin): List of list of all index that satisfy fwer controlling. This attribute is computed when the user wants to control precision score. - Only relevant when metric_control="precision" as it uses + Only relevant when risk="precision" as it uses learn then test (ltt) procedure. Contains n_alpha lists. @@ -142,7 +139,7 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin): List of list of all thresholds that satisfy fwer controlling. This attribute is computed when the user wants to control precision score. - Only relevant when metric_control="precision" as it uses + Only relevant when risk="precision" as it uses learn then test (ltt) procedure. Contains n_alpha lists. @@ -182,7 +179,12 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin): [False True False]] """ - valid_methods_by_metric_ = {"precision": ["ltt"], "recall": ["rcps", "crc"]} + risk_choice_map = { + "precision": precision, + "recall": recall, + } + + valid_methods_by_metric_ = {"precision": ["ltt"], "recall": ["crc", "rcps"]} valid_methods = list(chain(*valid_methods_by_metric_.values())) valid_metric_ = list(valid_methods_by_metric_.keys()) valid_bounds_ = ["hoeffding", "bernstein", "wsr", None] @@ -195,7 +197,7 @@ class MultiLabelClassificationController(BaseEstimator, ClassifierMixin): def __init__( self, predict_function: Callable[[ArrayLike], Union[list[NDArray], NDArray]], - metric_control: Optional[str] = "recall", + risk: str = "recall", method: Optional[str] = None, target_level: Union[float, Iterable[float]] = 0.9, confidence_level: Optional[float] = None, @@ -205,16 +207,21 @@ def __init__( verbose: int = 0, ) -> None: self._predict_function = predict_function - self.metric_control = metric_control + self._risk_name = risk + self._risk = self._check_and_convert_risk(risk) self.method = method - self._check_metric_control() self._check_method() alpha = [] for target in ( target_level if isinstance(target_level, Iterable) else [target_level] ): - alpha.append(1 - target) # higher is better for precision/recall + assert self._risk.higher_is_better, ( + "Current implemented risks (precision and recall) are defined such that " + "'higher is better'. The 'lower is better' case is not implemented." + ) + alpha.append(1 - target) # for higher is better only + self._alpha = np.array(_check_alpha(alpha)) self._check_confidence_level(confidence_level) @@ -235,6 +242,16 @@ def is_fitted(self): """Returns True if the controller is fitted""" return self._is_fitted + def _check_and_convert_risk(self, risk): + """Check and convert risk parameter.""" + + if risk not in self.risk_choice_map: + raise ValueError( + f"risk must be one of: {list(self.risk_choice_map.keys())}" + ) + + return self.risk_choice_map[risk] + def _check_parameters(self) -> None: """ Check n_jobs, verbose, and random_states. @@ -258,18 +275,15 @@ def _check_method(self) -> None: Raise error if the name of the method is not in self.valid_methods_ """ - self.method = cast(str, self.method) - self.metric_control = cast(str, self.metric_control) + valid_methods = self.valid_methods_by_metric_[self._risk_name] + if self.method is None: + self.method = valid_methods[0] + return - if self.method not in self.valid_methods_by_metric_[self.metric_control]: + if self.method not in valid_methods: raise ValueError( - "Invalid method for metric: " - + "You are controlling " - + self.metric_control - + " and you are using invalid method: " - + self.method - + ". Use instead: " - + "".join(self.valid_methods_by_metric_[self.metric_control]) + f"Invalid method '{self.method}' for risk '{self._risk_name}'. " + f"Valid methods are: {valid_methods}." ) def _check_all_labelled(self, y: NDArray) -> None: @@ -307,11 +321,12 @@ def _check_confidence_level(self, confidence_level: Optional[float]): Raises ------ ValueError - If confidence_level is ``None`` and method is RCPS or - if confidence_level is not in [0, 1] and method - is RCPS. + If confidence_level is ``None`` and method requires it + (RCPS or LTT) or if confidence_level is not in (0, 1). + Warning If confidence_level is not ``None`` and method is CRC + (because it will be ignored). """ if (not isinstance(confidence_level, float)) and (confidence_level is not None): raise ValueError( @@ -321,18 +336,16 @@ def _check_confidence_level(self, confidence_level: Optional[float]): if confidence_level is None: raise ValueError( "Invalid confidence_level. " - "confidence_level cannot be ``None`` when controlling " - "Recall with RCPS or Precision with LTT" + f"confidence_level cannot be ``None`` when using method '{self.method}'." ) elif (confidence_level <= 0) or (confidence_level >= 1): raise ValueError( - "Invalid confidence_level. confidence_level must be in ]0, 1[" + "Invalid confidence_level. confidence_level must be in (0, 1)" ) if (self.method == "crc") and (confidence_level is not None): warnings.warn( - "WARNING: you are using crc method, hence " - + "even if the confidence_level is not ``None``, it won't be" - + "taken into account" + "WARNING: you are using method 'crc', hence " + "even if confidence_level is not ``None``, it will be ignored." ) def _check_valid_index(self, alpha: NDArray): @@ -361,7 +374,7 @@ def _check_compute_risks_first_call(self) -> bool: bool True if it is the first time, else False. """ - return not hasattr(self, "risks") + return not hasattr(self, "_risks") def _check_bound(self, bound: Optional[str]): """ @@ -384,24 +397,6 @@ def _check_bound(self, bound: Optional[str]): + "taken into account." ) - def _check_metric_control(self): - """ - Check that the metrics to control are valid - (can be a string or list of string.) - """ - if self.metric_control not in self.valid_metric_: - raise ValueError( - "Invalid metric. " - "Allowed scores must be in the following list " - + ", ".join(self.valid_metric_) - ) - - if self.method is None: - if self.metric_control == "recall": - self.method = "crc" - else: # self.metric_control == "precision" - self.method = "ltt" - def _transform_pred_proba( self, y_pred_proba: Union[Sequence[NDArray], NDArray] ) -> NDArray: @@ -473,15 +468,25 @@ def compute_risks( y_pred_proba = self._predict_function(X) y_pred_proba_array = self._transform_pred_proba(y_pred_proba) - if self.metric_control == "recall": - risk = compute_risk_recall(self.predict_params, y_pred_proba_array, y) - else: # self.metric_control == "precision" - risk = compute_risk_precision(self.predict_params, y_pred_proba_array, y) + n_lambdas = len(self.predict_params) + n_samples = len(y_pred_proba_array) + + y_pred_proba_array_repeat = np.repeat(y_pred_proba_array, n_lambdas, axis=2) + y_pred = (y_pred_proba_array_repeat > self.predict_params).astype(int) + + risk = np.zeros((n_samples, n_lambdas)) + for index_sample in range(n_samples): + for index_lambda in range(n_lambdas): + risk[index_sample, index_lambda], _ = ( + self._risk.get_value_and_effective_sample_size( + y[index_sample, :], y_pred[index_sample, :, index_lambda] + ) + ) if first_call or _refit: - self.risks = risk + self._risks = risk else: - self.risks = np.vstack((self.risks, risk)) + self._risks = np.vstack((self._risks, risk)) return self @@ -489,9 +494,9 @@ def compute_best_predict_param(self) -> MultiLabelClassificationController: """ Compute optimal predict_params based on the computed risks. """ - if self.metric_control == "precision": - self.n_obs = len(self.risks) - self.r_hat = self.risks.mean(axis=0) + if self._risk == precision: + self.n_obs = len(self._risks) + self.r_hat = self._risks.mean(axis=0) self.valid_index, _ = ltt_procedure( np.expand_dims(self.r_hat, axis=0), np.expand_dims(self._alpha, axis=0), @@ -505,9 +510,9 @@ def compute_best_predict_param(self) -> MultiLabelClassificationController: self.best_predict_param, _ = find_precision_best_predict_param( self.r_hat, self.valid_index, self.predict_params ) - else: + elif self._risk == recall: self.r_hat, self.r_hat_plus = get_r_hat_plus( - self.risks, + self._risks, self.predict_params, self.method, self._rcps_bound, diff --git a/mapie/risk_control/risks.py b/mapie/risk_control/risks.py index 3cd6d2adc..27f4c4be9 100644 --- a/mapie/risk_control/risks.py +++ b/mapie/risk_control/risks.py @@ -4,125 +4,6 @@ import numpy as np from numpy.typing import NDArray -from sklearn.utils.validation import column_or_1d - - -def compute_risk_recall(lambdas: NDArray, y_pred_proba: NDArray, y: NDArray) -> NDArray: - """ - In `PrecisionRecallController` when `metric_control=recall`, - compute the recall per observation for each different - thresholds lambdas. - - Parameters - ---------- - y_pred_proba: NDArray of shape (n_samples, n_labels, 1) - Predicted probabilities for each label and each observation. - - y: NDArray of shape (n_samples, n_labels) - True labels. - - lambdas: NDArray of shape (n_lambdas, ) - Threshold that permit to compute recall. - - Returns - ------- - NDArray of shape (n_samples, n_labels, n_lambdas) - Risks for each observation and each value of lambda. - """ - if y_pred_proba.ndim != 3: - raise ValueError( - "y_pred_proba should be a 3d array, got an array of shape " - "{} instead.".format(y_pred_proba.shape) - ) - if y.ndim != 2: - raise ValueError( - "y should be a 2d array, got an array of shape {} instead.".format( - y_pred_proba.shape - ) - ) - if not np.array_equal(y_pred_proba.shape[:-1], y.shape): - raise ValueError("y and y_pred_proba could not be broadcast.") - lambdas = cast(NDArray, column_or_1d(lambdas)) - - n_lambdas = len(lambdas) - y_pred_proba_repeat = np.repeat(y_pred_proba, n_lambdas, axis=2) - y_pred_th = (y_pred_proba_repeat > lambdas).astype(int) - - y_repeat = np.repeat(y[..., np.newaxis], n_lambdas, axis=2) - risks = 1 - (_true_positive(y_pred_th, y_repeat) / y.sum(axis=1)[:, np.newaxis]) - return risks - - -def compute_risk_precision( - lambdas: NDArray, y_pred_proba: NDArray, y: NDArray -) -> NDArray: - """ - In `PrecisionRecallController` when `metric_control=precision`, - compute the precision per observation for each different - thresholds lambdas. - - Parameters - ---------- - y_pred_proba: NDArray of shape (n_samples, n_labels, 1) - Predicted probabilities for each label and each observation. - - y: NDArray of shape (n_samples, n_labels) - True labels. - - lambdas: NDArray of shape (n_lambdas, ) - Threshold that permit to compute precision score. - - Returns - ------- - NDArray of shape (n_samples, n_labels, n_lambdas) - Risks for each observation and each value of lambda. - """ - if y_pred_proba.ndim != 3: - raise ValueError( - "y_pred_proba should be a 3d array, got an array of shape " - "{} instead.".format(y_pred_proba.shape) - ) - if y.ndim != 2: - raise ValueError( - "y should be a 2d array, got an array of shape {} instead.".format( - y_pred_proba.shape - ) - ) - if not np.array_equal(y_pred_proba.shape[:-1], y.shape): - raise ValueError("y and y_pred_proba could not be broadcast.") - lambdas = cast(NDArray, column_or_1d(lambdas)) - - n_lambdas = len(lambdas) - y_pred_proba_repeat = np.repeat(y_pred_proba, n_lambdas, axis=2) - y_pred_th = (y_pred_proba_repeat > lambdas).astype(int) - - y_repeat = np.repeat(y[..., np.newaxis], n_lambdas, axis=2) - with np.errstate(divide="ignore", invalid="ignore"): - risks = 1 - _true_positive(y_pred_th, y_repeat) / y_pred_th.sum(axis=1) - risks[np.isnan(risks)] = 1 # nan value indicate high risks. - - return risks - - -def _true_positive(y_pred_th: NDArray, y_repeat: NDArray) -> NDArray: - """ - Compute the number of true positive. - - Parameters - ---------- - y_pred_proba : NDArray of shape (n_samples, n_labels, 1) - Predicted probabilities for each label and each observation. - - y: NDArray of shape (n_samples, n_labels) - True labels. - - Returns - ------- - tp: float - The number of true positive. - """ - tp = (y_pred_th * y_repeat).sum(axis=1) - return tp class BinaryClassificationRisk: @@ -130,8 +11,9 @@ class BinaryClassificationRisk: Define a risk (or a performance metric) to be used with the BinaryClassificationController. Predefined instances are implemented, see :data:`mapie.risk_control.precision`, :data:`mapie.risk_control.recall`, - :data:`mapie.risk_control.accuracy` and - :data:`mapie.risk_control.false_positive_rate`. + :data:`mapie.risk_control.accuracy`, + :data:`mapie.risk_control.false_positive_rate`, and + :data:`mapie.risk_control.predicted_positive_fraction`. Here, a binary classification risk (or performance) is defined by an occurrence and a condition. Let's take the example of precision. Precision is the sum of true @@ -177,8 +59,12 @@ class BinaryClassificationRisk: def __init__( self, - risk_occurrence: Callable[[int, int], bool], - risk_condition: Callable[[int, int], bool], + risk_occurrence: Callable[ + [NDArray[np.integer], NDArray[np.integer]], NDArray[np.bool_] + ], + risk_condition: Callable[ + [NDArray[np.integer], NDArray[np.integer]], NDArray[np.bool_] + ], higher_is_better: bool, ): self._risk_occurrence = risk_occurrence @@ -220,32 +106,23 @@ def get_value_and_effective_sample_size( If the risk is not defined (condition never met), the value is set to 1, and the number of effective samples is set to -1. """ - risk_occurrences = np.array( - [ - self._risk_occurrence(y_true_i, y_pred_i) - for y_true_i, y_pred_i in zip(y_true, y_pred) - ] - ) - risk_conditions = np.array( - [ - self._risk_condition(y_true_i, y_pred_i) - for y_true_i, y_pred_i in zip(y_true, y_pred) - ] - ) + risk_occurrences = self._risk_occurrence(y_true, y_pred) + risk_conditions = self._risk_condition(y_true, y_pred) + effective_sample_size = len(y_true) - np.sum(~risk_conditions) # Casting needed for MyPy with Python 3.9 effective_sample_size_int = cast(int, effective_sample_size) - if effective_sample_size_int != 0: + if effective_sample_size_int != 0.0: risk_sum: int = np.sum(risk_occurrences[risk_conditions]) risk_value = risk_sum / effective_sample_size_int + if self.higher_is_better: + risk_value = 1 - risk_value + return risk_value, effective_sample_size_int else: # In this case, the corresponding lambda shouldn't be considered valid. # In the current LTT implementation, providing n_obs=-1 will result # in an infinite p_value, effectively invaliding the lambda - risk_value, effective_sample_size_int = 1, -1 - if self.higher_is_better: - risk_value = 1 - risk_value - return risk_value, effective_sample_size_int + return 1, -1 precision = BinaryClassificationRisk( @@ -256,7 +133,7 @@ def get_value_and_effective_sample_size( accuracy = BinaryClassificationRisk( risk_occurrence=lambda y_true, y_pred: y_pred == y_true, - risk_condition=lambda y_true, y_pred: True, + risk_condition=lambda y_true, y_pred: np.repeat(True, len(y_true)), higher_is_better=True, ) @@ -274,6 +151,6 @@ def get_value_and_effective_sample_size( predicted_positive_fraction = BinaryClassificationRisk( risk_occurrence=lambda y_true, y_pred: y_pred == 1, - risk_condition=lambda y_true, y_pred: True, + risk_condition=lambda y_true, y_pred: np.repeat(True, len(y_true)), higher_is_better=False, ) diff --git a/mapie/tests/risk_control/test_binary_classification_control.py b/mapie/tests/risk_control/test_binary_classification_control.py index b205497c8..5060143b5 100644 --- a/mapie/tests/risk_control/test_binary_classification_control.py +++ b/mapie/tests/risk_control/test_binary_classification_control.py @@ -125,11 +125,12 @@ def test_binary_classification_risk( if effective_sample_size != 0: expected_value = metric_func(y_true, y_pred) expected_n = effective_sample_size + if risk_instance.higher_is_better: + expected_value = 1 - expected_value else: expected_value = 1 expected_n = -1 - if risk_instance.higher_is_better: - expected_value = 1 - expected_value + assert np.isclose(value, expected_value) assert n == expected_n diff --git a/mapie/tests/risk_control/test_precision_recall_control.py b/mapie/tests/risk_control/test_precision_recall_control.py index 94165f775..9b337fc5e 100644 --- a/mapie/tests/risk_control/test_precision_recall_control.py +++ b/mapie/tests/risk_control/test_precision_recall_control.py @@ -22,7 +22,7 @@ "method": str, "rcps_bound": Optional[str], "random_state": Optional[int], - "metric_control": Optional[str], + "risk": str, }, ) @@ -40,7 +40,7 @@ method="crc", rcps_bound=None, random_state=random_state, - metric_control="recall", + risk="recall", ), ), "rcps_wsr": ( @@ -48,7 +48,7 @@ method="rcps", rcps_bound="wsr", random_state=random_state, - metric_control="recall", + risk="recall", ), ), "rcps_hoeffding": ( @@ -56,7 +56,7 @@ method="rcps", rcps_bound="hoeffding", random_state=random_state, - metric_control="recall", + risk="recall", ), ), "rcps_bernstein": ( @@ -64,7 +64,7 @@ method="rcps", rcps_bound="bernstein", random_state=random_state, - metric_control="recall", + risk="recall", ), ), "ltt": ( @@ -72,7 +72,7 @@ method="ltt", rcps_bound=None, random_state=random_state, - metric_control="precision", + risk="precision", ), ), } @@ -216,7 +216,7 @@ def test_valid_metric_method(strategy: str) -> None: mapie_clf = MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, - metric_control=args["metric_control"], + risk=args["risk"], confidence_level=0.9, ) mapie_clf.calibrate(X_toy, y_toy) @@ -249,7 +249,7 @@ def test_predict_output_shape( mapie_clf = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=target_level, confidence_level=confidence_level, @@ -271,7 +271,7 @@ def test_results_for_same_alpha(strategy: str) -> None: mapie_clf = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=[0.9, 0.9], confidence_level=0.9, @@ -293,7 +293,7 @@ def test_results_for_partial_calibrate(strategy: str) -> None: mapie_clf = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=[0.9, 0.9], confidence_level=0.9, @@ -305,7 +305,7 @@ def test_results_for_partial_calibrate(strategy: str) -> None: mapie_clf_partial = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=[0.9, 0.9], confidence_level=0.9, @@ -352,7 +352,7 @@ def test_results_for_alpha_as_float_and_arraylike(strategy: str, alpha: Any) -> mapie_clf = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=alpha[0], confidence_level=0.1, @@ -364,7 +364,7 @@ def test_results_for_alpha_as_float_and_arraylike(strategy: str, alpha: Any) -> mapie_clf = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=alpha[1], confidence_level=0.1, @@ -376,7 +376,7 @@ def test_results_for_alpha_as_float_and_arraylike(strategy: str, alpha: Any) -> mapie_clf = MultiLabelClassificationController( predict_function=multilabel_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=alpha, confidence_level=0.1, @@ -399,7 +399,7 @@ def test_results_single_and_multi_jobs(strategy: str) -> None: mapie_clf_single = MultiLabelClassificationController( predict_function=multilabel_predict_function, n_jobs=1, - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=0.8, confidence_level=0.1, @@ -408,7 +408,7 @@ def test_results_single_and_multi_jobs(strategy: str) -> None: mapie_clf_multi = MultiLabelClassificationController( predict_function=multilabel_predict_function, n_jobs=-1, - metric_control=args["metric_control"], + risk=args["risk"], random_state=args["random_state"], target_level=0.8, confidence_level=0.1, @@ -469,7 +469,7 @@ def test_array_output_model( mapie_clf = MultiLabelClassificationController( predict_function=model.predict_proba, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=random_state, target_level=target_level, confidence_level=confidence_level, @@ -485,7 +485,7 @@ def test_reinit_new_fit(): ) mapie_clf.calibrate(X_toy, y_toy) mapie_clf.calibrate(X_toy, y_toy) - assert len(mapie_clf.risks) == len(X_toy) + assert len(mapie_clf._risks) == len(X_toy) @pytest.mark.parametrize("method", WRONG_METHODS) @@ -534,11 +534,11 @@ def test_bound_error(bound: str) -> None: @pytest.mark.parametrize("metric_control", WRONG_METRICS) def test_metric_error_in_init(metric_control: str) -> None: """Test error for wrong metrics""" - with pytest.raises(ValueError, match=r".*Invalid metric. *"): + with pytest.raises(ValueError, match=r".*risk must be one of:*"): MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, - metric_control=metric_control, + risk=metric_control, ) @@ -559,7 +559,7 @@ def test_error_ltt_confidence_level_null() -> None: MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, - metric_control="precision", + risk="precision", confidence_level=None, ) @@ -583,7 +583,7 @@ def test_error_confidence_level_wrong_value_ltt(confidence_level: Any) -> None: MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, - metric_control="precision", + risk="precision", confidence_level=confidence_level, ) @@ -605,7 +605,7 @@ def test_bound_none_ltt() -> None: MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, - metric_control="precision", + risk="precision", confidence_level=0.9, rcps_bound="wsr", ) @@ -613,7 +613,7 @@ def test_bound_none_ltt() -> None: def test_confidence_level_none_crc() -> None: """Test that a warning is raised when confidence_level is not none with CRC method.""" - with pytest.warns(UserWarning, match=r"WARNING: you are using crc*"): + with pytest.warns(UserWarning, match=r"WARNING: you are using method 'crc'*"): MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, @@ -646,7 +646,7 @@ def test_error_confidence_level_wrong_type_ltt(confidence_level: Any) -> None: MultiLabelClassificationController( predict_function=toy_predict_function, random_state=random_state, - metric_control="precision", + risk="precision", confidence_level=confidence_level, ) @@ -692,7 +692,7 @@ def test_pipeline_compatibility(strategy: str) -> None: mapie_clf = MultiLabelClassificationController( predict_function=pipe.predict_proba, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=random_state, confidence_level=0.9, rcps_bound=args["rcps_bound"], @@ -727,7 +727,7 @@ def test_toy_dataset_predictions(strategy: str) -> None: mapie_clf = MultiLabelClassificationController( predict_function=toy_predict_function, method=args["method"], - metric_control=args["metric_control"], + risk=args["risk"], random_state=random_state, target_level=0.8, confidence_level=0.9, @@ -735,6 +735,8 @@ def test_toy_dataset_predictions(strategy: str) -> None: ) mapie_clf.calibrate(X_toy, y_toy) y_ps = mapie_clf.predict(X_toy) + print(y_ps) + print(y_toy_mapie[strategy]) np.testing.assert_allclose(y_ps[:, :, 0], y_toy_mapie[strategy], rtol=1e-6) @@ -744,11 +746,11 @@ def test_error_wrong_method_metric_precision(method: str) -> None: Test that an error is returned when using a metric with invalid method . """ - with pytest.raises(ValueError, match=r".*Invalid method for metric*"): + with pytest.raises(ValueError, match=r".*Invalid method.*"): MultiLabelClassificationController( predict_function=toy_predict_function, method=method, - metric_control="precision", + risk="precision", ) @@ -758,18 +760,18 @@ def test_check_metric_control(method: str) -> None: Test that an error is returned when using a metric with invalid method . """ - with pytest.raises(ValueError, match=r".*Invalid method for metric*"): + with pytest.raises(ValueError, match=r".*Invalid method.*"): MultiLabelClassificationController( predict_function=toy_predict_function, method=method, - metric_control="recall", + risk="recall", ) def test_method_none_precision() -> None: mapie_clf = MultiLabelClassificationController( predict_function=toy_predict_function, - metric_control="precision", + risk="precision", confidence_level=0.9, ) mapie_clf.calibrate(X_toy, y_toy) @@ -778,7 +780,7 @@ def test_method_none_precision() -> None: def test_method_none_recall() -> None: mapie_clf = MultiLabelClassificationController( - predict_function=toy_predict_function, metric_control="recall" + predict_function=toy_predict_function, risk="recall" ) mapie_clf.calibrate(X_toy, y_toy) assert mapie_clf.method == "crc" diff --git a/mapie/tests/risk_control/test_risk_control.py b/mapie/tests/risk_control/test_risk_control.py index 9423a1346..bfbbe6e38 100644 --- a/mapie/tests/risk_control/test_risk_control.py +++ b/mapie/tests/risk_control/test_risk_control.py @@ -15,7 +15,6 @@ find_precision_best_predict_param, ltt_procedure, ) -from mapie.risk_control.risks import compute_risk_precision, compute_risk_recall lambdas = np.array([0.5, 0.9]) @@ -54,67 +53,6 @@ prng = np.random.RandomState(random_state) -def test_compute_recall_equal() -> None: - """Test that compute_recall give good result""" - recall = compute_risk_recall(lambdas, y_preds_proba, y_toy) - np.testing.assert_equal(recall, test_recall) - - -def test_compute_precision() -> None: - """Test that compute_precision give good result""" - precision = compute_risk_precision(lambdas, y_preds_proba, y_toy) - np.testing.assert_equal(precision, test_precision) - - -@pytest.mark.filterwarnings("ignore:: RuntimeWarning") -def test_recall_with_zero_sum_is_equal_nan() -> None: - """Test compute_recall with nan values""" - y_toy = np.zeros((4, 3)) - y_preds_proba = prng.rand(4, 3, 1) - recall = compute_risk_recall(lambdas, y_preds_proba, y_toy) - np.testing.assert_array_equal(recall, np.full_like(recall, np.nan)) - - -def test_precision_with_zero_sum_is_equal_ones() -> None: - """Test compute_precision with nan values""" - y_toy = prng.rand(4, 3) - y_preds_proba = np.zeros((4, 3, 1)) - precision = compute_risk_precision(lambdas, y_preds_proba, y_toy) - np.testing.assert_array_equal(precision, np.ones_like(precision)) - - -def test_compute_recall_shape() -> None: - """Test shape when using _compute_recall""" - recall = compute_risk_recall(lambdas, y_preds_proba, y_toy) - np.testing.assert_equal(recall.shape, test_recall.shape) - - -def test_compute_precision_shape() -> None: - """Test shape when using _compute_precision""" - precision = compute_risk_precision(lambdas, y_preds_proba, y_toy) - np.testing.assert_equal(precision.shape, test_precision.shape) - - -def test_compute_recall_with_wrong_shape() -> None: - """Test error when wrong shape in _compute_recall""" - with pytest.raises(ValueError, match=r".*y_pred_proba should be a 3d*"): - compute_risk_recall(lambdas, y_preds_proba.squeeze(), y_toy) - with pytest.raises(ValueError, match=r".*y should be a 2d*"): - compute_risk_recall(lambdas, y_preds_proba, np.expand_dims(y_toy, 2)) - with pytest.raises(ValueError, match=r".*could not be broadcast*"): - compute_risk_recall(lambdas, y_preds_proba, y_toy[:-1]) - - -def test_compute_precision_with_wrong_shape() -> None: - """Test shape when using _compute_precision""" - with pytest.raises(ValueError, match=r".*y_pred_proba should be a 3d*"): - compute_risk_precision(lambdas, y_preds_proba.squeeze(), y_toy) - with pytest.raises(ValueError, match=r".*y should be a 2d*"): - compute_risk_precision(lambdas, y_preds_proba, np.expand_dims(y_toy, 2)) - with pytest.raises(ValueError, match=r".*could not be broadcast*"): - compute_risk_precision(lambdas, y_preds_proba, y_toy[:-1]) - - @pytest.mark.parametrize("alpha", [0.5, [0.5], [0.5, 0.9]]) def test_p_values_different_alpha(alpha: Union[float, NDArray]) -> None: """Test type for different alpha for p_values"""