Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #16 #17

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Fix #16 #17

wants to merge 2 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 25, 2022

Unlike FluxML/Flux.jl#1959, this uses map over arrays. Some duplication, unfortunately. Possibly the new method should be restricted to AbstractGPUArrays?

Closes #16

Also tries to organise the tests just a little bit better.

Comment on lines +37 to +39
@test_broken collect(A * y) ≈ collect(A) * collect(y)

@test_broken gradient(A -> sum(abs, A * y), A)[1] isa CuArray # gather!(dst::JLArray, ...) fails
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the solution here to use gather (and take on a dep)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs gather/scatter from NNlibCUDA to work on the GPU. And since there's no corresponding code for non-CuArray GPUArrays, I think it can't work with this fake JLArray.

For testing it, you could set up the whole buildkite story to run honest CUDA tests. But perhaps it's not worth it, and this package should just trust NNlib + NNlibCUDA to test things. And perhaps Flux to test the integration?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xref FluxML/NNlib.jl#427 too --- it would be nice if forgetting to load NNlibCUDA gave friendly errors, not scalar indexing.

It would be nicer if that could be loaded automatically, of course.

for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
isnothing(_findval(x, labels)) && throw(ArgumentError("Value x = $x not found in labels = $labels"))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed these error types partly so that tests can distinguish scalar indexing errors from helpful messages.

end
end
return OneHotArray(indices, length(labels))
end
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason to change this is to avoid ever making a MVector or something weird like that:

Suggested change
function _onehotbatch(data::AbstractArray, labels) # this works for GPUArrays too
function _onehotbatch(data::AbstractGPUArray, labels)

indices = UInt32[something(_findval(x, labels), default_index) for x in data]
return OneHotArray(indices, length(labels))
end
function _onehotbatch(data::AbstractArray, labels, default)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function _onehotbatch(data::AbstractArray, labels, default)
function _onehotbatch(data::AbstractGPUArray, labels, default)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

onehotbatch(::CuArray, ...) moves data to host
2 participants