-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Describe the bug
Doing an FFT on array lengths 2^(21) and 2^(22) results in a kernel failure, but larger array sizes work.
To Reproduce
A simple script to reproduce:
import mlx.core as mx
size = int(2**21)
x = mx.ones(size)
mx.eval(mx.fft.fft(x, stream=mx.gpu))This will result in the following error:
Terminating due to uncaught exception: [metal::Device] Unable to load function four_step_mem_8192_float2_float2_0_false
Function four_step_mem_8192_float2_float2_0_false was not found in the library
Abort trap: 6A similar thing happens for an array that is 2**22 long. However, the code succeeds for arrays that have length 2**23, 2**24, 2**25, etc., up to 2**28. (I don't have enough memory to test beyond that.) By "succeed" I mean the function runs without failure. I haven't checked that the output is actually correct.
Expected behavior
The FFT should work for 2**21 and 2**22 if larger array sizes work. At the very least, the error should be caught appropriately with a more graceful exit.
Desktop (please complete the following information):
- OS Version: MacOS 15.1.1
- Version 0.22.0
Additional context
Digging into the code a bit I can see why it's failing. For a size of 2**21, plan.n1 here will get set to 2048. Later on, that will cause threadgroup_mem_size to get set to 8192 here. However, I don't know why that doesn't cause the assert at line 641 to raise an error.
I see the comment at line 640 that // FFTs up to 2^20 are currently supported, so I'm not sure why the 2^23 FFTs are running. Even if the assert worked properly, why the limit of 2^20? In the research application we're trying to use this for we will be evaluating arrays of 2^21 - 2^25, so it would be ideal if these array sizes could be handled.