Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] argsort metal kernel yields incorrect output with > 1024 elements #2570

Open
AnubhabB opened this issue Oct 20, 2024 · 2 comments
Open

Comments

@AnubhabB
Copy link
Contributor

AnubhabB commented Oct 20, 2024

Reproduction:

// Correct

// The kernel call @ candle-metal-kernels/src/lib.rs:2151 receives the following args:
//  nrows: 1  ncols: 1024 ncols_pad: 1024
let d = Tensor::rand(-256_f32, 255., (1, 1024), &candle_core::Device::new_metal(0)?)?;
println!("{d}");
// [[ 137.8366,  -72.5639, -186.1103, ..., -225.0789, -141.2470,  -12.9232]]
// Tensor[[1, 1024], f32, metal:4294969852]
let i = d.arg_sort_last_dim(true)?;
println!("{i}");
// [[132, 932, 801, ..., 556, 518, 683]]
// Tensor[[1, 1024], u32, metal:4294969852]

// Error - output indices are a bunch of zeroes, returns very large numbers if we use shape E.g. (1, 128650)

// The kernel call @ candle-metal-kernels/src/lib.rs:2151 receives the following args:
//  nrows: 1  ncols: 2048 ncols_pad: 2048
let d = Tensor::rand(-256_f32, 255., (1, 2048), &candle_core::Device::new_metal(0)?)?;
println!("{d}");
// [[ 137.8366,  -72.5639, -186.1103, ..., -225.0789, -141.2470,  -12.9232]]
// Tensor[[1, 2048], f32, metal:4294969852]
let i = d.arg_sort_last_dim(true)?;
println!("{i}");
// [[0, 0, 0, ..., 0, 0, 0]]
// Tensor[[1, 2048], u32, metal:4294969852]

Edit: removed incorrect diagnosis.

@LaurentMazare
Copy link
Collaborator

That's actually expected, though we should have a proper error message for it. The candle sort operator uses a bitonic sort which requires the whole data to fit in a single thread-group/cuda-block (the same approach is used by llama.cpp), the idea there is to use this operator for things like mixture of experts where the number of element to sort is very small but it cannot apply to larger sets of elements.

@AnubhabB
Copy link
Contributor Author

Yes I realized it's bitonic sort once I went through the code, didn't realize it's by design.

A generic implementation would be helpful (in my case speeding up token sampling for autoregressive language models) and I did some digging around this.

Torch delegates cuda sort to thrust - the current versions of thrust and cub resides cccl. NVIDIA/cccl is not supported by cudarc yet and my lowkey efforts to bindgen was a spectacular failure.

And from what I could gather, Torch relies on MPSGraph.argsort() to do the sorting. Yet again, MPSGraph is yet to be a part of metal-rs.

According to this implementation, cub uses an implementation of RadixSort.

I'm working on an implementation of it and if things go well and the port to metal works I'll probably create a PR where I'll call the bitonic sort kernel if ncols_pad < MaxThreadsPerGroup otherwise call a DeviceRadixSort kernel.

Lot of IFs in the note above, sorry bout that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants