Skip to content

Commit

Permalink
Update model.R
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Maar authored Oct 22, 2024
1 parent a19da80 commit f5ef138
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,48 +95,63 @@ build_dnn = function(input, output, hidden, activation, bias, dropout, embedding
}
return(net)
}

convBlock = torch::nn_module(
initialize = function(in_channels, out_channels, kernel_size, stride, padding){
initialize = function(in_channels, out_channels, kernel_size, stride, padding, dropout){
self$conv=torch::nn_conv2d(in_channels = in_channels
, out_channels = out_channels
, kernel_size = kernel_size
, stride = stride
, padding = padding)
self$batchNorm = torch::nn_batch_norm2d(out_channels)
self$activation = torch::nn_relu()
self$dropout_val = dropout
if(dropout>0){
self$dropout = torch::nn_dropout2d(p=dropout)
}
},
forward = function(x){
x |>
self$conv() |>
self$batchNorm() |>
self$activation()
if(self$dropout_val>0){
x |>
self$conv() |>
self$batchNorm() |>
self$activation()|>
self$dropout()
} else {
x |>
self$conv() |>
self$batchNorm() |>
self$activation()
}

}
)
inceptionBlock_A_2D = torch::nn_module(


initialize = function(in_channels, channel_mult=16L){
initialize = function(in_channels, channel_mult=16L,dropout=0){

self$branchA = torch::nn_sequential(
convBlock(
in_channels = in_channels
, out_channels = 4L*channel_mult
, kernel_size = 1L
, stride = 1L
, padding = 0L),
, padding = 0L
, dropout = dropout),
convBlock(
in_channels = 4L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = 3L
, stride = 1L
, padding = 1L),
, padding = 1L
, dropout = dropout),
convBlock(
in_channels = 6L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = 3L
, stride = 1L
, padding = 1L)
, padding = 1L
, dropout = dropout)
)

self$branchB = torch::nn_sequential(
Expand All @@ -146,13 +161,15 @@ inceptionBlock_A_2D = torch::nn_module(
, kernel_size = 1L
, stride = 1L
, padding = 0L
, dropout = dropout
),
convBlock(
in_channels = 3L*channel_mult
, out_channels = 4L*channel_mult
, kernel_size = 3L
, stride = 1L
, padding = 1L
, dropout = dropout
)
)

Expand All @@ -167,6 +184,7 @@ inceptionBlock_A_2D = torch::nn_module(
, out_channels = 4L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = 0L
)
)
Expand All @@ -177,6 +195,7 @@ inceptionBlock_A_2D = torch::nn_module(
, kernel_size = 1L
, stride = 1L
, padding = 0L
, dropout = dropout
)

},
Expand All @@ -189,27 +208,29 @@ inceptionBlock_A_2D = torch::nn_module(
# print(branchBRes$size())
# print(branchCRes$size())
# print(branchDRes$size())
res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res = torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res
}
)
inceptionBlock_A_2D_reduction = torch::nn_module(


initialize = function(in_channels, channel_mult=16L){
initialize = function(in_channels, channel_mult=16L, dropout=0){

self$branchA = torch::nn_sequential(
convBlock(
in_channels = in_channels
, out_channels = 4L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = c(0L, 1L)
),
convBlock(
in_channels = 4L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = 3L
, dropout = dropout
, stride = 1L
, padding = 0L
),
Expand All @@ -219,6 +240,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
, kernel_size = 3L
, stride = 1L
, padding = 1L
, dropout = dropout
)
)

Expand All @@ -228,12 +250,14 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
, out_channels = 3L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = c(0L, 1L)
),
convBlock(
in_channels = 3L*channel_mult
, out_channels = 4L*channel_mult
, kernel_size = 3L
, dropout = dropout
, stride = 1L
, padding = 0L
)
Expand All @@ -250,6 +274,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
, out_channels = 4L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = c(0L,1L)
)
)
Expand All @@ -259,6 +284,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
, out_channels = 4L*channel_mult
, kernel_size = c(3L,1L)
, stride = 1L
, dropout = dropout
, padding = 0L
)

Expand All @@ -272,7 +298,7 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
# print(branchBRes$size())
# print(branchCRes$size())
# print(branchDRes$size())
res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res = torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res
}
)
Expand All @@ -283,27 +309,30 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
inceptionBlock_A_1D = torch::nn_module(


initialize = function(in_channels, channel_mult=16L){
initialize = function(in_channels, channel_mult=16L, dropout=0){

self$branchA = torch::nn_sequential(
convBlock(
in_channels = in_channels
, out_channels = 4L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = 0L
),
convBlock(
in_channels = 4L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = c(1L,3L)
, stride = 1L
, dropout = dropout
, padding = c(0L,1L)
),
convBlock(
in_channels = 6L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = c(1L,3L)
, dropout = dropout
, stride = 1L
, padding = c(0L,1L)
)
Expand All @@ -315,13 +344,15 @@ inceptionBlock_A_1D = torch::nn_module(
, out_channels = 3L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = 0L
),
convBlock(
in_channels = 3L*channel_mult
, out_channels = 4L*channel_mult
, kernel_size = c(1L,3L)
, stride = 1L
, dropout = dropout
, padding = c(0L,1L)
)
)
Expand All @@ -337,6 +368,7 @@ inceptionBlock_A_1D = torch::nn_module(
, out_channels = 4L*channel_mult
, kernel_size = 1L
, stride = 1L
, dropout = dropout
, padding = c(0L,1L)
)
)
Expand All @@ -347,6 +379,7 @@ inceptionBlock_A_1D = torch::nn_module(
, kernel_size = 1L
, stride = 1L
, padding = 0L
, dropout = dropout
)

},
Expand All @@ -359,13 +392,12 @@ inceptionBlock_A_1D = torch::nn_module(
# print(branchBRes$size())
# print(branchCRes$size())
# print(branchDRes$size())
res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res = torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res
}
)

inceptionBlock <- function(type, channel_mult){
layer <- list(channel_mult=channel_mult, type=type)
inceptionBlock <- function(type, channel_mult, dropout){
layer <- list(channel_mult=channel_mult, type=type, dropout = dropout)
class(layer) <- c("inceptionBlock", "citolayer")
return(layer)
}
Expand Down Expand Up @@ -394,15 +426,15 @@ build_cnn<-function (input_shape, output_shape, architecture)
}
else if(inherits(layer, "inceptionBlock")){
if(layer$type=="2D"){
net_layers[[counter]] <- inceptionBlock_A_2D(input_shape[1], layer$channel_mult)
net_layers[[counter]] <- inceptionBlock_A_2D(input_shape[1], layer$channel_mult, layer$dropout)
input_shape[1]<-18L*layer$channel_mult
counter<-counter+1
} else if(layer$type=="red"){
net_layers[[counter]] <- inceptionBlock_A_2D_reduction(input_shape[1], layer$channel_mult)
net_layers[[counter]] <- inceptionBlock_A_2D_reduction(input_shape[1], layer$channel_mult, layer$dropout)
input_shape<-c(18L*layer$channel_mult, input_shape[2:3]-c(2,0))
counter<-counter+1
} else if(layer$type=="1D"){
net_layers[[counter]] <- inceptionBlock_A_1D(input_shape[1], layer$channel_mult)
net_layers[[counter]] <- inceptionBlock_A_1D(input_shape[1], layer$channel_mult, layer$dropout)
input_shape[1]<-18L*layer$channel_mult
counter<-counter+1
}
Expand Down

0 comments on commit f5ef138

Please sign in to comment.