Skip to content

Commit

Permalink
start refactoring of estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Nov 26, 2023
1 parent 4188442 commit 92bd442
Showing 1 changed file with 42 additions and 31 deletions.
73 changes: 42 additions & 31 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,7 @@ gridCvDeep <- function(mappedData,

currentEstimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings, fitParams,
paramSearch[[gridId]])

# initiate prediction
prediction <- NULL

fold <- labels$index
ParallelLogger::logInfo(paste0("Max fold: ", max(fold)))
currentEstimatorSettings$modelType <- modelSettings$modelType
currentModelParams$catFeatures <- dataset$get_cat_features()$shape[[1]]
currentModelParams$numFeatures <- dataset$get_numerical_features()$shape[[1]]
if (findLR) {
Expand All @@ -337,32 +332,13 @@ gridCvDeep <- function(mappedData,
currentEstimatorSettings$learningRate <- lr
}

learnRates <- list()
for (i in 1:max(fold)) {
ParallelLogger::logInfo(paste0("Fold ", i))
trainDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold != i) - 1)) # -1 for python 0-based indexing
testDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold == i) -1)) # -1 for python 0-based indexing
crossValidationResults <- doCrossvalidation(dataset,
labels=labels,
modelSettings = currentModelParams,
estimatorSettings = currentEstimatorSettings)
learnRates <- crossValidationResults$learnRates
prediction <- crossValidationResults$prediction

estimator <- createEstimator(modelType=modelSettings$modelType,
modelParameters=currentModelParams,
estimatorSettings=currentEstimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")

prediction <- rbind(
prediction,
predictDeepEstimator(
plpModel = estimator,
data = testDataset,
cohort = labels[fold == i, ]
)
)
learnRates[[i]] <- list(
LRs = estimator$learn_rate_schedule,
bestEpoch = estimator$best_epoch
)
}
maxIndex <- which.max(unlist(sapply(learnRates, `[`, 2)))
gridSearchPredictons[[gridId]] <- list(
prediction = prediction,
Expand Down Expand Up @@ -496,4 +472,39 @@ createEstimator <- function(modelType,
model_parameters = modelParameters,
estimator_settings = estimatorSettings)
return(estimator)
}

doCrossvalidation <- function(dataset, labels, modelSettings, estimatorSettings) {
fold <- labels$index
ParallelLogger::logInfo(paste0("Max fold: ", max(fold)))
learnRates <- list()
prediction <- NULL
for (i in 1:max(fold)) {
ParallelLogger::logInfo(paste0("Fold ", i))
trainDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold != i) - 1)) # -1 for python 0-based indexing
testDataset <- torch$utils$data$Subset(dataset, indices = as.integer(which(fold == i) -1)) # -1 for python 0-based indexing

estimator <- createEstimator(modelType=estimatorSettings$modelType,
modelParameters=modelSettings,
estimatorSettings=estimatorSettings)
estimator$fit(trainDataset, testDataset)

ParallelLogger::logInfo("Calculating predictions on left out fold set...")

prediction <- rbind(
prediction,
predictDeepEstimator(
plpModel = estimator,
data = testDataset,
cohort = labels[fold == i, ]
)
)
learnRates[[i]] <- list(
LRs = estimator$learn_rate_schedule,
bestEpoch = estimator$best_epoch
)
}
return (results=list(prediction=prediction,
learnRates=learnRates))

}

0 comments on commit 92bd442

Please sign in to comment.