Skip to content

Commit

Permalink
refactor: add some magic
Browse files Browse the repository at this point in the history
  • Loading branch information
nullptr authored Apr 28, 2024
1 parent f4f3ddb commit 6cfb029
Showing 1 changed file with 46 additions and 48 deletions.
94 changes: 46 additions & 48 deletions core/algorithm/el_algorithm_yolo_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,71 +149,69 @@ el_err_code_t AlgorithmYOLOWorld::postprocess() {
_results.clear();

// get outputs
auto* data_bboxes{static_cast<int8_t*>(this->__p_engine->get_output(0))};
auto* data_scores{static_cast<int8_t*>(this->__p_engine->get_output(1))};
const auto* data_bboxes{static_cast<int8_t*>(this->__p_engine->get_output(0))};
const auto* data_scores{static_cast<int8_t*>(this->__p_engine->get_output(1))};

auto width{this->__input_shape.dims[1]};
auto height{this->__input_shape.dims[2]};
const auto width{this->__input_shape.dims[1]};
const auto height{this->__input_shape.dims[2]};

float scale_scores{_output_quant_params[1].scale};
scale_scores = scale_scores < 0.1f ? scale_scores * 100.f : scale_scores; // rescale
int32_t zero_point_scores{_output_quant_params[1].zero_point};
const int32_t zero_point_scores{_output_quant_params[1].zero_point};

float scale_bboxes{_output_quant_params[0].scale};
int32_t zero_point_bboxes{_output_quant_params[0].zero_point};
const float scale_bboxes{_output_quant_params[0].scale};
const int32_t zero_point_bboxes{_output_quant_params[0].zero_point};

auto num_bboxes{this->__output_shape.dims[1]};
auto num_element{this->__output_shape.dims[2]};
auto num_classes{static_cast<uint8_t>(_output_shapes[1].dims[2])};
const auto num_bboxes{this->__output_shape.dims[1]};
const auto num_elements{this->__output_shape.dims[2]};
const auto num_classes{static_cast<uint8_t>(_output_shapes[1].dims[2])};

ScoreType score_threshold{get_score_threshold()};
IoUType iou_threshold{get_iou_threshold()};
const ScoreType score_threshold{get_score_threshold()};
const IoUType iou_threshold{get_iou_threshold()};

// parse output
for (size_t bbox_i = 0; bbox_i < num_bboxes; ++bbox_i) {
size_t idx_s = bbox_i * num_classes;
for (size_t target_i = 0; target_i < num_classes; ++target_i) {
uint8_t bbox_i_score =
static_cast<decltype(scale_scores)>(data_scores[idx_s + target_i] - zero_point_scores) * scale_scores;
if (bbox_i_score < score_threshold) {
for (size_t class_i = 0; class_i < num_classes; ++class_i) {

for (size_t score_i = class_i, bbox_i = 0 ; bbox_i < num_bboxes; score_i += num_classes, bbox_i += num_elements) {

const auto score = static_cast<decltype(scale_scores)>(data_scores[score_i] - zero_point_scores) * scale_scores;
if (score < score_threshold) {
continue;
}

{
BoxType box{
.x = 0,
.y = 0,
.w = 0,
.h = 0,
.score = bbox_i_score,
.target = static_cast<decltype(BoxType::target)>(target_i),
};

size_t idx_b = bbox_i * num_element;
auto tl_x{((data_bboxes[idx_b + INDEX_TL_X] - zero_point_bboxes) * scale_bboxes)};
auto tl_y{((data_bboxes[idx_b + INDEX_TL_Y] - zero_point_bboxes) * scale_bboxes)};
auto br_x{((data_bboxes[idx_b + INDEX_BR_X] - zero_point_bboxes) * scale_bboxes)};
auto br_y{((data_bboxes[idx_b + INDEX_BR_Y] - zero_point_bboxes) * scale_bboxes)};

box.w = br_x - tl_x;
box.h = br_y - tl_y;
box.x = tl_x + box.w / 2;
box.y = tl_y + box.h / 2;

box.x = EL_CLIP(box.x, 0, width) * _w_scale;
box.y = EL_CLIP(box.y, 0, height) * _h_scale;
box.w = EL_CLIP(box.w, 0, width) * _w_scale;
box.h = EL_CLIP(box.h, 0, height) * _h_scale;

_results.emplace_front(std::move(box));
}
BoxType box;

box.score = static_cast<decltype(BoxType::score)>(score);
box.target = static_cast<decltype(BoxType::target)>(class_i);

auto tl_x{((data_bboxes[bbox_i + INDEX_TL_X] - zero_point_bboxes) * scale_bboxes)};
auto tl_y{((data_bboxes[bbox_i + INDEX_TL_Y] - zero_point_bboxes) * scale_bboxes)};
auto br_x{((data_bboxes[bbox_i + INDEX_BR_X] - zero_point_bboxes) * scale_bboxes)};
auto br_y{((data_bboxes[bbox_i + INDEX_BR_Y] - zero_point_bboxes) * scale_bboxes)};

box.w = static_cast<decltype(BoxType::w)>(br_x - tl_x);
box.h = static_cast<decltype(BoxType::h)>(br_y - tl_y);

// if constexpr would be better (C++17)
static_assert(std::is_integral<decltype(box.w)>::value);
static_assert(std::is_integral<decltype(box.h)>::value);

box.x = tl_x + (box.w >> 1);
box.y = tl_y + (box.h >> 1);

_results.emplace_front(std::move(box));
}
}

el_nms(_results, iou_threshold, score_threshold, false, true);

_results.sort([](const BoxType& a, const BoxType& b) { return a.x < b.x; });

for (auto& box : _results) {
box.x = EL_CLIP(box.x, 0, width) * _w_scale;
box.y = EL_CLIP(box.y, 0, height) * _h_scale;
box.w = EL_CLIP(box.w, 0, width) * _w_scale;
box.h = EL_CLIP(box.h, 0, height) * _h_scale;
}

return EL_OK;
}

Expand Down

0 comments on commit 6cfb029

Please sign in to comment.