Skip to content

Commit

Permalink
merge and refactor gridCvDeep
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Nov 27, 2023
2 parents 92bd442 + 49942a8 commit e4c84a4
Show file tree
Hide file tree
Showing 26 changed files with 879 additions and 698 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
^extras$
^deploy.sh$
^compare_versions$
^.mypy_cache$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ renv.lock
extras/
.Renviron
inst/python/__pycache__
.mypy_cache
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Generated by roxygen2: do not edit by hand

export(TrainingCache)
export(fitEstimator)
export(gridCvDeep)
export(predictDeepEstimator)
Expand All @@ -10,6 +9,7 @@ export(setEstimator)
export(setMultiLayerPerceptron)
export(setResNet)
export(setTransformer)
export(trainingCache)
importFrom(dplyr,"%>%")
importFrom(reticulate,py_to_r)
importFrom(reticulate,r_to_py)
Expand Down
23 changes: 12 additions & 11 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
createDataset <- function(data, labels, plpModel=NULL) {
createDataset <- function(data, labels, plpModel = NULL) {
path <- system.file("python", package = "DeepPatientLevelPrediction")
Dataset <- reticulate::import_from_path("Dataset", path = path)$Data
dataset <- reticulate::import_from_path("Dataset", path = path)$Data
if (is.null(attributes(data)$path)) {
# sqlite object
attributes(data)$path <- attributes(data)$dbname
}
if (is.null(plpModel)) {
data <- Dataset(r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
r_to_py(labels$outcomeCount))
} else {
numericalFeatures <-
r_to_py(as.array(which(plpModel$covariateImportance$isNumeric)))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
numerical_features = numericalFeatures)
}
else {
data <- Dataset(r_to_py(normalizePath(attributes(data)$path)),
numerical_features = r_to_py(as.array(which(plpModel$covariateImportance$isNumeric))) )
}

return(data)
}

return(data)
}
7 changes: 3 additions & 4 deletions R/DeepPatientLevelPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

#' DeepPatientLevelPrediction
#'
#' @description A package containing deep learning extensions for developing prediction models using data in the OMOP CDM
#' @description A package containing deep learning extensions for developing
#' prediction models using data in the OMOP CDM
#'
#' @docType package
#' @name DeepPatientLevelPrediction
Expand All @@ -28,9 +29,7 @@
NULL

.onLoad <- function(libname, pkgname) {
# use superassignment to update global reference
# use superassignment to update global reference
reticulate::configure_environment(pkgname)
torch <<- reticulate::import("torch", delay_load = TRUE)
}


Loading

0 comments on commit e4c84a4

Please sign in to comment.