Skip to content

Commit

Permalink
added possibility to use transfer learning with non RGB images
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Maar authored Oct 24, 2024
1 parent a69d6d5 commit 292c0ed
Showing 1 changed file with 52 additions and 15 deletions.
67 changes: 52 additions & 15 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,33 +410,70 @@ build_cnn<-function (input_shape, output_shape, architecture)
transfer <- FALSE
for (layer in architecture) {
if (inherits(layer, "transfer")) {
if (!(input_dim == 2 && input_shape[1] == 3))
stop("The pretrained models only work on RGB images: [n, 3, x, y]")
if (!(input_dim == 2))
stop("The pretrained models only work on images: [n, channels, x, y]")

transfer_model <- get_pretrained_model(layer$name,
layer$pretrained)
if (layer$freeze)
transfer_model <- freeze_weights(transfer_model)
if(input_shape[1]!=3){
transfer_model<-tryCatch(
{
cur<-transfer_model
call_string<-"transfer_model"
while(length(cur$children)>0){
cur_name<-names(cur$children)[1]
call_string<-paste0(call_string, "[['", cur_name, "']]")
cur<-eval(rlang::parse_expr(call_string))
}
first_layer <- torch::nn_conv2d(
in_channels = as.integer(input_shape[1])
, out_channels = as.integer(cur$out_channels)
, kernel_size = as.integer(cur$kernel_size)
, stride = as.integer(cur$stride)
, padding = as.integer(cur$padding)
, dilation = as.integer(cur$dilation)
, groups = as.integer(cur$groups)
, bias = !is.null(cur$bias)
, padding_mode = cur$padding_mode
)
call_string<-paste0(call_string, "<-first_layer")
eval(rlang::parse_expr(call_string))
transfer_model
},error = function(x)stop(paste0("automatic input layer adjustment to ", input_shape[1], " layers failed with error message:\n", x))


)
}
if (!layer$replace_classifier) {
transfer_model <- replace_output_layer(transfer_model,
output_shape)

return(transfer_model)
}
transfer <- TRUE
input_shape <- get_transfer_output_shape(layer$name)
}
else if(inherits(layer, "inceptionBlock")){
if(layer$type=="2D"){
net_layers[[counter]] <- inceptionBlock_A_2D(input_shape[1], layer$channel_mult, layer$dropout)
input_shape[1]<-18L*layer$channel_mult
counter<-counter+1
} else if(layer$type=="red"){
net_layers[[counter]] <- inceptionBlock_A_2D_reduction(input_shape[1], layer$channel_mult, layer$dropout)
input_shape<-c(18L*layer$channel_mult, input_shape[2:3]-c(2,0))
counter<-counter+1
} else if(layer$type=="1D"){
net_layers[[counter]] <- inceptionBlock_A_1D(input_shape[1], layer$channel_mult, layer$dropout)
input_shape[1]<-18L*layer$channel_mult
counter<-counter+1
else if (inherits(layer, "inceptionBlock")) {
if (layer$type == "2D") {
net_layers[[counter]] <- inceptionBlock_A_2D(input_shape[1],
layer$channel_mult, layer$dropout)
input_shape[1] <- 18L * layer$channel_mult
counter <- counter + 1
}
else if (layer$type == "red") {
net_layers[[counter]] <- inceptionBlock_A_2D_reduction(input_shape[1],
layer$channel_mult, layer$dropout)
input_shape <- c(18L * layer$channel_mult, input_shape[2:3] -
c(2, 0))
counter <- counter + 1
}
else if (layer$type == "1D") {
net_layers[[counter]] <- inceptionBlock_A_1D(input_shape[1],
layer$channel_mult, layer$dropout)
input_shape[1] <- 18L * layer$channel_mult
counter <- counter + 1
}
}
else if (inherits(layer, "conv")) {
Expand Down

0 comments on commit 292c0ed

Please sign in to comment.