Skip to content

Commit

Permalink
fixed index to i32
Browse files Browse the repository at this point in the history
  • Loading branch information
Baiyuetribe committed Jan 7, 2025
1 parent 2440a7f commit 82bb4ee
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 76 deletions.
96 changes: 73 additions & 23 deletions src/layer/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,56 @@

namespace ncnn {

// auto comp = [this](const std::pair<float, int> &a, const std::pair<float, int> &b)
// {
// if (a.first == b.first)
// return a.second < b.second; // 值相等时按索引升序排序
// return this->largest ? (a.first > b.first) : (a.first < b.first);
// };

// simplestl兼容写法
struct TopK::CompareFunc
{
bool largest;
CompareFunc(bool l)
: largest(l)
{
}
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
if (a.first == b.first)
return a.second < b.second;
return largest ? (a.first > b.first) : (a.first < b.first);
}
};

void TopK::do_sort(std::vector<std::pair<float, int> >& vec, int k, bool sorted) const
{
CompareFunc comp(largest);
if (sorted)
{
std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp);
}
else
{
#if !NCNN_SIMPLESTL
std::nth_element(vec.begin(), vec.begin() + k - 1, vec.end(), comp);
std::sort(vec.begin(), vec.begin() + k, comp);
#else
for (int i = 0; i < k; i++)
{
for (int j = vec.size() - 1; j > i; j--)
{
if (comp(vec[j], vec[j - 1]))
{
std::swap(vec[j], vec[j - 1]);
}
}
}
#endif
}
}

