-
-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Related to #417
I expected to be able to use batch_sampler argument in mlr3torch, similar to how we can use it in torch.
However I observe that there are errors in mlr3torch, even using a batch_sampler that works with torch.
Here is an example of a batch_sampler that works in pure torch.
# 1. Define a simple dataset
my_dataset <- torch::dataset(
name = "MyDataset",
initialize = function() {
self$data <- torch::torch_tensor(1:5)
},
.getitem = function(index) {
self$data[index]
},
.length = function() {
length(self$data)
}
)
# 2. Define a custom batch sampler
batch_sampler_class <- torch::sampler(
"BatchSampler",
initialize = function(data_source) {
self$data_source <- data_source
},
.iter = function() {
batch_size <- 2
indices <- 1:self$.length()
batch_vec <- (indices-1) %/% batch_size
batch_list <- rev(split(indices, batch_vec))
count <- 0L
function() {
if (count < length(batch_list)) {
count <<- count + 1L
return(batch_list[[count]])
}
coro::exhausted()
}
},
.length = function() {
length(self$data_source)
}
)
# 3. Instantiate dataset and sampler
ds <- my_dataset()
batch_sampler_instance <- batch_sampler_class(ds)
# 4. Create dataloader with custom sampler
dl <- torch::dataloader(ds, batch_sampler = batch_sampler_instance)
# 5. Iterate through the dataloader
batches <- list()
coro::loop(for (batch in dl) {
batches[[length(batches) + 1L]] = batch
})
batches
Running this code gives me the output below:
> batches
[[1]]
torch_tensor
5
[ CPULongType{1} ]
[[2]]
torch_tensor
3
4
[ CPULongType{2} ]
[[3]]
torch_tensor
1
2
[ CPULongType{2} ]
The output above indicates that there are three batches. First with one sample, and the next two with two samples each.
Below I try something similar with mlr3torch current main.
sonar_task <- mlr3::tsk("sonar")
sonar_ingress <- list(feat=mlr3torch::TorchIngressToken(
features=sonar_task$col_roles$feature,
batchgetter=mlr3torch::batchgetter_num))
target_batchgetter <- function(data){
browser()
torch::torch_tensor()
}
sonar_dataset <- mlr3torch::task_dataset(sonar_task, sonar_ingress, target_batchgetter)
batch_sampler_instance <- batch_sampler_class(sonar_dataset)
inst_learner <- mlr3torch::LearnerTorchMLP$new(task_type="classif")
inst_learner$param_set$set_values(
epochs=10,
batch_size=20,
batch_sampler=batch_sampler_instance)
inst_learner$train(sonar_task)
Note that in the code above I am using batch_sampler=batch_sampler_instance
but it would be more convenient for benchmarks to provide batch_sampler=batch_sampler_class
as proposed in #419.
I get the error below,
> inst_learner$train(sonar_task)
Error in ctx$batch$y$to(device = ctx$device) :
attempt to apply non-function
> traceback()
15: train_loop(ctx, callbacks)
14: learner_torch_train(self, private, super, task, param_vals)
13: force(expr)
12: with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads,
num_interop_threads = param_vals$num_interop_threads, expr = {
learner_torch_train(self, private, super, task, param_vals)
})
11: .__LearnerTorch__.train(self = self, private = private, super = super,
task = task)
10: get_private(learner)$.train(task)
9: .f(learner = <environment>, task = <environment>)
8: eval(expr, p)
7: eval(expr, p)
6: eval.parent(expr, n = 1L)
5: invoke(.f, .args = .args, .opts = .opts, .seed = .seed, .timeout = .timeout)
4: encapsulate(learner$encapsulation["train"], .f = train_wrapper,
.args = list(learner = learner, task = task), .pkgs = learner$packages,
.seed = NA_integer_, .timeout = learner$timeout["train"])
3: learner_train(learner, task, train_row_ids = train_row_ids, mode = mode)
2: .__Learner__train(self = self, private = private, super = super,
task = task, row_ids = row_ids)
1: inst_learner$train(sonar_task)
It seems that using a sampler instead of batch_sampler could be a work-around, but it would be inefficient (for loop over observations is much slower than for loop over batches).
@sebffischer Do you have any idea where this error is coming from? or how to fix?