Skip to content

PLS regression with mixOmics engine - predictions not extracted from predict object  #132

@marioem

Description

@marioem

Hi,

vi_permute fails on PLS perdict object.

library(plsmod) # parsnip helper for pls
#> Loading required package: parsnip
library(mixOmics) # for pls regression
#> Loading required package: MASS
#> Loading required package: lattice
#> Loading required package: ggplot2
#> 
#> Loaded mixOmics 6.21.0
#> Thank you for using mixOmics!
#> Tutorials: http://mixomics.org
#> Bookdown vignette: https://mixomicsteam.github.io/Bookdown
#> Questions, issues: Follow the prompts at http://mixomics.org/contact-us
#> Cite us:  citation('mixOmics')
#> 
#> Attaching package: 'mixOmics'
#> The following objects are masked from 'package:parsnip':
#> 
#>     pls, tune
library(tidymodels)

library(doMC)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel

registerDoMC(cores = parallel::detectCores() - 1) # Mac and Linux only :-)
tidymodels_prefer()
data(concrete, package = "modeldata")

concrete <- 
  concrete %>% 
  group_by(across(-compressive_strength)) %>% 
  summarize(compressive_strength = mean(compressive_strength),
            .groups = "drop")
nrow(concrete)
#> [1] 992

set.seed(1501)
concrete_split <- initial_split(concrete, strata = compressive_strength)
concrete_train <- training(concrete_split)
concrete_test  <- testing(concrete_split)

set.seed(1502)
# concrete_folds <- 
#   vfold_cv(concrete_train, strata = compressive_strength) # , repeats = 5)

concrete_folds <- 
  bootstraps(concrete_train, times = 5) # 5 times for the sake of time for this reprex

pls_spec <- pls(num_comp = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("mixOmics")

normalized_rec <- 
  recipe(compressive_strength ~ ., data = concrete_train) %>% 
  step_normalize(all_predictors()) 

normalized <-
  workflow_set(
    preproc = list(normalized = normalized_rec),
    models = list(PLS = pls_spec)
  )

all_workflows <- normalized

bayes_ctrl <-
  control_bayes(
    seed = 828, 
    save_pred = TRUE,
    parallel_over = "everything",
    verbose = T,
    save_workflow = TRUE
  )

grid_results <-
  all_workflows %>%
  workflow_map("tune_bayes",
               resamples = concrete_folds,
               metrics = metric_set(rmse, mape, smape, mae, yardstick::ccc, huber_loss),
               iter = 5,
               control = bayes_ctrl)
#> 
#> ❯  Generating a set of 4 initial parameter results
#> ✓ Initialization complete
#> 
#> Optimizing rmse using the expected improvement
#> 
#> ── Iteration 1 ─────────────────────────────────────────────────────────────────
#> 
#> i Current best:      rmse=10.79 (@iter 0)
#> i Gaussian process model
#> ✓ Gaussian process model
#> ! No remaining candidate models
#> x Halting search
#> ✖ Optimization stopped prematurely; returning current results.

grid_results %>% 
  rank_results(rank_metric = "rmse") %>% 
  filter(.metric == "rmse") %>% 
  select(model, wflow_id, .config, rmse = mean, rank)
#> # A tibble: 4 × 5
#>   model wflow_id       .config               rmse  rank
#>   <chr> <chr>          <chr>                <dbl> <int>
#> 1 pls   normalized_PLS Preprocessor1_Model4  10.8     1
#> 2 pls   normalized_PLS Preprocessor1_Model3  10.9     2
#> 3 pls   normalized_PLS Preprocessor1_Model2  11.0     3
#> 4 pls   normalized_PLS Preprocessor1_Model1  11.7     4

best_tuneRmse <- 
  grid_results %>% 
  extract_workflow_set_result("normalized_PLS") %>% 
  select_best(metric = "rmse")

Best_test_resultsRmse <- 
  grid_results %>% 
  extract_workflow("normalized_PLS") %>% 
  finalize_workflow(best_tuneRmse) %>% 
  last_fit(split = concrete_split, metrics = metric_set(rmse, mape, smape, mae, yardstick::ccc, huber_loss))

Best_test_resultsRmse %>% 
extract_fit_parsnip() %>%
  vip::vip(method = "permute",
           num_features = 30,
           train = normalized_rec %>% prep() %>% bake(new_data = NULL), 
           target = "compressive_strength", 
           metric = "rmse", 
           nsim = 500,
           pred_wrapper = predict, 
           geom = "col", 
           all_permutations = F,
           aesthetics = list(color = "grey35"),
           include_type = T
  ) +
  ggtitle("Predictor importance - PLS")
#> Error in predicted - actual: non-numeric argument to binary operator

Best_test_resultsRmse %>% 
  extract_fit_parsnip() %>%
  vip::vip(method = "permute",
           num_features = 30,
           train = normalized_rec %>% prep() %>% bake(new_data = NULL), 
           target = "compressive_strength", 
           metric = "rmse", 
           nsim = 500,
           pred_wrapper = function(object, newdata) { pred = predict(object, newdata); print(str(pred)); pred}, 
           geom = "col", 
           all_permutations = F,
           aesthetics = list(color = "grey35"),
           include_type = T
  ) +
  ggtitle("Predictor importance - PLS")
#> List of 4
#>  $ predict : num [1:743, 1, 1:4] 17.4 14.9 17.7 19.3 16.9 ...
#>   ..- attr(*, "dimnames")=List of 3
#>   .. ..$ : chr [1:743] "1" "2" "3" "4" ...
#>   .. ..$ : chr "Y"
#>   .. ..$ : chr [1:4] "dim1" "dim2" "dim3" "dim4"
#>  $ variates: num [1:743, 1:4] -1.65 -1.88 -1.62 -1.47 -1.7 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : chr [1:743] "1" "2" "3" "4" ...
#>   .. ..$ : chr [1:4] "dim1" "dim2" "dim3" "dim4"
#>  $ B.hat   : num [1:8, 1, 1:4] 0.3927 0.1143 -0.0773 -0.2285 0.3045 ...
#>   ..- attr(*, "dimnames")=List of 3
#>   .. ..$ : chr [1:8] "cement" "blast_furnace_slag" "fly_ash" "water" ...
#>   .. ..$ : chr "Y"
#>   .. ..$ : chr [1:4] "dim1" "dim2" "dim3" "dim4"
#>  $ call    : language predict.mixo_spls(object = object, newdata = newdata)
#>  - attr(*, "class")= chr "predict"
#> NULL
#> Error in predicted - actual: non-numeric argument to binary operator

Created on 2022-10-18 with reprex v2.0.2

sessionInfo()
R version 4.2.1 (2022-06-23)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.6

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] mixOmics_6.21.0    MASS_7.3-58.1      plsmod_1.0.0       doMC_1.3.8         iterators_1.0.14   foreach_1.5.2      ranger_0.14.1      Cubist_0.4.0       lattice_0.20-45    xgboost_1.6.0.1    rpart_4.1.16       earth_5.3.1       
[13] plotmo_3.6.2       TeachingDemos_2.12 plotrix_3.8-2      Formula_1.2-4      baguette_1.0.0     rules_1.0.0        forcats_0.5.2      stringr_1.4.1      readr_2.1.3        tidyverse_1.3.2    yardstick_1.1.0    workflowsets_1.0.0
[25] workflows_1.1.0    tune_1.0.0.9000    tidyr_1.2.1        tibble_3.1.8       rsample_1.1.0      recipes_1.0.1      purrr_0.3.4        parsnip_1.0.2      modeldata_1.0.1    infer_1.0.3        ggplot2_3.3.6      dplyr_1.0.10      
[37] dials_1.0.0        scales_1.2.1       broom_1.0.1        tidymodels_1.0.0  

