Skip to content

Commit

Permalink
torch compile and slighly more efficient conversions to torch from po…
Browse files Browse the repository at this point in the history
…lars (#133)

* torch compile and slighly more efficient conversions to torch from polars
  • Loading branch information
egillax authored Nov 21, 2024
1 parent 89cc563 commit 4b99d90
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 10 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/R_CDM_check_hades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ jobs:
extra-packages: any::rcmdcheck
needs: check

- uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: setup r-reticulate venv
shell: Rscript {0}
run: |
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Suggests:
Remotes:
ohdsi/PatientLevelPrediction,
ohdsi/ResultModelManager
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: TRUE
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@

DeepPatientLevelPrediction 2.1.0.999
======================
- Add an option to use torch compile
- More efficient conversions from polars to torch in dataset processing
- Automatically detect broken links in docs using github actions
- Model initialization made more flexible with classes

DeepPatientLevelPrediction 2.1.0
======================
Expand Down
4 changes: 4 additions & 0 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#' @param criterion loss function to use
#' @param earlyStopping If earlyStopping should be used which stops the
#' training of your metric is not improving
#' @param compile if the model should be compiled before training, default FALSE
#' @param metric either `auc` or `loss` or a custom metric to use. This is the
#' metric used for scheduler and earlyStopping.
#' Needs to be a list with function `fun`, mode either `min` or `max` and a
Expand All @@ -59,6 +60,7 @@ setEstimator <- function(
useEarlyStopping = TRUE,
params = list(patience = 4)
),
compile = FALSE,
metric = "auc",
accumulationSteps = NULL,
seed = NULL) {
Expand All @@ -74,6 +76,7 @@ setEstimator <- function(
checkIsClass(epochs, c("numeric", "integer"))
checkHigher(epochs, 0)
checkIsClass(earlyStopping, c("list", "NULL"))
checkIsClass(compile, "logical")
checkIsClass(metric, c("character", "list"))
checkIsClass(seed, c("numeric", "integer", "NULL"))

Expand All @@ -100,6 +103,7 @@ setEstimator <- function(
epochs = epochs,
device = device,
earlyStopping = earlyStopping,
compile = compile,
findLR = findLR,
metric = metric,
accumulationSteps = accumulationSteps,
Expand Down
11 changes: 3 additions & 8 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, data, labels=None, numerical_features=None):
.with_columns(pl.col("rowId") - 1)
.collect()
)
cat_tensor = torch.tensor(data_cat.to_numpy())
cat_tensor = data_cat.to_torch()
tensor_list = torch.split(
cat_tensor[:, 1],
torch.unique_consecutive(cat_tensor[:, 0], return_counts=True)[1].tolist(),
Expand Down Expand Up @@ -90,13 +90,8 @@ def __init__(self, data, labels=None, numerical_features=None):
)
.collect()
)
indices = torch.as_tensor(
numerical_data.select(["rowId", "columnId"]).to_numpy(),
dtype=torch.long,
)
values = torch.tensor(
numerical_data.select("covariateValue").to_numpy(), dtype=torch.float
)
indices = numerical_data.select(["rowId", "columnId"]).to_torch(dtype=pl.Int64)
values = numerical_data.select("covariateValue").to_torch(dtype=pl.Float32)
self.num = torch.sparse_coo_tensor(
indices=indices.T,
values=values.squeeze(),
Expand Down
3 changes: 3 additions & 0 deletions inst/python/Estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(self, model, model_parameters, estimator_settings):
self.best_score = None
self.best_epoch = None
self.learn_rate_schedule = None
torch_compile = estimator_settings.get("compile", False)
if torch_compile:
self.model = torch.compile(self.model, dynamic=False)

def fit(self, dataset, test_dataset):
train_dataloader = DataLoader(
Expand Down
3 changes: 3 additions & 0 deletions man/setEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/setMultiLayerPerceptron.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 4b99d90

Please sign in to comment.