Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Oct 12, 2023
1 parent 3fba0d6 commit a45c78f
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/testthat/test-Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,32 @@ test_that("dimHidden ratio works as expected", {
dimHiddenRatio = 4/3))

})

test_that("numerical embedding works as expected", {
embeddings <- 32L # size of embeddings
features <- 2L # number of numerical features
patients <- 9L

numTensor <- torch$randn(c(patients, features))

numericalEmbeddingClass <- reticulate::import_from_path("ResNet", path=path)$NumericalEmbedding
numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features,
embedding_dim = embeddings,
bias = TRUE)
out <- numericalEmbedding(numTensor)

# should be patients x features x embedding size
expect_equal(out$shape[[0]], patients)
expect_equal(out$shape[[1]], features)
expect_equal(out$shape[[2]], embeddings)

numericalEmbedding <- numericalEmbeddingClass(num_embeddings = features,
embedding_dim = embeddings,
bias = FALSE)

out <- numericalEmbedding(numTensor)
expect_equal(out$shape[[0]], patients)
expect_equal(out$shape[[1]], features)
expect_equal(out$shape[[2]], embeddings)

})

0 comments on commit a45c78f

Please sign in to comment.