Skip to content

Commit

Permalink
Update utils.R
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Maar authored Nov 27, 2024
1 parent 612497d commit 0d21f88
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ format_targets <- function(Y, loss_obj, ylvls=NULL) {
return(list(Y=Y, Y_base=Y_base, y_dim=y_dim, ylvls=ylvls, responses=responses))
}


get_data_loader = function(..., batch_size=25L, shuffle=TRUE) {

get_data_loader<-function (..., batch_size = 25L, shuffle = TRUE, sampler=NULL)
{
ds <- torch::tensor_dataset(...)

dl <- torch::dataloader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = TRUE)

if(is.null(sampler)) dl <- torch::dataloader(ds, batch_size = batch_size, shuffle = shuffle,
pin_memory = TRUE)
else dl <- torch::dataloader(ds, batch_sampler = sampler,
pin_memory = TRUE)
return(dl)
}

Expand Down

0 comments on commit 0d21f88

Please sign in to comment.