Skip to content

TorchLearner: support / example for custom batch_sampler #420

@tdhock

Description

@tdhock

Related to #417

I expected to be able to use batch_sampler argument in mlr3torch, similar to how we can use it in torch.
However I observe that there are errors in mlr3torch, even using a batch_sampler that works with torch.

Here is an example of a batch_sampler that works in pure torch.

# 1. Define a simple dataset
my_dataset <- torch::dataset(
  name = "MyDataset",
  initialize = function() {
    self$data <- torch::torch_tensor(1:5)
  },
  .getitem = function(index) {
    self$data[index]
  },
  .length = function() {
    length(self$data)
  }
)
# 2. Define a custom batch sampler 
batch_sampler_class <- torch::sampler(
  "BatchSampler",
  initialize = function(data_source) {
    self$data_source <- data_source
  },
  .iter = function() {
    batch_size <- 2
    indices <- 1:self$.length()
    batch_vec <- (indices-1) %/% batch_size
    batch_list <- rev(split(indices, batch_vec))
    count <- 0L
    function() {
      if (count < length(batch_list)) {
        count <<- count + 1L
        return(batch_list[[count]])
      }
      coro::exhausted()
    }
  },
  .length = function() {
    length(self$data_source)
  }
)
# 3. Instantiate dataset and sampler
ds <- my_dataset()
batch_sampler_instance <- batch_sampler_class(ds)
# 4. Create dataloader with custom sampler
dl <- torch::dataloader(ds, batch_sampler = batch_sampler_instance)
# 5. Iterate through the dataloader
batches <- list()
coro::loop(for (batch in dl) {
  batches[[length(batches) + 1L]] = batch
})
batches

Running this code gives me the output below:

> batches
[[1]]
torch_tensor
 5
[ CPULongType{1} ]

[[2]]
torch_tensor
 3
 4
[ CPULongType{2} ]

[[3]]
torch_tensor
 1
 2
[ CPULongType{2} ]

The output above indicates that there are three batches. First with one sample, and the next two with two samples each.

Below I try something similar with mlr3torch current main.

sonar_task <- mlr3::tsk("sonar")
sonar_ingress <- list(feat=mlr3torch::TorchIngressToken(
  features=sonar_task$col_roles$feature,
  batchgetter=mlr3torch::batchgetter_num))
target_batchgetter <- function(data){
  browser()
  torch::torch_tensor()
}
sonar_dataset <- mlr3torch::task_dataset(sonar_task, sonar_ingress, target_batchgetter)
batch_sampler_instance <- batch_sampler_class(sonar_dataset)
inst_learner <- mlr3torch::LearnerTorchMLP$new(task_type="classif")
inst_learner$param_set$set_values(
  epochs=10,
  batch_size=20,
  batch_sampler=batch_sampler_instance)
inst_learner$train(sonar_task)

Note that in the code above I am using batch_sampler=batch_sampler_instance but it would be more convenient for benchmarks to provide batch_sampler=batch_sampler_class as proposed in #419.

I get the error below,

> inst_learner$train(sonar_task)
Error in ctx$batch$y$to(device = ctx$device) : 
  attempt to apply non-function
> traceback()
15: train_loop(ctx, callbacks)
14: learner_torch_train(self, private, super, task, param_vals)
13: force(expr)
12: with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, 
        num_interop_threads = param_vals$num_interop_threads, expr = {
            learner_torch_train(self, private, super, task, param_vals)
        })
11: .__LearnerTorch__.train(self = self, private = private, super = super, 
        task = task)
10: get_private(learner)$.train(task)
9: .f(learner = <environment>, task = <environment>)
8: eval(expr, p)
7: eval(expr, p)
6: eval.parent(expr, n = 1L)
5: invoke(.f, .args = .args, .opts = .opts, .seed = .seed, .timeout = .timeout)
4: encapsulate(learner$encapsulation["train"], .f = train_wrapper, 
       .args = list(learner = learner, task = task), .pkgs = learner$packages, 
       .seed = NA_integer_, .timeout = learner$timeout["train"])
3: learner_train(learner, task, train_row_ids = train_row_ids, mode = mode)
2: .__Learner__train(self = self, private = private, super = super, 
       task = task, row_ids = row_ids)
1: inst_learner$train(sonar_task)

It seems that using a sampler instead of batch_sampler could be a work-around, but it would be inefficient (for loop over observations is much slower than for loop over batches).

@sebffischer Do you have any idea where this error is coming from? or how to fix?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions