Skip to content

Commit

Permalink
replace sequential pt2
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed Apr 18, 2024
1 parent 7150be0 commit fabaef8
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions mobilenet-burn/src/model/mobilenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use burn::{
config::Config,
module::Module,
nn::{
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig}, BatchNormConfig, Dropout, DropoutConfig, Linear, LinearConfig
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
BatchNormConfig, Dropout, DropoutConfig, Linear, LinearConfig,
},
tensor::{backend::Backend, Tensor},
};
Expand All @@ -23,15 +24,13 @@ use {
burn_import::pytorch::{LoadArgs, PyTorchFileRecorder},
};

use super::{
conv_norm::Conv2dNormActivation,
inverted_residual::InvertedResidual,
};
use super::{conv_norm::Conv2dNormActivation, inverted_residual::InvertedResidual};

#[derive(Debug, Module)]
pub struct MobileNetV2<B: Backend> {
features: Vec<ConvBlock<B>>,
classifier: Vec<ClassifierLayersType<B>>,
// classifier: Vec<ClassifierLayersType<B>>,
classifier: Classifier<B>,
avg_pool: AdaptiveAvgPool2d,
}

Expand All @@ -49,7 +48,7 @@ pub struct MobileNetV2Config {
#[config(default = "8")]
round_nearest: usize,

norm_layer:BatchNormConfig,
norm_layer: BatchNormConfig,

#[config(default = "0.2")]
dropout: f64,
Expand All @@ -61,10 +60,17 @@ enum ConvBlock<B: Backend> {
Conv(Conv2dNormActivation<B>),
}


#[derive(Module, Debug)]
enum ClassifierLayersType<B: Backend> {
Dropout(Dropout),
Linear(Linear<B>),
struct Classifier<B: Backend> {
dropout: Dropout,
linear: Linear<B>,
}
impl<B: Backend> Classifier<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.dropout.forward(input);
self.linear.forward(x)
}
}

impl MobileNetV2Config {
Expand Down Expand Up @@ -127,12 +133,10 @@ impl MobileNetV2Config {
.init(device),
));

let classifier = vec![
ClassifierLayersType::Dropout(DropoutConfig::new(self.dropout).init()),
ClassifierLayersType::Linear(
LinearConfig::new(last_channel, self.num_classes).init(device),
),
];
let classifier = Classifier {
dropout: DropoutConfig::new(self.dropout).init(),
linear: LinearConfig::new(last_channel, self.num_classes).init(device),
};

MobileNetV2 {
features,
Expand All @@ -154,19 +158,9 @@ impl<B: Backend> MobileNetV2<B> {
}
}
}

x = self.avg_pool.forward(x);
x = x.flatten(1, 1);
for layer in &self.classifier {
match layer {
ClassifierLayersType::Dropout(dropout) => {
x = dropout.forward(x);
}
ClassifierLayersType::Linear(linear) => {
x = linear.forward(x);
}
}
}
x = self.classifier.forward(x);
x
}
}
Expand Down

0 comments on commit fabaef8

Please sign in to comment.