Skip to content

Commit

Permalink
Fix DeepNNTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
egillax committed Aug 18, 2022
1 parent 3b5aea5 commit 38c88d6
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 79 deletions.
46 changes: 30 additions & 16 deletions R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Dataset <- torch::dataset(
self$numericalIndex <- NULL
}


# add labels if training (make 0 vector for prediction)
if (!is.null(labels)) {
self$target <- torch::torch_tensor(labels)
Expand All @@ -35,6 +34,7 @@ Dataset <- torch::dataset(
# Weight to add in loss function to positive class
self$posWeight <- (self$target == 0)$sum() / self$target$sum()
# for DeepNNTorch
self$useAll <- all
if (all) {
self$all <- torch::torch_tensor(as.matrix(data), dtype = torch::torch_float32())
self$cat <- NULL
Expand Down Expand Up @@ -103,23 +103,37 @@ Dataset <- torch::dataset(
},
.getbatch = function(item) {
if (length(item) == 1) {
# add leading singleton dimension since models expects 2d tensors
return(list(
batch = list(
cat = self$cat[item]$unsqueeze(1),
num = self$num[item]$unsqueeze(1)
),
target = self$target[item]$unsqueeze(1)
))
return(self$.getBatchSingle(item))
} else {
return(self$.getBatchRegular(item))
}
},
.getBatchSingle = function(item) {
# add leading singleton dimension since models expects 2d tensors
if (self$useAll) {
batch <- list(all = self$all[item]$unsqueeze(1))
} else {
batch <- list(cat = self$cat[item]$unsqueeze(1),
num = self$num[item]$unsqueeze(1))
}
return(list(
batch = batch,
target = self$target[item]$unsqueeze(1)
))
},
.getBatchRegular = function(item) {
if (self$useAll) {
batch <- list(all = self$all[item])
} else {
return(list(
batch = list(
cat = self$cat[item],
num = self$num[item]
),
target = self$target[item]
))
batch = list(
cat = self$cat[item],
num = self$num[item]
)
}
return(list(
batch = batch,
target = self$target[item]
))
},
.length = function() {
self$target$size()[[1]] # shape[1]
Expand Down
Loading

0 comments on commit 38c88d6

Please sign in to comment.