diff --git a/python/dalex/NEWS.md b/python/dalex/NEWS.md index 23c03e310..8f714acda 100644 --- a/python/dalex/NEWS.md +++ b/python/dalex/NEWS.md @@ -3,6 +3,7 @@ ### development * added a way to pass `sample_weight` to loss functions in `model_parts()` (variable importance) using `weights` from `dx.Explainer` ([#563](https://github.com/ModelOriented/DALEX/issues/563)) +* fixed the visualization of `shap_wrapper` for `shap==0.45.0` ### v1.7.0 (2024-02-28) diff --git a/python/dalex/dalex/wrappers/_shap/object.py b/python/dalex/dalex/wrappers/_shap/object.py index e57901565..c75ec7be5 100644 --- a/python/dalex/dalex/wrappers/_shap/object.py +++ b/python/dalex/dalex/wrappers/_shap/object.py @@ -122,7 +122,14 @@ def plot(self, **kwargs): else: base_value = self.shap_explainer.expected_value - shap_values = self.result[1] if isinstance(self.result, list) else self.result + if isinstance(self.result, list): + shap_values = self.result[1] + elif isinstance(self.result, np.ndarray): + if len(self.result.shape) == 3: + shap_values = self.result[:, :, 1] + else: + shap_values = self.result + force_plot(base_value=base_value, shap_values=shap_values, features=self.new_observation.values,