TopK::TopK()
{
// one_blob_only = true; // 仅有1个输入和1个输出
Expand Down Expand Up @@ -64,11 +114,11 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 创建输出blob
top_blob_values.create(k, elemsize, opt.blob_allocator);
top_blob_indices.create(k, elemsize, opt.blob_allocator);
top_blob_indices.create(k, sizeof(int), opt.blob_allocator);

const float* ptr = bottom_blob;
float* outptr = top_blob_values;
float* indices = top_blob_indices;
int* indices = top_blob_indices;
// 创建pair数组用于排序
std::vector<std::pair<float, int> > vec(w);
for (int i = 0; i < w; i++)
Expand All @@ -92,7 +142,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
if (axis == 0)
{
top_blob_values.create(w, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, sizeof(int), opt.blob_allocator);

// #pragma omp parallel for
for (int j = 0; j < w; j++) // 对每列进行处理
Expand All @@ -110,21 +160,21 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
for (int i = 0; i < k; i++)
{
top_blob_values.row(i)[j] = vec[i].first;
top_blob_indices.row(i)[j] = static_cast<float>(vec[i].second);
top_blob_indices.row<int>(i)[j] = vec[i].second;
}
}
}
// 在每一列上进行TopK ,axis=-1等价于axis=1
else
{
top_blob_values.create(k, h, elemsize, opt.blob_allocator);
top_blob_indices.create(k, h, elemsize, opt.blob_allocator);
top_blob_indices.create(k, h, sizeof(int), opt.blob_allocator);

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.row(i);
float* outptr = top_blob_values.row(i);
float* indices = top_blob_indices.row<float>(i);
int* indices = top_blob_indices.row<int>(i);

std::vector<std::pair<float, int> > vec(w);
for (int j = 0; j < w; j++)
Expand All @@ -148,7 +198,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 深度方向上;w不变,高度h变为k
top_blob_values.create(w, h, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, k, sizeof(int), opt.blob_allocator);
// #pragma omp parallel for collapse(2)
for (int i = 0; i < h; i++)
{
Expand All @@ -169,7 +219,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
for (int c = 0; c < k; c++)
{
float* outptr = top_blob_values.channel(c);
float* indices = top_blob_indices.channel(c);
int* indices = (int*)top_blob_indices.channel(c);
outptr[i * w + j] = channel_values[c].first;
indices[i * w + j] = channel_values[c].second;
}
Expand All @@ -180,7 +230,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 子元素内部进行TopK;w不变,高度变为k
top_blob_values.create(w, k, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, channels, sizeof(int), opt.blob_allocator);
for (int q = 0; q < channels; q++)
{
// 获取每个channel的行
Expand All @@ -200,7 +250,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
for (int i = 0; i < k; i++)
{
float* outptr = top_blob_values.channel(q).row(i);
float* indices = top_blob_indices.channel(q).row(i);
int* indices = (int*)top_blob_indices.channel(q).row(i);
outptr[j] = row_scores[i].first;
indices[j] = row_scores[i].second;
}
Expand All @@ -211,14 +261,14 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 输出为k长度的向量,高度不变
top_blob_values.create(k, h, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(k, h, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(k, h, channels, sizeof(int), opt.blob_allocator);
for (int q = 0; q < channels; q++)
{
for (int j = 0; j < h; j++)
{
const float* ptr = bottom_blob.channel(q).row(j);
float* outptr = top_blob_values.channel(q).row(j);
float* indices = top_blob_indices.channel(q).row<float>(j);
int* indices = top_blob_indices.channel(q).row<int>(j);

std::vector<std::pair<float, int> > vec(w);
for (int i = 0; i < w; i++)
Expand All @@ -245,7 +295,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// PyTorch:batch -> channel -> height -> width
// ncnn:channels -> depth -> height -> width
top_blob_values.create(w, h, k, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, k, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, k, channels, sizeof(int), opt.blob_allocator);

// 在pytorch中,假设x为torch.Size([3, 2, 6, 7]),按N维度,也就是x[0]、x[1]、x[2],对比排序,最后直接输出x[i]
// 但在ncnn中,从channels遍历后,维度d再遍历会获得2*3=6种数据。这里就卡主了,不知道怎么处理
Expand All @@ -255,21 +305,21 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 在channel维度上进行TopK
top_blob_values.create(w, h, d, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, d, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, d, k, sizeof(int), opt.blob_allocator);

// need help !!!
}
else if (axis == 20)
{
// 在h维度上进行TopK
top_blob_values.create(w, k, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, d, channels, sizeof(int), opt.blob_allocator);

for (int q = 0; q < channels; q++)
{
const float* ptr = bottom_blob.channel(q);
float* outptr = top_blob_values.channel(q);
float* indices = top_blob_indices.channel(q);
int* indices = top_blob_indices.channel(q);

for (int z = 0; z < d; z++)
{
Expand All @@ -288,7 +338,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
for (int kk = 0; kk < k; kk++)
{
outptr[(z * k + kk) * w + i] = row_scores[kk].first;
indices[(z * k + kk) * w + i] = static_cast<float>(row_scores[kk].second);
indices[(z * k + kk) * w + i] = row_scores[kk].second;
}
}
}
Expand All @@ -298,13 +348,13 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 在h维度上进行TopK
top_blob_values.create(w, k, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, d, channels, sizeof(int), opt.blob_allocator);

for (int q = 0; q < channels; q++)
{
const float* ptr = bottom_blob.channel(q);
float* outptr = top_blob_values.channel(q);
float* indices = top_blob_indices.channel(q);
int* indices = top_blob_indices.channel(q);

for (int z = 0; z < d; z++)
{
Expand All @@ -323,7 +373,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
for (int kk = 0; kk < k; kk++)
{
outptr[(z * k + kk) * w + i] = row_scores[kk].first;
indices[(z * k + kk) * w + i] = static_cast<float>(row_scores[kk].second);
indices[(z * k + kk) * w + i] = row_scores[kk].second;
}
}
}
Expand All @@ -333,7 +383,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
{
// 在w维度上进行TopK
top_blob_values.create(k, h, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(k, h, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(k, h, d, channels, sizeof(int), opt.blob_allocator);

for (int q = 0; q < channels; q++)
{
Expand All @@ -355,9 +405,9 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
for (int j = 0; j < k; j++)
{
float* outptr = top_blob_values.channel(q).row(i * d + z);
float* indices = top_blob_indices.channel(q).row(i * d + z);
int* indices = top_blob_indices.channel(q).row<int>(i * d + z);
outptr[j] = row_values[j].first;
indices[j] = static_cast<float>(row_values[j].second);
indices[j] = row_values[j].second;
}
}
}
Expand Down
52 changes: 2 additions & 50 deletions src/layer/topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,56 +35,8 @@ class TopK : public Layer
int sorted;

private:
// auto comp = [this](const std::pair<float, int> &a, const std::pair<float, int> &b)
// {
// if (a.first == b.first)
// return a.second < b.second; // 值相等时按索引升序排序
// return this->largest ? (a.first > b.first) : (a.first < b.first);
// };

// simplestl兼容写法
struct CompareFunc
{
bool largest;
CompareFunc(bool l)
: largest(l)
{
}
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
if (a.first == b.first)
return a.second < b.second; // 值相等时按索引升序排序
return largest ? (a.first > b.first) : (a.first < b.first);
}
};
void do_sort(std::vector<std::pair<float, int> >& vec, int k, bool sorted) const
{
CompareFunc comp(largest); // 兼容c++03
if (sorted)
{
std::partial_sort(vec.begin(), vec.begin() + k, vec.end(), comp);
}
else
{
#if !NCNN_SIMPLESTL
std::nth_element(vec.begin(), vec.begin() + k - 1, vec.end(), comp);
std::sort(vec.begin(), vec.begin() + k, comp);
#else
// 替换 nth_element + sort 组合
// 使用 bubble_sort 实现相同功能,适配sim_stl
for (int i = 0; i < k; i++)
{
for (int j = vec.size() - 1; j > i; j--)
{
if (comp(vec[j], vec[j - 1]))
{
std::swap(vec[j], vec[j - 1]);
}
}
}
#endif
}
}
struct CompareFunc; // 前向声明
void do_sort(std::vector<std::pair<float, int> >& vec, int k, bool sorted) const;
};

} // namespace ncnn
Expand Down
8 changes: 5 additions & 3 deletions tools/pnnx/tests/ncnn/test_torch_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def forward(self, x, y, z, d):
k=2,
dim=2,
)
d3, i9 = torch.topk(d, k=2, dim=3, sorted=False)
d3, i9 = torch.topk(d, k=2, dim=3, sorted=True)
# return x0, y1, y2, z1, i3, i4, i5, d0, d1, d2, d3, i6, i7, i8, i9
return x0, y1, y2, z1, i3, i4, i5, d2, d3, i8, i9
return x0, y1, y2, i0, i1, i2, z1, i3, i4, i5, d2, d3, i8, i9


def test():
Expand Down Expand Up @@ -79,7 +79,9 @@ def test():
b = test_torch_topk_ncnn.test_inference()

for a0, b0 in zip(a, b):
a0 = a0.float()
if a0.dtype != torch.float:
a0 = a0.to(torch.int32) # i64 --> i32
b0 = b0.view(torch.int32) # f32 --> i32
if not torch.allclose(a0, b0, 1e-3, 1e-3):
return False
return True
Expand Down

0 comments on commit 82bb4ee

Please sign in to comment.