diff --git a/R/cnn.R b/R/cnn.R index 8cb45d4..35238e9 100644 --- a/R/cnn.R +++ b/R/cnn.R @@ -129,7 +129,7 @@ cnn<-function (X = NULL, Y = NULL, architecture, loss = c("mse", "mae", early_stopping = NULL, lr_scheduler = NULL, custom_parameters = NULL, device = c("cpu", "cuda", "mps"), plot = TRUE, verbose = TRUE, train_data_loader=NULL, valid_data_loader=NULL) { - use_custom_dl<-!is.null(dataloader) + use_custom_dl<-!is.null(train_data_loader) if(use_custom_dl){ sample_batch = train_data_loader$.iter()$.next() X<- as.array(sample_batch$X)