Skip to content

Commit

Permalink
torch compile and slighly more efficient conversions to torch from po…
Browse files Browse the repository at this point in the history
…lars
  • Loading branch information
egillax committed Nov 21, 2024
1 parent 279f274 commit 3fb855a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
4 changes: 4 additions & 0 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#' @param criterion loss function to use
#' @param earlyStopping If earlyStopping should be used which stops the
#' training of your metric is not improving
#' @param compile if the model should be compiled before training, default FALSE
#' @param metric either `auc` or `loss` or a custom metric to use. This is the
#' metric used for scheduler and earlyStopping.
#' Needs to be a list with function `fun`, mode either `min` or `max` and a
Expand All @@ -59,6 +60,7 @@ setEstimator <- function(
useEarlyStopping = TRUE,
params = list(patience = 4)
),
compile = FALSE,
metric = "auc",
accumulationSteps = NULL,
seed = NULL) {
Expand All @@ -74,6 +76,7 @@ setEstimator <- function(
checkIsClass(epochs, c("numeric", "integer"))
checkHigher(epochs, 0)
checkIsClass(earlyStopping, c("list", "NULL"))
checkIsClass(compile, "logical")
checkIsClass(metric, c("character", "list"))
checkIsClass(seed, c("numeric", "integer", "NULL"))

Expand All @@ -100,6 +103,7 @@ setEstimator <- function(
epochs = epochs,
device = device,
earlyStopping = earlyStopping,
compile = compile,
findLR = findLR,
metric = metric,
accumulationSteps = accumulationSteps,
Expand Down
11 changes: 3 additions & 8 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, data, labels=None, numerical_features=None):
.with_columns(pl.col("rowId") - 1)
.collect()
)
cat_tensor = torch.tensor(data_cat.to_numpy())
cat_tensor = data_cat.to_torch()
tensor_list = torch.split(
cat_tensor[:, 1],
torch.unique_consecutive(cat_tensor[:, 0], return_counts=True)[1].tolist(),
Expand Down Expand Up @@ -90,13 +90,8 @@ def __init__(self, data, labels=None, numerical_features=None):
)
.collect()
)
indices = torch.as_tensor(
numerical_data.select(["rowId", "columnId"]).to_numpy(),
dtype=torch.long,
)
values = torch.tensor(
numerical_data.select("covariateValue").to_numpy(), dtype=torch.float
)
indices = numerical_data.select(["rowId", "columnId"]).to_torch(dtype=pl.Int64)
values = numerical_data.select("covariateValue").to_torch(dtype=pl.Float32)
self.num = torch.sparse_coo_tensor(
indices=indices.T,
values=values.squeeze(),
Expand Down
2 changes: 2 additions & 0 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(self, model, model_parameters, estimator_settings):
self.best_score = None
self.best_epoch = None
self.learn_rate_schedule = None
if parameters["estimator_settings"]["compile"]:
self.model = torch.compile(self.model, dynamic=False)

def fit(self, dataset, test_dataset):
train_dataloader = DataLoader(
Expand Down

0 comments on commit 3fb855a

Please sign in to comment.