From 6adfde6b98d63deda9c4012d020e6bc12492b2b1 Mon Sep 17 00:00:00 2001 From: Henrik John Date: Tue, 27 Aug 2024 17:22:56 +0200 Subject: [PATCH] Pass embedding dimensions to model --- R/Estimator.R | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/R/Estimator.R b/R/Estimator.R index 19f414e..850eab1 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -218,6 +218,7 @@ fitEstimator <- function(trainData, included = incs, covariateValue = 0, isNumeric = .data$columnId %in% cvResult$numericalIndex + # get mapping maybe here ) comp <- start - Sys.time() @@ -268,6 +269,7 @@ fitEstimator <- function(trainData, hyperParamSearch = hyperSummary ), covariateImportance = covariateRef + # also return mapping as part of covariateRef above, not necessary to do separately ) class(result) <- "plpModel" @@ -301,6 +303,7 @@ predictDeepEstimator <- function(plpModel, cohort = cohort, mapping = plpModel$covariateImportance %>% dplyr::select("columnId", "covariateId") + # check this if it is correclty passing the mapped data rather than creating a new mapping ) data <- createDataset(mappedData, plpModel = plpModel) } @@ -421,6 +424,7 @@ gridCvDeep <- function(mappedData, prediction$cohortStartDate <- as.Date(prediction$cohortStartDate, origin = "1970-01-01") numericalIndex <- dataset$get_numerical_features() + # get mapping as above # save torch code here if (!dir.exists(file.path(modelLocation))) { @@ -434,6 +438,7 @@ gridCvDeep <- function(mappedData, finalParam = finalParam, paramGridSearch = paramGridSearch, numericalIndex = numericalIndex$to_list() + # add mapping here, two columns [covariateId, columnId] ) ) } @@ -577,8 +582,9 @@ doCrossValidationImpl <- function(dataset, fillEstimatorSettings(modelSettings$estimatorSettings, fitParams, parameters) - currentModelParams$catFeatures <- dataset$get_cat_features()$max() + currentModelParams$catFeatures <- dataset$get_cat_features()$len() currentModelParams$numFeatures <- dataset$get_numerical_features()$len() + currentModelParams$cat2Features <- dataset$get_cat_2_features()$len() if (currentEstimatorSettings$findLR) { lr <- getLR(currentModelParams, currentEstimatorSettings, dataset) ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr)) @@ -659,7 +665,8 @@ trainFinalModel <- function(dataset, finalParam, modelSettings, labels) { fitParams <- names(finalParam)[grepl("^estimator", names(finalParam))] - modelParams$catFeatures <- dataset$get_cat_features()$max() + modelParams$catFeatures <- dataset$get_cat_features()$len() + modelParams$cat2Features <- dataset$get_cat_2_features()$len() modelParams$numFeatures <- dataset$get_numerical_features()$len() modelParams$modelType <- modelSettings$modelType