Skip to content

Commit

Permalink
remove DeepNNTorch models - add MLP using estimator - Update docs and…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
egillax committed Aug 29, 2022
1 parent 38c88d6 commit 0cc79e9
Show file tree
Hide file tree
Showing 29 changed files with 503 additions and 1,482 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: DeepPatientLevelPrediction
Type: Package
Title: Deep Learning For Patient Level Prediction Using Data In The OMOP Common Data Model
Version: 0.0.1
Date: 2021-06-07
Version: 1.0.0
Date: 29-08-2022
Authors@R: c(
person("Jenna", "Reps", email = "[email protected]", role = c("aut")),
person("Egill", "Fridgeirsson", email = "[email protected]", role = c("aut", "cre")),
Expand Down Expand Up @@ -37,6 +37,6 @@ Remotes:
ohdsi/PatientLevelPrediction@develop,
ohdsi/FeatureExtraction,
ohdsi/Eunomia
RoxygenNote: 7.2.0
RoxygenNote: 7.2.1
Encoding: UTF-8
Config/testthat/edition: 3
7 changes: 1 addition & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

export(Dataset)
export(Estimator)
export(doubleLayerNN)
export(fitDeepNNTorch)
export(fitEstimator)
export(gridCvDeep)
export(predictDeepEstimator)
export(predictDeepNN)
export(setDeepNNTorch)
export(setMultiLayerPerceptron)
export(setResNet)
export(setTransformer)
export(singleLayerNN)
export(tripleLayerNN)
import(data.table)
importFrom(data.table,":=")
importFrom(dplyr,"%>%")
Expand Down
47 changes: 13 additions & 34 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ Dataset <- torch::dataset(
#' @param data a dataframe like object with the covariates
#' @param labels a dataframe with the labels
#' @param numericalIndex in what column numeric data is in (if any)
#' @param all if True then returns all features instead of splitting num/cat
initialize = function(data, labels = NULL, numericalIndex = NULL, all = FALSE) {
initialize = function(data, labels = NULL, numericalIndex = NULL) {
# determine numeric
if (is.null(numericalIndex) && all == FALSE) {
if (is.null(numericalIndex)) {
numericalIndex <- data %>%
dplyr::group_by(columnId) %>%
dplyr::collect() %>%
Expand All @@ -24,23 +23,12 @@ Dataset <- torch::dataset(
if (!is.null(labels)) {
self$target <- torch::torch_tensor(labels)
} else {
if (all == FALSE) {
self$target <- torch::torch_tensor(rep(0, data %>% dplyr::distinct(rowId)
%>% dplyr::collect() %>% nrow()))
} else {
self$target <- torch::torch_tensor(rep(0, dim(data)[[1]]))
}
self$target <- torch::torch_tensor(rep(0, data %>% dplyr::distinct(rowId)
%>% dplyr::collect() %>% nrow()))
}
# Weight to add in loss function to positive class
self$posWeight <- (self$target == 0)$sum() / self$target$sum()
# for DeepNNTorch
self$useAll <- all
if (all) {
self$all <- torch::torch_tensor(as.matrix(data), dtype = torch::torch_float32())
self$cat <- NULL
self$num <- NULL
return()
}

# add features
catColumns <- which(!numericalIndex)
dataCat <- dplyr::filter(data, columnId %in% catColumns) %>%
Expand Down Expand Up @@ -80,9 +68,6 @@ Dataset <- torch::dataset(
size = c(self$target$shape, sum(numericalIndex))
)$to_dense()
}
if (self$cat$shape[1] != self$num$shape[1]) {
browser()
}
},
getNumericalIndex = function() {
return(
Expand Down Expand Up @@ -110,26 +95,20 @@ Dataset <- torch::dataset(
},
.getBatchSingle = function(item) {
# add leading singleton dimension since models expects 2d tensors
if (self$useAll) {
batch <- list(all = self$all[item]$unsqueeze(1))
} else {
batch <- list(cat = self$cat[item]$unsqueeze(1),
num = self$num[item]$unsqueeze(1))
}
batch <- list(
cat = self$cat[item]$unsqueeze(1),
num = self$num[item]$unsqueeze(1)
)
return(list(
batch = batch,
target = self$target[item]$unsqueeze(1)
))
},
.getBatchRegular = function(item) {
if (self$useAll) {
batch <- list(all = self$all[item])
} else {
batch = list(
cat = self$cat[item],
num = self$num[item]
)
}
batch <- list(
cat = self$cat[item],
num = self$num[item]
)
return(list(
batch = batch,
target = self$target[item]
Expand Down
Loading

0 comments on commit 0cc79e9

Please sign in to comment.