Skip to content

Commit

Permalink
Update model.R
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Maar authored Dec 4, 2024
1 parent 0d21f88 commit c615589
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,97 @@ inceptionBlock_A_2D_reduction = torch::nn_module(
}
)

inceptionBlock_A_100_reduction = torch::nn_module(


initialize = function(in_channels, channel_mult=16L, dropout=0){

self$branchA = torch::nn_sequential(
convBlock(
in_channels = in_channels
, out_channels = 4L*channel_mult
, kernel_size = c(1L, 100L)
, stride = 1L
, dropout = dropout
, padding = c(0L, 50L)
),
convBlock(
in_channels = 4L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = c(3L, 100L)
, dropout = dropout
, stride = 1L
, padding = c(1L, 50L)
),
convBlock(
in_channels = 6L*channel_mult
, out_channels = 6L*channel_mult
, kernel_size = c(3L, 100L)
, stride = 1L
, padding = 0L
, dropout = dropout
)
)

self$branchB = torch::nn_sequential(
convBlock(
in_channels = in_channels
, out_channels = 3L*channel_mult
, kernel_size = c(1L, 100L)
, stride = 1L
, dropout = dropout
, padding = c(0L, 50L)
),
convBlock(
in_channels = 3L*channel_mult
, out_channels = 4L*channel_mult
, kernel_size = c(3L, 100L)
, dropout = dropout
, stride = 1L
, padding = 0L
)
)

self$branchC = torch::nn_sequential(
torch::nn_avg_pool2d(
kernel_size = c(3L,100L)
, stride = 1L
, padding = c(0L, 50L)
),
convBlock(
in_channels = in_channels
, out_channels = 4L*channel_mult
, kernel_size = c(1L, 100L)
, stride = 1L
, dropout = dropout
, padding = 0L
)
)

self$branchD = convBlock(
in_channels = in_channels
, out_channels = 4L*channel_mult
, kernel_size = c(3L,100L)
, stride = 1L
, dropout = dropout
, padding = 0L
)

},
forward = function(x){
branchARes = self$branchA(x)
branchBRes = self$branchB(x)
branchCRes = self$branchC(x)
branchDRes = self$branchD(x)
# print(branchARes$size())
# print(branchBRes$size())
# print(branchCRes$size())
# print(branchDRes$size())
res = torch::torch_cat(list(branchARes, branchBRes, branchCRes, branchDRes),2L)
res
}
)


inceptionBlock_A_1D = torch::nn_module(

Expand Down Expand Up @@ -528,6 +617,13 @@ build_cnn<-function (input_shape, output_shape, architecture)
input_shape[1] <- 18L * layer$channel_mult
counter <- counter + 1
}
else if (layer$type == "red100") {
net_layers[[counter]] <- inceptionBlock_A_100_reduction(input_shape[1],
layer$channel_mult, layer$dropout)
input_shape <- c(18L * layer$channel_mult, input_shape[2:3] -
c(2, 99))
counter <- counter + 1
}
}
else if (inherits(layer, "conv")) {
net_layers[[counter]] <- switch(input_dim, torch::nn_conv1d(input_shape[1],
Expand Down

0 comments on commit c615589

Please sign in to comment.