diff --git a/R/cnn.R b/R/cnn.R index 35238e9..85a746e 100644 --- a/R/cnn.R +++ b/R/cnn.R @@ -132,8 +132,8 @@ cnn<-function (X = NULL, Y = NULL, architecture, loss = c("mse", "mae", use_custom_dl<-!is.null(train_data_loader) if(use_custom_dl){ sample_batch = train_data_loader$.iter()$.next() - X<- as.array(sample_batch$X) - Y<- c(as.matrix(sample_batch$Y)) + X<- as.array(sample_batch[[1]]) + Y<- c(as.matrix(sample_batch[[2]])) batchsize=2L if(!xor(is.null(valid_data_loader), validation!=0)) stop("You need to provided an validation data loader if you want to use a validation split with a custom data loader") }