@@ -91,7 +91,7 @@ def _multivariate_grid(
91
91
df : pd .DataFrame ,
92
92
scatter_alpha : float = 0.5 ,
93
93
) -> None :
94
- def colored_scatter (x : Any , y : Any , c : str | None = None ) -> Callable [[ Any ] , None ]:
94
+ def colored_scatter (x : pd . Series , y : pd . Series , c : str | None = None ) -> Callable [... , None ]:
95
95
def scatter (* args : Any , ** kwargs : Any ) -> None :
96
96
args = (x , y )
97
97
if c is not None :
@@ -109,15 +109,25 @@ def scatter(*args: Any, **kwargs: Any) -> None:
109
109
for name , df_group in df .groupby ([sens_col , outcome_col ]):
110
110
legends .append (f"S={ name [0 ]} , Y={ name [1 ]} " )
111
111
g .plot_joint (colored_scatter (df_group [col_x ], df_group [col_y ], color ))
112
- sns .distplot ( # type: ignore[attr-defined]
113
- df_group [col_x ].to_numpy (), ax = g .ax_marg_x , color = color
112
+ sns .histplot ( # type: ignore[attr-defined]
113
+ df_group [col_x ].to_numpy (),
114
+ ax = g .ax_marg_x ,
115
+ color = color ,
116
+ kde = True ,
117
+ stat = "density" ,
118
+ kde_kws = dict (cut = 3 ),
114
119
)
115
- sns .distplot ( # type: ignore[attr-defined]
116
- df_group [col_y ].to_numpy (), ax = g .ax_marg_y , vertical = True
120
+ sns .histplot ( # type: ignore[attr-defined]
121
+ df_group [col_y ].to_numpy (),
122
+ ax = g .ax_marg_y ,
123
+ vertical = True ,
124
+ kde = True ,
125
+ stat = "density" ,
126
+ kde_kws = dict (cut = 3 ),
117
127
)
118
128
# Do also global Hist:
119
- # sns.distplot (df[col_x].values, ax=g.ax_marg_x, color='grey')
120
- # sns.distplot (df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True)
129
+ # sns.histplot (df[col_x].values, ax=g.ax_marg_x, color='grey')
130
+ # sns.histplot (df[col_y].values.ravel(), ax=g.ax_marg_y, color='grey', vertical=True)
121
131
plt .legend (legends )
122
132
123
133
0 commit comments