Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onehotbatch(::CuArray, ...) moves data to host #16

Open
mcabbott opened this issue Jul 24, 2022 · 6 comments · May be fixed by #17
Open

onehotbatch(::CuArray, ...) moves data to host #16

mcabbott opened this issue Jul 24, 2022 · 6 comments · May be fixed by #17

Comments

@mcabbott
Copy link
Member

mcabbott commented Jul 24, 2022

The lack of FluxML/Flux.jl#1959 causes the following error, currently blocking FluxML/Flux.jl#2025 :

julia> using CUDA, OneHotArrays, NNlibCUDA

julia> CUDA.allowscalar(false)

julia> x = [1, 3, 2];

julia> y = onehotbatch(x, (0,1,2,3))
4×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
     
 1    
     1
   1  

julia> y2 = onehotbatch(x |> cu, (0,1,2,3))
ERROR: Scalar indexing is disallowed.

Edit: after #27, onehotbatch(x |> cu, 0:3) works, but other ways to specify the labels do not.

@mcabbott mcabbott linked a pull request Jul 25, 2022 that will close this issue
@chengchingwen
Copy link
Member

I don't think that should be allowed. Taking cu should always happened AFTER taking onehotbatch. Consider the some real case where labels are non-bits types like array of strings, it doesn't make sense to take onehotbatch on gpu array. It just need to be clarified in the docs.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 3, 2022

No strong opinions, was just trying to make Flux's tests pass. #17 is more code duplication than ideal.

@chengchingwen
Copy link
Member

Maybe we could remove that test in Flux?

@ToucheSir
Copy link
Member

Also no strong feelings either way, but we if we don't want to support we should add the functionality as deprecated so that it's not a breaking change on the Flux side.

@CarloLucibello
Copy link
Member

As seen in #24 the current behavior is surprising when allowscalar(true). We should allow onehotbatch(::CuArray, ...) whenever possible and error out otherwise.

@CarloLucibello CarloLucibello changed the title onehotbatch(::CuArray, ...) onehotbatch(::CuArray, ...) moves data to host Dec 28, 2022
@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 28, 2022

#27 fixes the case onehotbatch(::CuVector{<:Integer}, ::UnitRange)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants