Skip to content

Commit

Permalink
Update training.R
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Maar authored Nov 22, 2024
1 parent 949e886 commit 785e1f6
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions R/training.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ train_model <- function(model, epochs, device, train_dl, valid_dl=NULL, verbose
best_train_loss = Inf
best_val_loss = Inf
counter = 0

model$losses$auc <- NA
for (epoch in min(which(is.na(model$losses$train_l))):(epochs+ min(which(is.na(model$losses$train_l))) - 1)) {
train_l <- c()
model$training_properties$epoch <- epoch
preds <- c()
ys <- c()

### Batch evaluation ###
coro::loop(for (b in train_dl) {
Expand Down Expand Up @@ -70,7 +72,10 @@ train_model <- function(model, epochs, device, train_dl, valid_dl=NULL, verbose
optimizer$step()

train_l <- c(train_l, loss$item())
preds <- c(preds, as.numeric(output$to(device = "cpu")))
ys <- c(ys, as.numeric(b[[2]]))
})
model$losses$auc[epoch] <- Metrics::auc(ys, preds)

if(is.na(loss$item())) {
if(verbose) cat("Loss is NA. Bad training, please hyperparameters. See vignette('B-Training_neural_networks') for help.\n")
Expand Down Expand Up @@ -124,7 +129,8 @@ train_model <- function(model, epochs, device, train_dl, valid_dl=NULL, verbose
if(verbose) cat(sprintf("Loss at epoch %d: training: %3.3f, validation: %3.3f, lr: %3.5f\n",
epoch, model$losses$train_l[epoch], model$losses$valid_l[epoch],optimizer$param_groups[[1]]$lr))
}else{
if (verbose) cat(sprintf("Loss at epoch %d: %3f, lr: %3.5f\n", epoch, model$losses$train_l[epoch],optimizer$param_groups[[1]]$lr))
if (verbose) cat(sprintf("Loss at epoch %d: training: %3.3f, auc: %3.3f, lr: %3.5f\n",
epoch, model$losses$train_l[epoch], model$losses$auc[epoch], optimizer$param_groups[[1]]$lr))
}

### create plot ###
Expand Down

0 comments on commit 785e1f6

Please sign in to comment.