diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index 69ee153..c9ad8b3 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -99,7 +99,8 @@ def __init__(self, model, model_parameters, estimator_settings): self.best_score = None self.best_epoch = None self.learn_rate_schedule = None - if estimator_settings["compile"]: + torch_compile = estimator_settings.get("compile", False) + if torch_compile: self.model = torch.compile(self.model, dynamic=False) def fit(self, dataset, test_dataset):