Skip to content

Commit 1ea27b9

Browse files
authored
Update YOLOX to burn 0.14 (#41)
* Update to burn 0.14 * Remove duplicate release flag
1 parent 00dfeac commit 1ea27b9

File tree

5 files changed

+16
-16
lines changed

5 files changed

+16
-16
lines changed

resnet-burn/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ with the `NdArray` backend and performs inference on the provided input image.
3838
You can run the example with the following command:
3939

4040
```sh
41-
cargo run --release --example inference samples/dog.jpg --release
41+
cargo run --release --example inference samples/dog.jpg
4242
```
4343

4444
#### Fine-tuning

yolox-burn/Cargo.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ pretrained = ["burn/network", "std", "dep:dirs"]
1212

1313
[dependencies]
1414
# Note: default-features = false is needed to disable std
15-
burn = { version = "0.13.0", default-features = false }
16-
burn-import = { version = "0.13.0" }
15+
burn = { version = "0.14.0", default-features = false }
16+
burn-import = { version = "0.14.0" }
1717
itertools = { version = "0.12.1", default-features = false, features = [
1818
"use_alloc",
1919
] }
@@ -24,5 +24,5 @@ serde = { version = "1.0.192", default-features = false, features = [
2424
] } # alloc is for no_std, derive is needed
2525

2626
[dev-dependencies]
27-
burn = { version = "0.13.0", features = ["ndarray"] }
27+
burn = { version = "0.14.0", features = ["ndarray"] }
2828
image = { version = "0.24.9", features = ["png", "jpeg"] }

yolox-burn/examples/inference.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use yolox_burn::model::{boxes::nms, weights, yolox::Yolox, BoundingBox};
55

66
use burn::{
77
backend::NdArray,
8-
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
8+
tensor::{backend::Backend, Device, Element, Tensor, TensorData},
99
};
1010

1111
const HEIGHT: usize = 640;
@@ -16,9 +16,12 @@ fn to_tensor<B: Backend, T: Element>(
1616
shape: [usize; 3],
1717
device: &Device<B>,
1818
) -> Tensor<B, 3> {
19-
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
20-
// [H, W, C] -> [C, H, W]
21-
.permute([2, 0, 1])
19+
Tensor::<B, 3>::from_data(
20+
TensorData::new(data, shape).convert::<B::FloatElem>(),
21+
device,
22+
)
23+
// [H, W, C] -> [C, H, W]
24+
.permute([2, 0, 1])
2225
}
2326

2427
/// Draws bounding boxes on the given image.

yolox-burn/src/model/boxes.rs

+3-6
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,19 @@ pub fn nms<B: Backend>(
4545
let (cls_score, cls_idx) = candidate_scores.squeeze::<2>(0).max_dim_with_indices(1);
4646
let cls_score: Vec<_> = cls_score
4747
.into_data()
48-
.value
49-
.iter()
48+
.iter::<B::FloatElem>()
5049
.map(|v| v.elem::<f32>())
5150
.collect();
5251
let cls_idx: Vec<_> = cls_idx
5352
.into_data()
54-
.value
55-
.iter()
53+
.iter::<B::IntElem>()
5654
.map(|v| v.elem::<i64>() as usize)
5755
.collect();
5856

5957
// [num_boxes, 4]
6058
let candidate_boxes: Vec<_> = candidate_boxes
6159
.into_data()
62-
.value
63-
.iter()
60+
.iter::<B::FloatElem>()
6461
.map(|v| v.elem::<f32>())
6562
.collect();
6663

yolox-burn/src/model/head.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ const PRIOR_PROB: f64 = 1e-2;
2424
fn create_2d_grid<B: Backend>(x: usize, y: usize, device: &Device<B>) -> Tensor<B, 3, Int> {
2525
let y_idx = Tensor::arange(0..y as i64, device)
2626
.reshape(Shape::new([y, 1]))
27-
.repeat(1, x)
27+
.repeat_dim(1, x)
2828
.reshape(Shape::new([y, x]));
2929
let x_idx = Tensor::arange(0..x as i64, device)
3030
.reshape(Shape::new([1, x])) // can only repeat with dim=1
31-
.repeat(0, y)
31+
.repeat_dim(0, y)
3232
.reshape(Shape::new([y, x]));
3333

3434
Tensor::stack(vec![x_idx, y_idx], 2)

0 commit comments

Comments
 (0)