-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make outputsize
work with Embedding
#2088
base: master
Are you sure you want to change the base?
Conversation
(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1)) | ||
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these warn about the limitations? outputsize
can be used in sanity checking contexts, so I think a user would be willing to accept a little noise.
* `Embedding` accepts either integers or one-hot arrays, and `ohx = onehotbatch(x, ...)` | ||
has one more dimension than `x`. Here `outputsize` uses `size(x)`. | ||
* At present `outputsize` does not work with recurrent layers, | ||
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug. | |
`outputsize(RNN(2 => 3), (2, 1))` gives an error. See https://github.com/FluxML/Flux.jl/pull/1755 for more. |
I'm not up to speed on any Embedding
changes, but a link reference here would be nice.
Would this landing obsolete #1656? |
Like #1656 this wants to make
outputsize(Embedding(3 => 4), (5,)) == (4, 5)
. That is, it thinks the size referred to byoutputsize
should be the size of the array of vocabulary indices, not the size of the one-hot representation.But rather than overload indexing or
gather
(as here https://github.com/FluxML/Flux.jl/pull/1656/files#diff-0dfa3b94337acdaa714025f5198f6907e6a50a59aac03ba1230fbcb681126da2R172) this just adds methods to(::Embedding)
. I think that's least likely to cause surprises. If indexing shows up elsewhere, we can decide then whether to extend.Restricting
(m::Embedding)(x::AbstractArray{<:Integer})
also seems like the right thing to do, error right away on non-integer input.