diff --git a/R/Estimator.R b/R/Estimator.R
index 1023a92..3d7590c 100644
--- a/R/Estimator.R
+++ b/R/Estimator.R
@@ -54,8 +54,7 @@ setEstimator <- function(learningRate = "auto",
   earlyStopping = list(useEarlyStopping = TRUE,
                        params = list(patience = 4)),
   metric = "auc",
-  seed = NULL, 
-  modelType = NULL
+  seed = NULL
 ) {
 
   checkIsClass(learningRate, c("numeric", "character"))
@@ -90,8 +89,7 @@ setEstimator <- function(learningRate = "auto",
                             earlyStopping = earlyStopping,
                             findLR = findLR,
                             metric = metric,
-                            seed = seed[1],
-                            modelType = modelType)
+                            seed = seed[1])
 
   optimizer <- rlang::enquo(optimizer)
   estimatorSettings$optimizer <- function() rlang::eval_tidy(optimizer)
@@ -144,11 +142,11 @@ fitEstimator <- function(trainData,
   if (!is.null(trainData$folds)) {
     trainData$labels <- merge(trainData$labels, trainData$fold, by = "rowId")
   }
-  
-  if (modelSettings$estimatorSettings$modelType == "Finetuner") {
+
+  if (modelSettings$modelType == "Finetuner") {
     # make sure to use same mapping from covariateIds to columns if finetuning
     path <- modelSettings$param[[1]]$modelPath
-    oldCovImportance <- utils::read.csv(file.path(path, 
+    oldCovImportance <- utils::read.csv(file.path(path,
                                                   "covariateImportance.csv"))
     mapping <- oldCovImportance %>% dplyr::select("columnId", "covariateId")
     numericalIndex <- which(oldCovImportance %>% dplyr::pull("isNumeric"))
@@ -229,7 +227,7 @@ fitEstimator <- function(trainData,
       attrition = attr(trainData, "metaData")$attrition,
       trainingTime = paste(as.character(abs(comp)), attr(comp, "units")),
       trainingDate = Sys.Date(),
-      modelName = modelSettings$estimatorSettings$modelType,
+      modelName = modelSettings$modelType,
       finalModelParameters = cvResult$finalParam,
       hyperParamSearch = hyperSummary
     ),
@@ -278,9 +276,9 @@ predictDeepEstimator <- function(plpModel,
                         map_location = "cpu")
     estimator <-
       createEstimator(modelParameters =
-                        snakeCaseToCamelCaseNames(model$model_parameters),
+                      snakeCaseToCamelCaseNames(model$model_parameters),
                       estimatorSettings =
-                        snakeCaseToCamelCaseNames(model$estimator_settings))
+                      snakeCaseToCamelCaseNames(model$estimator_settings))
     estimator$model$load_state_dict(model$model_state_dict)
     prediction$value <- estimator$predict_proba(data)
   } else {
@@ -311,7 +309,7 @@ gridCvDeep <- function(mappedData,
                        modelLocation,
                        analysisPath) {
   ParallelLogger::logInfo(paste0("Running hyperparameter search for ",
-                                 modelSettings$estimatorSettings$modelType,
+                                 modelSettings$modelType,
                                  " model"))
 
   ###########################################################################
@@ -342,7 +340,9 @@ gridCvDeep <- function(mappedData,
       ParallelLogger::logInfo(paste(names(paramSearch[[gridId]]),
                                     paramSearch[[gridId]], collapse = " | "))
       currentModelParams <- paramSearch[[gridId]][modelSettings$modelParamNames]
-      attr(currentModelParams, "metaData")$names <- modelSettings$modelParamNames
+      attr(currentModelParams, "metaData")$names <-
+        modelSettings$modelParamNames
+      currentModelParams$modelType <- modelSettings$modelType
       currentEstimatorSettings <-
         fillEstimatorSettings(modelSettings$estimatorSettings,
                               fitParams,
@@ -420,7 +420,7 @@ gridCvDeep <- function(mappedData,
 
   modelParams$catFeatures <- dataset$get_cat_features()$max()
   modelParams$numFeatures <- dataset$get_numerical_features()$len()
-
+  modelParams$modelType <- modelSettings$modelType
 
   estimatorSettings <- fillEstimatorSettings(modelSettings$estimatorSettings,
                                              fitParams,
@@ -495,19 +495,19 @@ evalEstimatorSettings <- function(estimatorSettings) {
 createEstimator <- function(modelParameters,
                             estimatorSettings) {
   path <- system.file("python", package = "DeepPatientLevelPrediction")
-  
-  if (estimatorSettings$modelType == "Finetuner") {
+
+  if (modelParameters$modelType == "Finetuner") {
     estimatorSettings$finetune <- TRUE
     plpModel <- PatientLevelPrediction::loadPlpModel(modelParameters$modelPath)
-    estimatorSettings$finetuneModelPath <- 
+    estimatorSettings$finetuneModelPath <-
       file.path(normalizePath(plpModel$model), "DeepEstimatorModel.pt")
-    estimatorSettings$modelType <-
-      plpModel$modelDesign$modelSettings$estimatorSettings$modelType
-    }
-  
+    modelParameters$modelType <-
+      plpModel$modelDesign$modelSettings$modelType
+  }
+
   model <-
-    reticulate::import_from_path(estimatorSettings$modelType,
-                                 path = path)[[estimatorSettings$modelType]]
+    reticulate::import_from_path(modelParameters$modelType,
+                                 path = path)[[modelParameters$modelType]]
   estimator <- reticulate::import_from_path("Estimator", path = path)$Estimator
 
   modelParameters <- camelCaseToSnakeCaseNames(modelParameters)
diff --git a/R/MLP.R b/R/MLP.R
index 5c973bd..d798912 100644
--- a/R/MLP.R
+++ b/R/MLP.R
@@ -93,8 +93,6 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
                                       {param <- param[sample(length(param),
                                                              randomSample)]}))
   }
-  estimatorSettings$modelType <- "MLP"
-  attr(param, "settings")$modelType <- estimatorSettings$modelType
   results <- list(
     fitFunction = "fitEstimator",
     param = param,
@@ -103,8 +101,11 @@ setMultiLayerPerceptron <- function(numLayers = c(1:8),
     modelParamNames = c(
       "numLayers", "sizeHidden",
       "dropout", "sizeEmbedding"
-    )
+    ),
+    modelType = "MLP"
   )
+  attr(results$param, "settings")$modelType <- results$modelType
+  
 
   class(results) <- "modelSettings"
 
diff --git a/R/ResNet.R b/R/ResNet.R
index 2c4a2b6..88f10d3 100644
--- a/R/ResNet.R
+++ b/R/ResNet.R
@@ -137,17 +137,17 @@ setResNet <- function(numLayers = c(1:8),
                                       {param <- param[sample(length(param),
                                                              randomSample)]}))
   }
-  estimatorSettings$modelType <- "ResNet"
-  attr(param, "settings")$modelType <- estimatorSettings$modelType
   results <- list(
     fitFunction = "fitEstimator",
     param = param,
     estimatorSettings = estimatorSettings,
     saveType = "file",
     modelParamNames = c("numLayers", "sizeHidden", "hiddenFactor",
-                        "residualDropout", "hiddenDropout", "sizeEmbedding")
+                        "residualDropout", "hiddenDropout", "sizeEmbedding"),
+    modelType = "ResNet"
   )
-
+  attr(results$param, "settings")$modelType <- results$modelType
+  
   class(results) <- "modelSettings"
 
   return(results)
diff --git a/R/TransferLearning.R b/R/TransferLearning.R
index b99be24..2ffc516 100644
--- a/R/TransferLearning.R
+++ b/R/TransferLearning.R
@@ -38,17 +38,17 @@ setFinetuner <- function(modelPath,
   
   param <- list()
   param[[1]] <- list(modelPath = modelPath)
-    
-  estimatorSettings$modelType <- "Finetuner"
-  attr(param, "settings")$modelType <- estimatorSettings$modelType
+  
   results <- list(
   fitFunction = "fitEstimator",
   param = param,
   estimatorSettings = estimatorSettings,
   saveType = "file",
-  modelParamNames = c("modelPath")
+  modelParamNames = c("modelPath"),
+  modelType = "Finetuner"
   )
- 
+  attr(results$param, "settings")$modelType <- results$modelType
+  
  class(results) <- "modelSettings"
  
  return(results)
diff --git a/R/Transformer.R b/R/Transformer.R
index f8d4212..cae92c3 100644
--- a/R/Transformer.R
+++ b/R/Transformer.R
@@ -180,8 +180,6 @@ setTransformer <- function(numBlocks = 3,
                                       {param <- param[sample(length(param),
                                                              randomSample)]}))
   }
-  estimatorSettings$modelType <- "Transformer"
-  attr(param, "settings")$modelType <- estimatorSettings$modelType
   results <- list(
     fitFunction = "fitEstimator",
     param = param,
@@ -190,9 +188,10 @@ setTransformer <- function(numBlocks = 3,
     modelParamNames = c(
       "numBlocks", "dimToken", "dimOut", "numHeads",
       "attDropout", "ffnDropout", "resDropout", "dimHidden"
-    )
+    ),
+    modelType = "Transformer"
   )
-
+  attr(results$param, "settings")$modelType <- results$modelType
   class(results) <- "modelSettings"
   return(results)
 }
diff --git a/inst/python/MLP.py b/inst/python/MLP.py
index 7e91f36..511adc8 100644
--- a/inst/python/MLP.py
+++ b/inst/python/MLP.py
@@ -15,9 +15,10 @@ def __init__(
         normalization=nn.BatchNorm1d,
         dropout=None,
         dim_out: int = 1,
+        model_type="MLP"
     ):
         super(MLP, self).__init__()
-        self.name = "MLP"
+        self.name = model_type
         cat_features = int(cat_features)
         num_features = int(num_features)
         size_embedding = int(size_embedding)
diff --git a/inst/python/ResNet.py b/inst/python/ResNet.py
index 453e584..7d60410 100644
--- a/inst/python/ResNet.py
+++ b/inst/python/ResNet.py
@@ -19,9 +19,10 @@ def __init__(
         residual_dropout=0,
         dim_out: int = 1,
         concat_num=True,
+        model_type="ResNet"
     ):
         super(ResNet, self).__init__()
-        self.name = "ResNet"
+        self.name = model_type
         cat_features = int(cat_features)
         num_features = int(num_features)
         size_embedding = int(size_embedding)
diff --git a/inst/python/Transformer.py b/inst/python/Transformer.py
index 58625a9..ec3707a 100644
--- a/inst/python/Transformer.py
+++ b/inst/python/Transformer.py
@@ -35,9 +35,10 @@ def __init__(
         ffn_norm=nn.LayerNorm,
         head_norm=nn.LayerNorm,
         att_norm=nn.LayerNorm,
+        model_type="Transformer"
     ):
         super(Transformer, self).__init__()
-        self.name = "Transformer"
+        self.name = model_type
         cat_features = int(cat_features)
         num_features = int(num_features)
         num_blocks = int(num_blocks)
diff --git a/man/setEstimator.Rd b/man/setEstimator.Rd
index 2abb3ae..b8424a3 100644
--- a/man/setEstimator.Rd
+++ b/man/setEstimator.Rd
@@ -16,8 +16,7 @@ setEstimator(
   criterion = torch$nn$BCEWithLogitsLoss,
   earlyStopping = list(useEarlyStopping = TRUE, params = list(patience = 4)),
   metric = "auc",
-  seed = NULL,
-  modelType = NULL
+  seed = NULL
 )
 }
 \arguments{
diff --git a/tests/testthat/test-Estimator.R b/tests/testthat/test-Estimator.R
index 98f335f..0e317c8 100644
--- a/tests/testthat/test-Estimator.R
+++ b/tests/testthat/test-Estimator.R
@@ -2,14 +2,15 @@ catFeatures <- smallDataset$dataset$get_cat_features()$max()
 numFeatures <- smallDataset$dataset$get_numerical_features()$len()
 
 modelParameters <- list(
-  cat_features = catFeatures,
-  num_features = numFeatures,
-  size_embedding = 16,
-  size_hidden = 16,
-  num_layers = 2,
-  hidden_factor = 2
+  catFeatures = catFeatures,
+  numFeatures = numFeatures,
+  sizeEmbedding = 16,
+  sizeHidden = 16,
+  numLayers = 2,
+  hiddenFactor = 2,
+  modelType = "ResNet"
 )
-modelType = "ResNet"
+
 estimatorSettings <-
   setEstimator(learningRate = 3e-4,
                weightDecay = 0.0,
@@ -23,8 +24,7 @@ estimatorSettings <-
                scheduler =
                list(fun = torch$optim$lr_scheduler$ReduceLROnPlateau,
                     params = list(patience = 1)),
-               earlyStopping = NULL,
-               modelType = modelType)
+               earlyStopping = NULL)
 
 estimator <- createEstimator(modelParameters = modelParameters,
                              estimatorSettings = estimatorSettings)
@@ -34,19 +34,19 @@ test_that("Estimator initialization works", {
   # count parameters in both instances
   path <- system.file("python", package = "DeepPatientLevelPrediction")
   resNet <-
-    reticulate::import_from_path(estimatorSettings$modelType, 
-                                 path = path)[[estimatorSettings$modelType]]
+    reticulate::import_from_path(modelParameters$modelType, 
+                                 path = path)[[modelParameters$modelType]]
 
-  testthat::expect_equal(
+   testthat::expect_equal(
     sum(reticulate::iterate(estimator$model$parameters(),
                             function(x) x$numel())),
-    sum(reticulate::iterate(do.call(resNet, modelParameters)$parameters(),
+    sum(reticulate::iterate(do.call(resNet, camelCaseToSnakeCaseNames(modelParameters))$parameters(),
                             function(x) x$numel()))
   )
 
   testthat::expect_equal(
     estimator$model_parameters,
-    modelParameters
+    camelCaseToSnakeCaseNames(modelParameters)
   )
 })
 
@@ -114,8 +114,7 @@ test_that("estimator fitting works", {
                                     batchSize = 128,
                                     epochs = 5,
                                     device = "cpu",
-                                    metric = "loss",
-                                    modelType = modelType)
+                                    metric = "loss")
   estimator <- createEstimator(modelParameters = modelParameters,
                                estimatorSettings = estimatorSettings)
 
@@ -216,8 +215,7 @@ test_that("Estimator without earlyStopping works", {
                                     batchSize = 128,
                                     epochs = 1,
                                     device = "cpu",
-                                    earlyStopping = NULL,
-                                    modelType = modelType)
+                                    earlyStopping = NULL)
   estimator2 <- createEstimator(modelParameters = modelParameters,
                                 estimatorSettings = estimatorSettings)
   sink(nullfile())
@@ -240,8 +238,7 @@ test_that("Early stopper can use loss and stops early", {
                                            params = list(mode = c("min"),
                                                          patience = 1)),
                                     metric = "loss",
-                                    seed = 42,
-                                    modelType = modelType)
+                                    seed = 42)
 
   estimator <- createEstimator(modelParameters = modelParameters,
                                estimatorSettings = estimatorSettings)
@@ -269,8 +266,7 @@ test_that("Custom metric in estimator works", {
                                     epochs = 1,
                                     metric = list(fun = metricFun,
                                                   name = "auprc",
-                                                  mode = "max"),
-                                    modelType = modelType)
+                                                  mode = "max"))
   estimator <- createEstimator(modelParameters = modelParameters,
                                estimatorSettings = estimatorSettings)
   expect_true(is.function(estimator$metric$fun))
@@ -333,12 +329,11 @@ test_that("device as a function argument works", {
   }
 
   estimatorSettings <- setEstimator(device = getDevice,
-                                    learningRate = 3e-4,
-                                    modelType = modelType)
+                                    learningRate = 3e-4)
 
   model <- setDefaultResNet(estimatorSettings = estimatorSettings)
   model$param[[1]]$catFeatures <- 10
-
+  model$param[[1]]$modelType <- "ResNet"
   estimator <- createEstimator(modelParameters = model$param[[1]],
                                estimatorSettings = estimatorSettings)
 
@@ -347,11 +342,11 @@ test_that("device as a function argument works", {
   Sys.setenv("testDeepPLPDevice" = "meta")
 
   estimatorSettings <- setEstimator(device = getDevice,
-                                    learningRate = 3e-4,
-                                    modelType = modelType)
+                                    learningRate = 3e-4)
 
   model <- setDefaultResNet(estimatorSettings = estimatorSettings)
   model$param[[1]]$catFeatures <- 10
+  model$param[[1]]$modelType <- "ResNet"
 
   estimator <- createEstimator(modelParameters = model$param[[1]],
                                estimatorSettings = estimatorSettings)
diff --git a/tests/testthat/test-LRFinder.R b/tests/testthat/test-LRFinder.R
index 8509586..1fb71f9 100644
--- a/tests/testthat/test-LRFinder.R
+++ b/tests/testthat/test-LRFinder.R
@@ -31,8 +31,7 @@ test_that("LR scheduler that changes per batch works", {
 
 test_that("LR finder works", {
   estimatorSettings <- setEstimator(batchSize = 32L,
-                                    seed = 42,
-                                    modelType = "ResNet")
+                                    seed = 42)
   lrFinder <-
     createLRFinder(modelParameters =
                      list(cat_features =
@@ -42,7 +41,8 @@ test_that("LR finder works", {
                           size_embedding = 32L,
                           size_hidden = 64L,
                           num_layers = 1L,
-                          hidden_factor = 1L),
+                          hidden_factor = 1L,
+                          modelType = "ResNet"),
                    estimatorSettings = estimatorSettings,
                    lrSettings = list(minLr = 3e-4,
                                      maxLr = 10.0,
@@ -68,11 +68,11 @@ test_that("LR finder works with device specified by a function", {
            size_embedding = 8L,
            size_hidden = 16L,
            num_layers = 1L,
-           hidden_factor = 1L),
+           hidden_factor = 1L,
+           modelType = "ResNet"),
     estimatorSettings = setEstimator(batchSize = 32L,
                                      seed = 42,
-                                     device = deviceFun,
-                                     modelType = "ResNet"),
+                                     device = deviceFun),
     lrSettings = list(minLr = 3e-4,
                       maxLr = 10.0,
                       numLr = 20L,