Skip to content

Commit

Permalink
version 1.1.4 (#67)
Browse files Browse the repository at this point in the history
Adds device input as a function to estimator
  • Loading branch information
egillax authored Apr 19, 2023
1 parent f4c5e92 commit 935f51d
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 11 deletions.
7 changes: 3 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 1.1.3
Date: 15-12-2022
Version: 1.1.4
Date: 18-04-2023
Authors@R: c(
person("Egill", "Fridgeirsson", email = "[email protected]", role = c("aut", "cre")),
person("Jenna", "Reps", email = "[email protected]", role = c("aut")),
Expand All @@ -24,8 +24,7 @@ Imports:
ParallelLogger (>= 2.0.0),
PatientLevelPrediction (>= 6.0.4),
rlang,
torch (>= 0.9.0),
torchopt,
torch (>= 0.10.0),
withr
Suggests:
devtools,
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
DeepPatientLevelPrediction 1.1.4
======================
- Remove torchopt dependancy since adamw is now in torch
- Update torch dependency to >=0.10.0
- Allow device to be a function that resolves during Estimator initialization

DeepPatientLevelPrediction 1.1.3
======================
- Fix actions after torch updated to v0.10 (#65)
Expand Down
7 changes: 6 additions & 1 deletion R/Estimator-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ Estimator <- R6::R6Class(
modelParameters,
estimatorSettings) {
self$seed <- estimatorSettings$seed
self$device <- estimatorSettings$device
if (is.function(estimatorSettings$device)) {
device <- estimatorSettings$device()
} else {
device <- estimatorSettings$device
}
self$device <- device
torch::torch_manual_seed(seed=self$seed)
self$model <- do.call(modelType, modelParameters)
self$modelParameters <- modelParameters
Expand Down
5 changes: 3 additions & 2 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
#' @param weightDecay what weight_decay to use
#' @param batchSize batchSize to use
#' @param epochs how many epochs to train for
#' @param device what device to train on
#' @param device what device to train on, can be a string or a function to that evaluates
#' to the device during runtime
#' @param optimizer which optimizer to use
#' @param scheduler which learning rate scheduler to use
#' @param criterion loss function to use
Expand All @@ -41,7 +42,7 @@ setEstimator <- function(learningRate='auto',
batchSize = 512,
epochs = 30,
device='cpu',
optimizer = torchopt::optim_adamw,
optimizer = torch::optim_adamw,
scheduler = list(fun=torch::lr_reduce_on_plateau,
params=list(patience=1)),
criterion = torch::nn_bce_with_logits_loss,
Expand Down
5 changes: 3 additions & 2 deletions man/setEstimator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 39 additions & 1 deletion tests/testthat/test-Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,42 @@ test_that("setEstimator with paramsToTune is correctly added to hyperparameters"
expect_equal(estimatorSettings2$learningRate, 1e-3)
expect_equal(as.character(estimatorSettings2$metric), "auprc")
expect_equal(estimatorSettings2$earlyStopping$params$patience, 10)
})
})

test_that("device as a function argument works", {
getDevice <- function() {
dev <- Sys.getenv("testDeepPLPDevice")
if (dev == ""){
dev = "cpu"
} else{
dev
}
}

estimatorSettings <- setEstimator(device=getDevice)

model <- setDefaultResNet(estimatorSettings = estimatorSettings)
model$param[[1]]$catFeatures <- 10

estimator <- Estimator$new(modelType="ResNet",
modelParameters = model$param[[1]],
estimatorSettings = estimatorSettings)

expect_equal(estimator$device, "cpu")

Sys.setenv("testDeepPLPDevice" = "meta")

estimatorSettings <- setEstimator(device=getDevice)

model <- setDefaultResNet(estimatorSettings = estimatorSettings)
model$param[[1]]$catFeatures <- 10

estimator <- Estimator$new(modelType="ResNet",
modelParameters = model$param[[1]],
estimatorSettings = estimatorSettings)

expect_equal(estimator$device, "meta")

Sys.unsetenv("testDeepPLPDevice")

})
2 changes: 1 addition & 1 deletion tests/testthat/test-LRFinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ test_that("LR scheduler that changes per batch works", {
model <- ResNet(catFeatures = 10, numFeatures = 1,
sizeEmbedding = 32, sizeHidden = 64,
numLayers = 1, hiddenFactor = 1)
optimizer <- torchopt::optim_adamw(model$parameters, lr=1e-7)
optimizer <- torch::optim_adamw(model$parameters, lr=1e-7)

scheduler <- lrPerBatch(optimizer,
startLR = 1e-7,
Expand Down

0 comments on commit 935f51d

Please sign in to comment.