Skip to content

Commit

Permalink
add dplyr specifier to functions
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Aug 17, 2022
1 parent 6a805c9 commit 4874f1f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Dataset <- torch::dataset(
dt <- data.table::data.table(rows=dataCat$rowId, cols=dataCat$columnId)
maxFeatures <- max(dt[, .N, by=rows][,N])
start <- Sys.time()
tensorList <- lapply(1:max(data %>% pull(rowId)), function(x) {
tensorList <- lapply(1:max(data %>% dplyr::pull(rowId)), function(x) {
torch::torch_tensor(dt[rows==x, cols])
})
self$lengths <- lengths
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-DeepNNTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ test_that("setDeepNNTorch with runPlp working checks", {
testthat::expect_true('performanceEvaluation' %in% names(res))

# check prediction same size as pop
testthat::expect_equal(nrow(res$prediction %>% filter(evaluationType %in% c('Train', 'Test'))),
testthat::expect_equal(nrow(res$prediction %>% dplyr::filter(evaluationType %in% c('Train', 'Test'))),
nrow(population))

# check prediction between 0 and 1
Expand All @@ -67,7 +67,7 @@ test_that("Triple layer-nn works", {
epochs= c(5), seed=NULL)

sink(nullfile())
results <- fitDeepNNTorch(trainData$Train, deepset$param, analysisId=1)
results <- fitDeepNNTorch(trainData$Train, deepset, analysisId=1)
sink()

expect_equal(class(results), 'plpModel')
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-ResNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ test_that("ResNet with runPlp working checks", {
testthat::expect_true('performanceEvaluation' %in% names(res2))

# check prediction same size as pop
testthat::expect_equal(nrow(res2$prediction %>% filter(evaluationType %in% c('Train', 'Test'))), nrow(population))
testthat::expect_equal(nrow(res2$prediction %>%
dplyr::filter(evaluationType %in% c('Train', 'Test'))), nrow(population))

# check prediction between 0 and 1
testthat::expect_gte(min(res2$prediction$value), 0)
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ test_that("length of index correct", {

testthat::expect_equal(
length(dataset$getNumericalIndex()),
n_distinct(mappedData$covariates %>% pull(covariateId)))
dplyr::n_distinct(mappedData$covariates %>% dplyr::pull(covariateId)))

})

test_that("number of num and cat features sum correctly", {

testthat::expect_equal(
dataset$numNumFeatures()+dataset$numCatFeatures(),
n_distinct(mappedData$covariates %>% pull(covariateId))
dplyr::n_distinct(mappedData$covariates %>% dplyr::pull(covariateId))
)

})
Expand All @@ -26,7 +26,7 @@ test_that("length of dataset correct", {
expect_equal(length(dataset), dataset$num$shape[1])
expect_equal(
dataset$.length(),
n_distinct(mappedData$covariates %>% pull(rowId))
dplyr::n_distinct(mappedData$covariates %>% dplyr::pull(rowId))
)

})
Expand Down

0 comments on commit 4874f1f

Please sign in to comment.