diff --git a/src/main.py b/src/main.py index 887a7d9..4c57250 100644 --- a/src/main.py +++ b/src/main.py @@ -708,6 +708,11 @@ def short_context_recovery(model, data, base_length, lambda_factors_base, n_hat_ # Fine-tune for short context recovery model = fine_tune(model, data, length, lambda_factors, n_hat, steps=100) + # Update model attributes + key = "4k" if length == 4096 else "8k" + model.lambda_factors[key] = lambda_factors + model.n_hat[key] = n_hat + # Store base factors for use during inference model.lambda_factors_base = lambda_factors_base model.n_hat_base = n_hat_base