Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Sep 9, 2024
1 parent 34be21f commit c5ecf29
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
5 changes: 2 additions & 3 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ fitEstimator <- function(trainData,
included = incs,
covariateValue = 0,
isNumeric = .data$columnId %in% cvResult$numericalIndex
# get mapping maybe here
) %>%
left_join(cvResult$cat1Mapping %>% rename(cat1Idx = index), by = "covariateId") %>%
left_join(cvResult$cat2Mapping %>% rename(cat2Idx = index), by = "covariateId")

comp <- start - Sys.time()
modelSettings$estimatorSettings$initStrategy <- NULL
result <- list(
model = cvResult$estimator,
preprocessing = list(
Expand Down Expand Up @@ -271,7 +271,6 @@ fitEstimator <- function(trainData,
hyperParamSearch = hyperSummary
),
covariateImportance = covariateRef
# also return mapping as part of covariateRef above, not necessary to do separately
)

class(result) <- "plpModel"
Expand Down Expand Up @@ -318,7 +317,7 @@ predictDeepEstimator <- function(plpModel,
)
data <- createDataset(mappedData, plpModel = plpModel)
}

# get predictions
prediction <- cohort
if (is.character(plpModel$model)) {
Expand Down
12 changes: 4 additions & 8 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import json
import os

# this one needs to have a parameter to not do the mapping again

class Data(Dataset):
def __init__(self, data, labels=None, numerical_features=None,
in_cat_1_mapping=None, in_cat_2_mapping=None):
Expand Down Expand Up @@ -50,12 +48,6 @@ def __init__(self, data, labels=None, numerical_features=None,
).lazy()
else:
data = pl.scan_ipc(pathlib.Path(data).joinpath("covariates/*.arrow"))

# # Fetch only the first few rows
# data_head = data.limit(100).collect()
# print("Head of the data:")
# print(data_head)

observations = data.select(pl.col("rowId").max()).collect()[0, 0]
# detect features are numeric
if numerical_features is None:
Expand Down Expand Up @@ -248,3 +240,7 @@ def __getitem__(self, item):
batch["num"] = batch["num"].unsqueeze(0)
return [batch, self.target[item].squeeze()]





39 changes: 39 additions & 0 deletions tests/testthat/test-PoincareTransformer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

test_that("Poincare Transformer works", {
settings <- setCustomEmbeddingTransformer("/Users/henrikjohn/Desktop/poincare_model_dim_3.pt")

results <- PatientLevelPrediction::runPlp(
plpData = plpData,
outcomeId = 3,
modelSettings = settings,
analysisId = "Analysis_Poincare",
analysisName = "Testing Deep Learning",
populationSettings = populationSet,
splitSettings = PatientLevelPrediction::createDefaultSplitSetting(),
sampleSettings = PatientLevelPrediction::createSampleSettings(),
featureEngineeringSettings = PatientLevelPrediction::createFeatureEngineeringSettings(),
preprocessSettings = PatientLevelPrediction::createPreprocessSettings(),
executeSettings = PatientLevelPrediction::createExecuteSettings(
runSplitData = TRUE,
runSampleData = FALSE,
runfeatureEngineering = FALSE,
runPreprocessData = FALSE,
runModelDevelopment = TRUE,
runCovariateSummary = FALSE
),
saveDirectory = file.path(testLoc, "Poincare")
)


params <- defaultTransformer$param[[1]]

expect_equal(params$numBlocks, 3)
expect_equal(params$dimToken, 192)
expect_equal(params$numHeads, 8)
expect_equal(params$resDropout, 0.0)
expect_equal(params$attDropout, 0.2)

settings <- attr(defaultTransformer, "settings")

expect_equal(settings$name, "defaultTransformer")
})

0 comments on commit c5ecf29

Please sign in to comment.