diff --git a/R/utils.R b/R/utils.R index d58f09c..86b0dfe 100644 --- a/R/utils.R +++ b/R/utils.R @@ -84,13 +84,13 @@ format_targets <- function(Y, loss_obj, ylvls=NULL) { return(list(Y=Y, Y_base=Y_base, y_dim=y_dim, ylvls=ylvls, responses=responses)) } - -get_data_loader = function(..., batch_size=25L, shuffle=TRUE) { - +get_data_loader<-function (..., batch_size = 25L, shuffle = TRUE, sampler=NULL) +{ ds <- torch::tensor_dataset(...) - - dl <- torch::dataloader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = TRUE) - + if(is.null(sampler)) dl <- torch::dataloader(ds, batch_size = batch_size, shuffle = shuffle, + pin_memory = TRUE) + else dl <- torch::dataloader(ds, batch_sampler = sampler, + pin_memory = TRUE) return(dl) }