Skip to content

Commit f1a884f

Browse files
authored
fix: resolve issue withpredict_proba when used with ctx_data
1 parent 67767b3 commit f1a884f

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

mostlyai/engine/_tabular/probability.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def _generate_marginal_probs(
245245
tgt_stats: dict,
246246
seed_columns: list[str],
247247
device: torch.device,
248+
n_samples: int,
248249
ctx_data: pd.DataFrame | None = None,
249250
ctx_stats: dict | None = None,
250251
fixed_probs: dict | None = None,
@@ -259,14 +260,14 @@ def _generate_marginal_probs(
259260
tgt_stats: Target statistics
260261
seed_columns: Seed column names in original format, in correct order
261262
device: Device for computation
263+
n_samples: Number of samples to generate probabilities for
262264
ctx_data: Optional context data
263265
ctx_stats: Optional context statistics (required if ctx_data provided)
264266
fixed_probs: Optional fixed probabilities for rare token handling
265267
266268
Returns:
267269
DataFrame of shape (n_samples, cardinality) with probabilities and column names
268270
"""
269-
n_samples = len(seed_encoded)
270271
target_stats = tgt_stats["columns"][target_column]
271272

272273
# Build fixed_values dict from seed_encoded
@@ -407,6 +408,7 @@ def predict_proba(
407408
tgt_stats=tgt_stats,
408409
seed_columns=seed_columns,
409410
device=device,
411+
n_samples=n_samples,
410412
ctx_data=ctx_data,
411413
ctx_stats=ctx_stats,
412414
fixed_probs=fixed_probs,
@@ -451,10 +453,10 @@ def predict_proba(
451453
# Build DataFrames for each combo with actual values, then concatenate
452454
combo_dfs = []
453455
for combo_idx, prev_combo in enumerate(prev_combos):
454-
# Copy extended_seed for this combo
455-
df = extended_seed.copy()
456+
# Build data dict starting with columns from extended_seed
457+
data = {col: extended_seed[col].values for col in extended_seed.columns}
456458

457-
# Add previous target columns with actual values (no dummy values)
459+
# Add previous target columns with actual values
458460
for i in range(target_idx):
459461
prev_target_col = target_columns[i]
460462
encoded_val = prev_combo[i]
@@ -466,8 +468,10 @@ def predict_proba(
466468
argn_column=prev_target_stats[ARGN_COLUMN],
467469
argn_sub_column=sub_col_key,
468470
)
469-
df[full_sub_col_name] = encoded_val
471+
data[full_sub_col_name] = encoded_val
470472

473+
# Create DataFrame with explicit row count
474+
df = pd.DataFrame(data, index=range(n_samples))
471475
combo_dfs.append(df)
472476

473477
# Concatenate all combo DataFrames into single batch
@@ -490,6 +494,7 @@ def predict_proba(
490494
tgt_stats=tgt_stats,
491495
seed_columns=extended_seed_columns,
492496
device=device,
497+
n_samples=n_samples * num_prev_combos,
493498
ctx_data=batched_ctx_data,
494499
ctx_stats=ctx_stats,
495500
fixed_probs=fixed_probs,

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)