Skip to content

Commit ee9d7f3

Browse files
committed
Merge branch 'main' into release
2 parents 4961665 + 4cccff4 commit ee9d7f3

File tree

3 files changed

+231
-207
lines changed

3 files changed

+231
-207
lines changed

ethicml/plot/plotting.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _multivariate_grid(
9191
df: pd.DataFrame,
9292
scatter_alpha: float = 0.5,
9393
) -> None:
94-
def colored_scatter(x: Any, y: Any, c: str | None = None) -> Callable[[Any], None]:
94+
def colored_scatter(x: pd.Series, y: pd.Series, c: str | None = None) -> Callable[..., None]:
9595
def scatter(*args: Any, **kwargs: Any) -> None:
9696
args = (x, y)
9797
if c is not None:
@@ -109,15 +109,25 @@ def scatter(*args: Any, **kwargs: Any) -> None:
109109
for name, df_group in df.groupby([sens_col, outcome_col]):
110110
legends.append(f"S={name[0]}, Y={name[1]}")
111111
g.plot_joint(colored_scatter(df_group[col_x], df_group[col_y], color))
112-
sns.distplot( # type: ignore[attr-defined]
113-
df_group[col_x].to_numpy(), ax=g.ax_marg_x, color=color
112+
sns.histplot( # type: ignore[attr-defined]
113+
df_group[col_x].to_numpy(),
114+
ax=g.ax_marg_x,
115+
color=color,
116+
kde=True,
117+
stat="density",
118+
kde_kws=dict(cut=3),
114119
)
115-
sns.distplot( # type: ignore[attr-defined]
116-
df_group[col_y].to_numpy(), ax=g.ax_marg_y, vertical=True
120+
sns.histplot( # type: ignore[attr-defined]
121+
df_group[col_y].to_numpy(),
122+
ax=g.ax_marg_y,
123+
vertical=True,
124+
kde=True,
125+
stat="density",
126+
kde_kws=dict(cut=3),
117127
)
118128
# Do also global Hist:
119-
# sns.distplot(df[col_x].values, ax=g.ax_marg_x, color='grey')
120-
# sns.distplot(df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True)
129+
# sns.histplot(df[col_x].values, ax=g.ax_marg_x, color='grey')
130+
# sns.histplot(df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True)
121131
plt.legend(legends)
122132

123133

0 commit comments

Comments
 (0)