From 3fb855a9d00b78c9eb023ad732d7386788049060 Mon Sep 17 00:00:00 2001 From: egillax Date: Thu, 21 Nov 2024 15:28:28 +0100 Subject: [PATCH] torch compile and slighly more efficient conversions to torch from polars --- R/Estimator.R | 4 ++++ inst/python/Dataset.py | 11 +++-------- inst/python/Estimator.py | 2 ++ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 19f414e..d4b8bc5 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -33,6 +33,7 @@ #' @param criterion loss function to use #' @param earlyStopping If earlyStopping should be used which stops the #' training of your metric is not improving +#' @param compile if the model should be compiled before training, default FALSE #' @param metric either `auc` or `loss` or a custom metric to use. This is the #' metric used for scheduler and earlyStopping. #' Needs to be a list with function `fun`, mode either `min` or `max` and a @@ -59,6 +60,7 @@ setEstimator <- function( useEarlyStopping = TRUE, params = list(patience = 4) ), + compile = FALSE, metric = "auc", accumulationSteps = NULL, seed = NULL) { @@ -74,6 +76,7 @@ setEstimator <- function( checkIsClass(epochs, c("numeric", "integer")) checkHigher(epochs, 0) checkIsClass(earlyStopping, c("list", "NULL")) + checkIsClass(compile, "logical") checkIsClass(metric, c("character", "list")) checkIsClass(seed, c("numeric", "integer", "NULL")) @@ -100,6 +103,7 @@ setEstimator <- function( epochs = epochs, device = device, earlyStopping = earlyStopping, + compile = compile, findLR = findLR, metric = metric, accumulationSteps = accumulationSteps, diff --git a/inst/python/Dataset.py b/inst/python/Dataset.py index ed3c3bd..dfe48a6 100644 --- a/inst/python/Dataset.py +++ b/inst/python/Dataset.py @@ -50,7 +50,7 @@ def __init__(self, data, labels=None, numerical_features=None): .with_columns(pl.col("rowId") - 1) .collect() ) - cat_tensor = torch.tensor(data_cat.to_numpy()) + cat_tensor = data_cat.to_torch() tensor_list = torch.split( cat_tensor[:, 1], torch.unique_consecutive(cat_tensor[:, 0], return_counts=True)[1].tolist(), @@ -90,13 +90,8 @@ def __init__(self, data, labels=None, numerical_features=None): ) .collect() ) - indices = torch.as_tensor( - numerical_data.select(["rowId", "columnId"]).to_numpy(), - dtype=torch.long, - ) - values = torch.tensor( - numerical_data.select("covariateValue").to_numpy(), dtype=torch.float - ) + indices = numerical_data.select(["rowId", "columnId"]).to_torch(dtype=pl.Int64) + values = numerical_data.select("covariateValue").to_torch(dtype=pl.Float32) self.num = torch.sparse_coo_tensor( indices=indices.T, values=values.squeeze(), diff --git a/inst/python/Estimator.py b/inst/python/Estimator.py index 1b6ac18..d9a92ee 100644 --- a/inst/python/Estimator.py +++ b/inst/python/Estimator.py @@ -99,6 +99,8 @@ def __init__(self, model, model_parameters, estimator_settings): self.best_score = None self.best_epoch = None self.learn_rate_schedule = None + if parameters["estimator_settings"]["compile"]: + self.model = torch.compile(self.model, dynamic=False) def fit(self, dataset, test_dataset): train_dataloader = DataLoader(