loaded via a namespace (and not attached):
  [1] readxl_1.4.1        backports_1.4.1     igraph_1.3.5        plyr_1.8.7          splines_4.2.1       BiocParallel_1.30.3 listenv_0.8.0       digest_0.6.29       htmltools_0.5.3     fansi_1.0.3         magrittr_2.0.3      memoise_2.0.1      
 [13] googlesheets4_1.0.1 tzdb_0.3.0          globals_0.16.1      modelr_0.1.9        gower_1.0.0         matrixStats_0.62.0  rARPACK_0.11-0      R.utils_2.12.0      hardhat_1.2.0       colorspace_2.0-3    vip_0.3.2           ggrepel_0.9.1      
 [25] rvest_1.0.3         warp_0.2.0          haven_2.5.1         xfun_0.33           callr_3.7.2         crayon_1.5.2        jsonlite_1.8.2      libcoin_1.0-9       survival_3.4-0      glue_1.6.2          gtable_0.3.1        gargle_1.2.1       
 [37] ipred_0.9-13        R.cache_0.16.0      clipr_0.8.0         future.apply_1.9.1  mvtnorm_1.1-3       DBI_1.1.3           Rcpp_1.0.9          GPfit_1.0-8         lava_1.6.10         prodlim_2019.11.13  httr_1.4.4          RColorBrewer_1.1-3 
 [49] ellipsis_0.3.2      pkgconfig_2.0.3     R.methodsS3_1.8.2   nnet_7.3-18         dbplyr_2.2.1        utf8_1.2.2          tidyselect_1.1.2    rlang_1.0.6         DiceDesign_1.9      reshape2_1.4.4      munsell_0.5.0       cellranger_1.1.0   
 [61] tools_4.2.1         cachem_1.0.6        cli_3.4.1           generics_0.1.3      evaluate_0.16       fastmap_1.1.0       yaml_2.3.5          processx_3.7.0      knitr_1.40          fs_1.5.2            future_1.28.0       nlme_3.1-159       
 [73] R.oo_1.25.0         xml2_1.3.3          compiler_4.2.1      rstudioapi_0.14     slider_0.2.2        reprex_2.0.2        lhs_1.1.5           stringi_1.7.8       ps_1.7.1            highr_0.9           RSpectra_0.16-1     Matrix_1.5-1       
 [85] styler_1.7.0        conflicted_1.1.0    vctrs_0.4.2         pillar_1.8.1        lifecycle_1.0.2     furrr_0.3.1         corpcor_1.6.10      data.table_1.14.2   R6_2.5.1            gridExtra_2.3       C50_0.1.6           parallelly_1.32.1  
 [97] codetools_0.2-18    assertthat_0.2.1    withr_2.5.0         hms_1.1.2           grid_4.2.1          timeDate_4021.106   class_7.3-20        rmarkdown_2.16      inum_1.0-4          googledrive_2.0.0   partykit_1.2-16     lubridate_1.8.0    
[109] ellipse_0.4.3 

This can be worked around by setting

pred_wrapper = function(object, newdata) { pred = predict(object, newdata); pred$predict}

but it would be nice if vi_ functions were aware of this PLS idiosyncrasy, and worked with it out-of-the-box.

Cheers,

Mariusz

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions