diff --git a/mostlyai/engine/_tabular/probability.py b/mostlyai/engine/_tabular/probability.py index 74a1a094..a97a9f3e 100644 --- a/mostlyai/engine/_tabular/probability.py +++ b/mostlyai/engine/_tabular/probability.py @@ -245,6 +245,7 @@ def _generate_marginal_probs( tgt_stats: dict, seed_columns: list[str], device: torch.device, + n_samples: int, ctx_data: pd.DataFrame | None = None, ctx_stats: dict | None = None, fixed_probs: dict | None = None, @@ -259,6 +260,7 @@ def _generate_marginal_probs( tgt_stats: Target statistics seed_columns: Seed column names in original format, in correct order device: Device for computation + n_samples: Number of samples to generate probabilities for ctx_data: Optional context data ctx_stats: Optional context statistics (required if ctx_data provided) fixed_probs: Optional fixed probabilities for rare token handling @@ -266,7 +268,6 @@ def _generate_marginal_probs( Returns: DataFrame of shape (n_samples, cardinality) with probabilities and column names """ - n_samples = len(seed_encoded) target_stats = tgt_stats["columns"][target_column] # Build fixed_values dict from seed_encoded @@ -407,6 +408,7 @@ def predict_proba( tgt_stats=tgt_stats, seed_columns=seed_columns, device=device, + n_samples=n_samples, ctx_data=ctx_data, ctx_stats=ctx_stats, fixed_probs=fixed_probs, @@ -451,10 +453,10 @@ def predict_proba( # Build DataFrames for each combo with actual values, then concatenate combo_dfs = [] for combo_idx, prev_combo in enumerate(prev_combos): - # Copy extended_seed for this combo - df = extended_seed.copy() + # Build data dict starting with columns from extended_seed + data = {col: extended_seed[col].values for col in extended_seed.columns} - # Add previous target columns with actual values (no dummy values) + # Add previous target columns with actual values for i in range(target_idx): prev_target_col = target_columns[i] encoded_val = prev_combo[i] @@ -466,8 +468,10 @@ def predict_proba( argn_column=prev_target_stats[ARGN_COLUMN], argn_sub_column=sub_col_key, ) - df[full_sub_col_name] = encoded_val + data[full_sub_col_name] = encoded_val + # Create DataFrame with explicit row count + df = pd.DataFrame(data, index=range(n_samples)) combo_dfs.append(df) # Concatenate all combo DataFrames into single batch @@ -490,6 +494,7 @@ def predict_proba( tgt_stats=tgt_stats, seed_columns=extended_seed_columns, device=device, + n_samples=n_samples * num_prev_combos, ctx_data=batched_ctx_data, ctx_stats=ctx_stats, fixed_probs=fixed_probs, diff --git a/uv.lock b/uv.lock index 5c767336..b322e48c 100644 --- a/uv.lock +++ b/uv.lock @@ -2282,7 +2282,7 @@ wheels = [ [[package]] name = "mostlyai-engine" -version = "2.3.1" +version = "2.3.3" source = { editable = "." } dependencies = [ { name = "accelerate" },