Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bringing Tensor::gather() behavior closer to torch.gather() #2567

Merged
merged 4 commits into from
Oct 17, 2024

Conversation

AnubhabB
Copy link
Contributor

Definition:

input and index must have the same number of dimensions. It is also required that index.size(d) <= input.size(d) for all dimensions d != dim. out will have the same shape as index. Note that input and index do not broadcast against each other.

In our case we were effectively constraining the requirement to:
index.size(d) != input.size(d); where d != dim

This PR removes the != constraint and allows the op to go through as long as:
index.dims()[d] <= input.dims()[d]; where d != dim

Added tests with data generated and validated against torch's ScatterGather tests.

Tested on arm64, x86_64, cuda and metal. I don't have access to mkl and others, so unsure if the kernels hold true!

@AnubhabB
Copy link
Contributor Author

That's odd .. All tests are passing for my test run on Ubuntu x86_64. Let me see if I can replicate this.

@LaurentMazare
Copy link
Collaborator

Looks more like an issue with the github CI infrastructure, I've relaunched the failed action.

@LaurentMazare LaurentMazare merged commit dcd8333 into huggingface:main Oct 17, 2024
9 of 10 checks passed
@LaurentMazare
Copy link
Collaborator

Thanks, I've merged the changes after running the tests locally on my linux box and it was all fine.

@AnubhabB
Copy link
Contributor Author

Thanks a ton @LaurentMazare .. this unblocks another PR :)

@AnubhabB AnubhabB deleted the gather branch October 17, 2024 14:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants