Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Embedding #2084

Merged
merged 3 commits into from
Oct 17, 2022
Merged

Simplify Embedding #2084

merged 3 commits into from
Oct 17, 2022

Conversation

mcabbott
Copy link
Member

Embedding has some special code for OneHotMatrix which (1) will break with latest changes, and (2) doesn't allow higher-rank arrays the way that "index" input does:

julia> Embedding(26 => 10)(rand(1:26)) |> size
(10,)

julia> Embedding(26 => 10)(rand(1:26, 2,3,4,5)) |> size
(10, 2, 3, 4, 5)

julia> Embedding(26 => 10)(onehot(rand(1:26), 1:26)) |> size
(10,)

julia> Embedding(26 => 10)(onehotbatch(rand(1:26, 2), 1:26)) |> size  # can't go further
(10, 2)

So this PR simplifies & adds reshape.

I did this after forgetting that #1656 exists, some overlap. This PR does not attempt to fix outputsize. Some other changes there have already happened elsewhere.

Comment on lines +697 to +698
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could instead call Flux.onecold. The result will differ on e.g. [true, true, false], not sure we care too much either way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance in the one hot case? If it's onecold-compatible, then folks should use OneHotArray for performance. At least with *, we do the mathematically expected operation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For OneHotArray these should be identical, right? Result and performance.

For a one-hot BitArray, the results will agree. I would guess that onecold is faster but haven't checked.

For a generic BitArray, I'm not sure which is mathematically expected really. I think you're saying that * is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, what you wrote is what I meant re: performance. I was adding that in the one-hot bit array case, we can direct people to OneHotArray if their concern is performance.

Yeah whenever I've come across this type of operation in papers, I see it written as *. There's an implicit assumption that x is one-hot, so maybe onecold could be better here if it were made to error for [true, true, false], etc. But I think silently choosing the first "hot" index is wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Mixing two embedding vectors seems less wrong. But probably nobody ever hits this & it's just a way to decouple from OneHotArray types. I don't think we should document that boolean indexing is an option.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think we are happy with the current implementation in the PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so.

I see we had a very similar discussion in #1656 (comment) BTW, I forgot... but same conclusion.

src/layers/basic.jl Outdated Show resolved Hide resolved
…t without 5 named variables, and show that the point of onehot is variables which aren't 1:n already. Also show result of higher-rank input.
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ready to me. Is there more you wanted to do here?

src/layers/basic.jl Outdated Show resolved Hide resolved
Comment on lines +697 to +698
(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think we are happy with the current implementation in the PR?

@mcabbott
Copy link
Member Author

The "more" is #2088, really. Will merge when green.

@mcabbott mcabbott merged commit dfd4549 into FluxML:master Oct 17, 2022
@mcabbott mcabbott deleted the embedding branch January 10, 2023 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants