diff --git a/R/model.R b/R/model.R index c17dc34..280d95d 100644 --- a/R/model.R +++ b/R/model.R @@ -95,9 +95,8 @@ build_dnn = function(input, output, hidden, activation, bias, dropout, embedding } return(net) } - convBlock = torch::nn_module( - initialize = function(in_channels, out_channels, kernel_size, stride, padding){ + initialize = function(in_channels, out_channels, kernel_size, stride, padding, dropout){ self$conv=torch::nn_conv2d(in_channels = in_channels , out_channels = out_channels , kernel_size = kernel_size @@ -105,18 +104,31 @@ convBlock = torch::nn_module( , padding = padding) self$batchNorm = torch::nn_batch_norm2d(out_channels) self$activation = torch::nn_relu() + self$dropout_val = dropout + if(dropout>0){ + self$dropout = torch::nn_dropout2d(p=dropout) + } }, forward = function(x){ - x |> - self$conv() |> - self$batchNorm() |> - self$activation() + if(self$dropout_val>0){ + x |> + self$conv() |> + self$batchNorm() |> + self$activation()|> + self$dropout() + } else { + x |> + self$conv() |> + self$batchNorm() |> + self$activation() + } + } ) inceptionBlock_A_2D = torch::nn_module( - initialize = function(in_channels, channel_mult=16L){ + initialize = function(in_channels, channel_mult=16L,dropout=0){ self$branchA = torch::nn_sequential( convBlock( @@ -124,19 +136,22 @@ inceptionBlock_A_2D = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = 1L , stride = 1L - , padding = 0L), + , padding = 0L + , dropout = dropout), convBlock( in_channels = 4L*channel_mult , out_channels = 6L*channel_mult , kernel_size = 3L , stride = 1L - , padding = 1L), + , padding = 1L + , dropout = dropout), convBlock( in_channels = 6L*channel_mult , out_channels = 6L*channel_mult , kernel_size = 3L , stride = 1L - , padding = 1L) + , padding = 1L + , dropout = dropout) ) self$branchB = torch::nn_sequential( @@ -146,6 +161,7 @@ inceptionBlock_A_2D = torch::nn_module( , kernel_size = 1L , stride = 1L , padding = 0L + , dropout = dropout ), convBlock( in_channels = 3L*channel_mult @@ -153,6 +169,7 @@ inceptionBlock_A_2D = torch::nn_module( , kernel_size = 3L , stride = 1L , padding = 1L + , dropout = dropout ) ) @@ -167,6 +184,7 @@ inceptionBlock_A_2D = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = 0L ) ) @@ -177,6 +195,7 @@ inceptionBlock_A_2D = torch::nn_module( , kernel_size = 1L , stride = 1L , padding = 0L + , dropout = dropout ) }, @@ -189,14 +208,14 @@ inceptionBlock_A_2D = torch::nn_module( # print(branchBRes$size()) # print(branchCRes$size()) # print(branchDRes$size()) - res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L) + res = torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L) res } ) inceptionBlock_A_2D_reduction = torch::nn_module( - initialize = function(in_channels, channel_mult=16L){ + initialize = function(in_channels, channel_mult=16L, dropout=0){ self$branchA = torch::nn_sequential( convBlock( @@ -204,12 +223,14 @@ inceptionBlock_A_2D_reduction = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = c(0L, 1L) ), convBlock( in_channels = 4L*channel_mult , out_channels = 6L*channel_mult , kernel_size = 3L + , dropout = dropout , stride = 1L , padding = 0L ), @@ -219,6 +240,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module( , kernel_size = 3L , stride = 1L , padding = 1L + , dropout = dropout ) ) @@ -228,12 +250,14 @@ inceptionBlock_A_2D_reduction = torch::nn_module( , out_channels = 3L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = c(0L, 1L) ), convBlock( in_channels = 3L*channel_mult , out_channels = 4L*channel_mult , kernel_size = 3L + , dropout = dropout , stride = 1L , padding = 0L ) @@ -250,6 +274,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = c(0L,1L) ) ) @@ -259,6 +284,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = c(3L,1L) , stride = 1L + , dropout = dropout , padding = 0L ) @@ -272,7 +298,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module( # print(branchBRes$size()) # print(branchCRes$size()) # print(branchDRes$size()) - res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L) + res = torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L) res } ) @@ -283,7 +309,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module( inceptionBlock_A_1D = torch::nn_module( - initialize = function(in_channels, channel_mult=16L){ + initialize = function(in_channels, channel_mult=16L, dropout=0){ self$branchA = torch::nn_sequential( convBlock( @@ -291,6 +317,7 @@ inceptionBlock_A_1D = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = 0L ), convBlock( @@ -298,12 +325,14 @@ inceptionBlock_A_1D = torch::nn_module( , out_channels = 6L*channel_mult , kernel_size = c(1L,3L) , stride = 1L + , dropout = dropout , padding = c(0L,1L) ), convBlock( in_channels = 6L*channel_mult , out_channels = 6L*channel_mult , kernel_size = c(1L,3L) + , dropout = dropout , stride = 1L , padding = c(0L,1L) ) @@ -315,6 +344,7 @@ inceptionBlock_A_1D = torch::nn_module( , out_channels = 3L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = 0L ), convBlock( @@ -322,6 +352,7 @@ inceptionBlock_A_1D = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = c(1L,3L) , stride = 1L + , dropout = dropout , padding = c(0L,1L) ) ) @@ -337,6 +368,7 @@ inceptionBlock_A_1D = torch::nn_module( , out_channels = 4L*channel_mult , kernel_size = 1L , stride = 1L + , dropout = dropout , padding = c(0L,1L) ) ) @@ -347,6 +379,7 @@ inceptionBlock_A_1D = torch::nn_module( , kernel_size = 1L , stride = 1L , padding = 0L + , dropout = dropout ) }, @@ -359,13 +392,12 @@ inceptionBlock_A_1D = torch::nn_module( # print(branchBRes$size()) # print(branchCRes$size()) # print(branchDRes$size()) - res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L) + res = torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L) res } ) - -inceptionBlock <- function(type, channel_mult){ - layer <- list(channel_mult=channel_mult, type=type) +inceptionBlock <- function(type, channel_mult, dropout){ + layer <- list(channel_mult=channel_mult, type=type, dropout = dropout) class(layer) <- c("inceptionBlock", "citolayer") return(layer) } @@ -394,15 +426,15 @@ build_cnn<-function (input_shape, output_shape, architecture) } else if(inherits(layer, "inceptionBlock")){ if(layer$type=="2D"){ - net_layers[[counter]] <- inceptionBlock_A_2D(input_shape[1], layer$channel_mult) + net_layers[[counter]] <- inceptionBlock_A_2D(input_shape[1], layer$channel_mult, layer$dropout) input_shape[1]<-18L*layer$channel_mult counter<-counter+1 } else if(layer$type=="red"){ - net_layers[[counter]] <- inceptionBlock_A_2D_reduction(input_shape[1], layer$channel_mult) + net_layers[[counter]] <- inceptionBlock_A_2D_reduction(input_shape[1], layer$channel_mult, layer$dropout) input_shape<-c(18L*layer$channel_mult, input_shape[2:3]-c(2,0)) counter<-counter+1 } else if(layer$type=="1D"){ - net_layers[[counter]] <- inceptionBlock_A_1D(input_shape[1], layer$channel_mult) + net_layers[[counter]] <- inceptionBlock_A_1D(input_shape[1], layer$channel_mult, layer$dropout) input_shape[1]<-18L*layer$channel_mult counter<-counter+1 }