Skip to content

Commit 4fc997a

Browse files
committed
fix tests errors
1 parent ef63034 commit 4fc997a

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

mapie/risk_control/multi_label_classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def _check_compute_risks_first_call(self) -> bool:
372372
bool
373373
True if it is the first time, else False.
374374
"""
375-
return not hasattr(self, "risks")
375+
return not hasattr(self, "_risks")
376376

377377
def _check_bound(self, bound: Optional[str]):
378378
"""
@@ -496,10 +496,10 @@ def compute_best_predict_param(self) -> MultiLabelClassificationController:
496496
self.n_obs = len(self._risks)
497497
self.r_hat = self._risks.mean(axis=0)
498498
self.valid_index, _ = ltt_procedure(
499-
self.r_hat,
500-
np.tile(self._alpha, (self.r_hat.shape[0], 1)),
499+
np.expand_dims(self.r_hat, axis=0),
500+
np.expand_dims(self._alpha, axis=0),
501501
cast(float, self._delta),
502-
np.full_like(self.r_hat, self.n_obs),
502+
np.expand_dims(np.array([self.n_obs]), axis=0),
503503
)
504504
self.valid_predict_params = []
505505
for index_list in self.valid_index:

mapie/tests/risk_control/test_precision_recall_control.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def test_reinit_new_fit():
485485
)
486486
mapie_clf.calibrate(X_toy, y_toy)
487487
mapie_clf.calibrate(X_toy, y_toy)
488-
assert len(mapie_clf.risks) == len(X_toy)
488+
assert len(mapie_clf._risks) == len(X_toy)
489489

490490

491491
@pytest.mark.parametrize("method", WRONG_METHODS)
@@ -534,9 +534,7 @@ def test_bound_error(bound: str) -> None:
534534
@pytest.mark.parametrize("metric_control", WRONG_METRICS)
535535
def test_metric_error_in_init(metric_control: str) -> None:
536536
"""Test error for wrong metrics"""
537-
with pytest.raises(
538-
ValueError, match=r".*When risk is provided as a string, it must be one of:.*"
539-
):
537+
with pytest.raises(ValueError, match=r".*risk must be one of:*"):
540538
MultiLabelClassificationController(
541539
predict_function=toy_predict_function,
542540
random_state=random_state,

0 commit comments

Comments
 (0)