Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MobileNetv2 #26

Merged
merged 12 commits into from
Apr 25, 2024
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
|------------------------------------------------|-------------------------------------------------------|----------------------------------------------|
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| Model | Description | Repository Link |
|-------------------------------------------------|-------------------------------------------------------|----------------------------------------------|
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/README.md) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |

## Community Contributions

Expand Down
25 changes: 25 additions & 0 deletions mobilenetv2-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
authors = ["Arjun31415", "guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "mobilenetv2-burn"
version = "0.1.0"
edition = "2021"

[features]
default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0" }
burn-import = { version = "0.13.0" }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
16 changes: 16 additions & 0 deletions mobilenetv2-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.

## Sample Image

Image Title: Standing yellow Labrador Retriever dog.
Author: Djmirko
Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/

## Pre-trained Model

The ImageNet pre-trained model was ported from [`torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2`](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights).

As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)).
40 changes: 40 additions & 0 deletions mobilenetv2-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# MobileNetV2 Burn

[MobileNetV2](https://arxiv.org/abs/1801.04381) is a convolutional neural network architecture for
classification tasks which seeks to perform well on mobile devices. You can find the
[Burn](https://github.com/tracel-ai/burn) implementation for the MobileNetV2 in
[src/model/mobilenetv2.rs](src/model/mobilenetv2.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", default-features = false }
```

If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag.

```toml
[dependencies]
mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", features = ["pretrained"] }
```

**Important:** this feature requires `std`.

### Example Usage

The [inference example](examples/inference.rs) initializes a MobileNetV2 from the ImageNet
[pre-trained weights](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights)
with the `NdArray` backend and performs inference on the provided input image.

You can run the example with the following command:

```sh
cargo run --release --features pretrained --example inference samples/dog.jpg
```
69 changes: 69 additions & 0 deletions mobilenetv2-burn/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use mobilenetv2_burn::model::{imagenet, mobilenetv2::MobileNetV2, weights};

use burn::{
backend::NdArray,
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};

const HEIGHT: usize = 224;
const WIDTH: usize = 224;

fn to_tensor<B: Backend, T: Element>(
data: Vec<T>,
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// [H, W, C] -> [C, H, W]
.permute([2, 0, 1])
/ 255 // normalize between [0, 1]
}

pub fn main() {
// Parse arguments
let img_path = std::env::args().nth(1).expect("No image path provided");

// Create MobileNetV2
let device = Default::default();
let model: MobileNetV2<NdArray> =
MobileNetV2::pretrained(weights::MobileNetV2::ImageNet1kV2, &device)
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}"))
.unwrap();

// Load image
let img = image::open(&img_path)
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}"))
.unwrap();

// Resize to 224x224
let resized_img = img.resize_exact(
WIDTH as u32,
HEIGHT as u32,
image::imageops::FilterType::Triangle, // also known as bilinear in 2D
);

// Create tensor from image data
let img_tensor = to_tensor(
resized_img.into_rgb8().into_raw(),
[HEIGHT, WIDTH, 3],
&device,
)
.unsqueeze::<4>(); // [B, C, H, W]

// Normalize the image
let x = imagenet::Normalizer::new(&device).normalize(img_tensor);

// Forward pass
let out = model.forward(x);

// Output class index w/ score (raw)
let (score, idx) = out.max_dim_with_indices(1);
let idx = idx.into_scalar() as usize;

println!(
"Predicted: {}\nCategory Id: {}\nScore: {:.4}",
imagenet::CLASSES[idx],
idx,
score.into_scalar()
);
}
Binary file added mobilenetv2-burn/samples/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions mobilenetv2-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![cfg_attr(not(feature = "std"), no_std)]
pub mod model;
extern crate alloc;
83 changes: 83 additions & 0 deletions mobilenetv2-burn/src/model/conv_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig, PaddingConfig2d,
},
tensor::{self, backend::Backend, Tensor},
};

/// A rectified linear unit where the activation is limited to a maximum of 6.
#[derive(Module, Debug, Clone, Default)]
pub struct ReLU6 {}
impl ReLU6 {
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
tensor::activation::relu(input).clamp_max(6)
}
}

/// A Conv2d -> BatchNorm -> activation block.
#[derive(Module, Debug)]
pub struct Conv2dNormActivation<B: Backend> {
conv: Conv2d<B>,
norm: BatchNorm<B, 2>,
activation: ReLU6,
}

/// [Conv2dNormActivation] configuration.
#[derive(Config, Debug)]
pub struct Conv2dNormActivationConfig {
pub in_channels: usize,
pub out_channels: usize,

#[config(default = "3")]
pub kernel_size: usize,

#[config(default = "1")]
pub stride: usize,

#[config(default = "None")]
pub padding: Option<usize>,

#[config(default = "1")]
pub groups: usize,

#[config(default = "1")]
pub dilation: usize,

#[config(default = false)]
pub bias: bool,
}

impl Conv2dNormActivationConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2dNormActivation<B> {
let padding = if let Some(padding) = self.padding {
padding
} else {
(self.kernel_size - 1) / 2 * self.dilation
};

Conv2dNormActivation {
conv: Conv2dConfig::new(
[self.in_channels, self.out_channels],
[self.kernel_size, self.kernel_size],
)
.with_padding(PaddingConfig2d::Explicit(padding, padding))
.with_stride([self.stride, self.stride])
.with_bias(self.bias)
.with_dilation([self.dilation, self.dilation])
.with_groups(self.groups)
.init(device),
norm: BatchNormConfig::new(self.out_channels).init(device),
activation: ReLU6 {},
}
}
}
impl<B: Backend> Conv2dNormActivation<B> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(input);
let x = self.norm.forward(x);
self.activation.forward(x)
}
}
Loading