Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make outputsize work with Embedding #2088

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

Like #1656 this wants to make outputsize(Embedding(3 => 4), (5,)) == (4, 5). That is, it thinks the size referred to by outputsize should be the size of the array of vocabulary indices, not the size of the one-hot representation.

But rather than overload indexing or gather (as here https://github.com/FluxML/Flux.jl/pull/1656/files#diff-0dfa3b94337acdaa714025f5198f6907e6a50a59aac03ba1230fbcb681126da2R172) this just adds methods to (::Embedding). I think that's least likely to cause surprises. If indexing shows up elsewhere, we can decide then whether to extend.

Restricting (m::Embedding)(x::AbstractArray{<:Integer}) also seems like the right thing to do, error right away on non-integer input.

Comment on lines +166 to +167
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these warn about the limitations? outputsize can be used in sanity checking contexts, so I think a user would be willing to accept a little noise.

* `Embedding` accepts either integers or one-hot arrays, and `ohx = onehotbatch(x, ...)`
has one more dimension than `x`. Here `outputsize` uses `size(x)`.
* At present `outputsize` does not work with recurrent layers,
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug.
`outputsize(RNN(2 => 3), (2, 1))` gives an error. See https://github.com/FluxML/Flux.jl/pull/1755 for more.

I'm not up to speed on any Embedding changes, but a link reference here would be nice.

@ToucheSir
Copy link
Member

Would this landing obsolete #1656?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants