diff --git a/R/Estimator.R b/R/Estimator.R index 19f414e..850eab1 100644 --- a/R/Estimator.R +++ b/R/Estimator.R @@ -218,6 +218,7 @@ fitEstimator <- function(trainData, included = incs, covariateValue = 0, isNumeric = .data$columnId %in% cvResult$numericalIndex + # get mapping maybe here ) comp <- start - Sys.time() @@ -268,6 +269,7 @@ fitEstimator <- function(trainData, hyperParamSearch = hyperSummary ), covariateImportance = covariateRef + # also return mapping as part of covariateRef above, not necessary to do separately ) class(result) <- "plpModel" @@ -301,6 +303,7 @@ predictDeepEstimator <- function(plpModel, cohort = cohort, mapping = plpModel$covariateImportance %>% dplyr::select("columnId", "covariateId") + # check this if it is correclty passing the mapped data rather than creating a new mapping ) data <- createDataset(mappedData, plpModel = plpModel) } @@ -421,6 +424,7 @@ gridCvDeep <- function(mappedData, prediction$cohortStartDate <- as.Date(prediction$cohortStartDate, origin = "1970-01-01") numericalIndex <- dataset$get_numerical_features() + # get mapping as above # save torch code here if (!dir.exists(file.path(modelLocation))) { @@ -434,6 +438,7 @@ gridCvDeep <- function(mappedData, finalParam = finalParam, paramGridSearch = paramGridSearch, numericalIndex = numericalIndex$to_list() + # add mapping here, two columns [covariateId, columnId] ) ) } @@ -577,8 +582,9 @@ doCrossValidationImpl <- function(dataset, fillEstimatorSettings(modelSettings$estimatorSettings, fitParams, parameters) - currentModelParams$catFeatures <- dataset$get_cat_features()$max() + currentModelParams$catFeatures <- dataset$get_cat_features()$len() currentModelParams$numFeatures <- dataset$get_numerical_features()$len() + currentModelParams$cat2Features <- dataset$get_cat_2_features()$len() if (currentEstimatorSettings$findLR) { lr <- getLR(currentModelParams, currentEstimatorSettings, dataset) ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr)) @@ -659,7 +665,8 @@ trainFinalModel <- function(dataset, finalParam, modelSettings, labels) { fitParams <- names(finalParam)[grepl("^estimator", names(finalParam))] - modelParams$catFeatures <- dataset$get_cat_features()$max() + modelParams$catFeatures <- dataset$get_cat_features()$len() + modelParams$cat2Features <- dataset$get_cat_2_features()$len() modelParams$numFeatures <- dataset$get_numerical_features()$len() modelParams$modelType <- modelSettings$modelType