Skip to content

Commit

Permalink
update description
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 9, 2022
1 parent 4376997 commit 5a92835
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 1.0.0
Date: 29-08-2022
Date: 09-10-2022
Authors@R: c(
person("Jenna", "Reps", email = "[email protected]", role = c("aut")),
person("Egill", "Fridgeirsson", email = "[email protected]", role = c("aut", "cre")),
Expand All @@ -23,7 +23,7 @@ Imports:
data.table,
FeatureExtraction (>= 3.0.0),
ParallelLogger (>= 2.0.0),
PatientLevelPrediction,
PatientLevelPrediction, (>= 6.0.4)
rlang,
torch (>= 0.8.0)
Suggests:
Expand All @@ -34,7 +34,7 @@ Suggests:
plyr,
testthat
Remotes:
ohdsi/PatientLevelPrediction@develop,
ohdsi/PatientLevelPrediction,
ohdsi/FeatureExtraction,
ohdsi/Eunomia
RoxygenNote: 7.2.1
Expand Down
18 changes: 9 additions & 9 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#' @param optimizer which optimizer to use
#' @param scheduler which learning rate scheduler to use
#' @param criterion loss function to use
#' @param posWeight If more weight should be added to positive labels during training - will result in miscalibrated models
#' @param earlyStopping If earlyStopping should be used which stops the training of your metric is not improving
#' @param earlyStoppingMetric Which parameter to use for early stopping
#' @param patience patience for earlyStopper
Expand Down Expand Up @@ -247,7 +246,6 @@ gridCvDeep <- function(mappedData,
ParallelLogger::logInfo(paste0("Fold ", i))
trainDataset <- torch::dataset_subset(dataset, indices = which(fold != i))
testDataset <- torch::dataset_subset(dataset, indices = which(fold == i))
# fitParams$posWeight <- trainDataset$dataset$posWeight
estimator <- Estimator$new(
baseModel = baseModel,
modelParameters = modelParams,
Expand Down Expand Up @@ -301,7 +299,6 @@ gridCvDeep <- function(mappedData,
fitParams <- finalParam[fitParamNames]
fitParams$epochs <- finalParam$learnSchedule$bestEpoch
fitParams$batchSize <- batchSize
fitParams$posWeight <- dataset$posWeight
# create the dir
if (!dir.exists(file.path(modelLocation))) {
dir.create(file.path(modelLocation), recursive = T)
Expand Down Expand Up @@ -388,7 +385,6 @@ Estimator <- R6::R6Class(
self$learningRate <- self$itemOrDefaults(fitParameters, "learningRate", 1e-3)
self$l2Norm <- self$itemOrDefaults(fitParameters, "weightDecay", 1e-5)
self$batchSize <- self$itemOrDefaults(fitParameters, "batchSize", 1024)
self$posWeight <- self$itemOrDefaults(fitParameters, "posWeight", 1)
self$prefix <- self$itemOrDefaults(fitParameters, "prefix", self$model$name)

self$previousEpochs <- self$itemOrDefaults(fitParameters, "previousEpochs", 0)
Expand All @@ -399,9 +395,7 @@ Estimator <- R6::R6Class(
lr = self$learningRate,
weight_decay = self$l2Norm
)
self$criterion <- criterion(torch::torch_tensor(self$posWeight,
device = self$device
))
self$criterion <- criterion()

self$scheduler <- scheduler(self$optimizer,
patience = 1,
Expand Down Expand Up @@ -615,14 +609,20 @@ Estimator <- R6::R6Class(
batchIndex <- 1:length(dataset)
batchIndex <- split(batchIndex, ceiling(seq_along(batchIndex) / self$batchSize))
torch::with_no_grad({
predictions <- c()
predictions <- torch::torch_empty(length(dataset), device=self$device)
self$model$eval()
progressBar <- utils::txtProgressBar(style = 3)
ix <- 1
coro::loop(for (b in batchIndex) {
batch <- self$batchToDevice(dataset[b])
target <- batch$target
pred <- self$model(batch$batch)
predictions <- c(predictions, as.array(torch::torch_sigmoid(pred$cpu())))
predictions[b] <- torch::torch_sigmoid(pred)
utils::setTxtProgressBar(progressBar, ix / length(batchIndex))
ix <- ix + 1
})
predictions <- as.array(predictions$cpu())
close(progressBar)
})
return(predictions)
},
Expand Down

0 comments on commit 5a92835

Please sign in to comment.