@@ -245,6 +245,7 @@ def _generate_marginal_probs(
245245 tgt_stats : dict ,
246246 seed_columns : list [str ],
247247 device : torch .device ,
248+ n_samples : int ,
248249 ctx_data : pd .DataFrame | None = None ,
249250 ctx_stats : dict | None = None ,
250251 fixed_probs : dict | None = None ,
@@ -259,14 +260,14 @@ def _generate_marginal_probs(
259260 tgt_stats: Target statistics
260261 seed_columns: Seed column names in original format, in correct order
261262 device: Device for computation
263+ n_samples: Number of samples to generate probabilities for
262264 ctx_data: Optional context data
263265 ctx_stats: Optional context statistics (required if ctx_data provided)
264266 fixed_probs: Optional fixed probabilities for rare token handling
265267
266268 Returns:
267269 DataFrame of shape (n_samples, cardinality) with probabilities and column names
268270 """
269- n_samples = len (seed_encoded )
270271 target_stats = tgt_stats ["columns" ][target_column ]
271272
272273 # Build fixed_values dict from seed_encoded
@@ -407,6 +408,7 @@ def predict_proba(
407408 tgt_stats = tgt_stats ,
408409 seed_columns = seed_columns ,
409410 device = device ,
411+ n_samples = n_samples ,
410412 ctx_data = ctx_data ,
411413 ctx_stats = ctx_stats ,
412414 fixed_probs = fixed_probs ,
@@ -451,10 +453,10 @@ def predict_proba(
451453 # Build DataFrames for each combo with actual values, then concatenate
452454 combo_dfs = []
453455 for combo_idx , prev_combo in enumerate (prev_combos ):
454- # Copy extended_seed for this combo
455- df = extended_seed . copy ()
456+ # Build data dict starting with columns from extended_seed
457+ data = { col : extended_seed [ col ]. values for col in extended_seed . columns }
456458
457- # Add previous target columns with actual values (no dummy values)
459+ # Add previous target columns with actual values
458460 for i in range (target_idx ):
459461 prev_target_col = target_columns [i ]
460462 encoded_val = prev_combo [i ]
@@ -466,8 +468,10 @@ def predict_proba(
466468 argn_column = prev_target_stats [ARGN_COLUMN ],
467469 argn_sub_column = sub_col_key ,
468470 )
469- df [full_sub_col_name ] = encoded_val
471+ data [full_sub_col_name ] = encoded_val
470472
473+ # Create DataFrame with explicit row count
474+ df = pd .DataFrame (data , index = range (n_samples ))
471475 combo_dfs .append (df )
472476
473477 # Concatenate all combo DataFrames into single batch
@@ -490,6 +494,7 @@ def predict_proba(
490494 tgt_stats = tgt_stats ,
491495 seed_columns = extended_seed_columns ,
492496 device = device ,
497+ n_samples = n_samples * num_prev_combos ,
493498 ctx_data = batched_ctx_data ,
494499 ctx_stats = ctx_stats ,
495500 fixed_probs = fixed_probs ,
0 commit comments