Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions mostlyai/engine/_tabular/probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _generate_marginal_probs(
tgt_stats: dict,
seed_columns: list[str],
device: torch.device,
n_samples: int,
ctx_data: pd.DataFrame | None = None,
ctx_stats: dict | None = None,
fixed_probs: dict | None = None,
Expand All @@ -259,14 +260,14 @@ def _generate_marginal_probs(
tgt_stats: Target statistics
seed_columns: Seed column names in original format, in correct order
device: Device for computation
n_samples: Number of samples to generate probabilities for
ctx_data: Optional context data
ctx_stats: Optional context statistics (required if ctx_data provided)
fixed_probs: Optional fixed probabilities for rare token handling

Returns:
DataFrame of shape (n_samples, cardinality) with probabilities and column names
"""
n_samples = len(seed_encoded)
target_stats = tgt_stats["columns"][target_column]

# Build fixed_values dict from seed_encoded
Expand Down Expand Up @@ -407,6 +408,7 @@ def predict_proba(
tgt_stats=tgt_stats,
seed_columns=seed_columns,
device=device,
n_samples=n_samples,
ctx_data=ctx_data,
ctx_stats=ctx_stats,
fixed_probs=fixed_probs,
Expand Down Expand Up @@ -451,10 +453,10 @@ def predict_proba(
# Build DataFrames for each combo with actual values, then concatenate
combo_dfs = []
for combo_idx, prev_combo in enumerate(prev_combos):
# Copy extended_seed for this combo
df = extended_seed.copy()
# Build data dict starting with columns from extended_seed
data = {col: extended_seed[col].values for col in extended_seed.columns}

# Add previous target columns with actual values (no dummy values)
# Add previous target columns with actual values
for i in range(target_idx):
prev_target_col = target_columns[i]
encoded_val = prev_combo[i]
Expand All @@ -466,8 +468,10 @@ def predict_proba(
argn_column=prev_target_stats[ARGN_COLUMN],
argn_sub_column=sub_col_key,
)
df[full_sub_col_name] = encoded_val
data[full_sub_col_name] = encoded_val

# Create DataFrame with explicit row count
df = pd.DataFrame(data, index=range(n_samples))
combo_dfs.append(df)

# Concatenate all combo DataFrames into single batch
Expand All @@ -490,6 +494,7 @@ def predict_proba(
tgt_stats=tgt_stats,
seed_columns=extended_seed_columns,
device=device,
n_samples=n_samples * num_prev_combos,
ctx_data=batched_ctx_data,
ctx_stats=ctx_stats,
fixed_probs=fixed_probs,
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.