Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make outputsize work with Embedding #2088

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,13 +655,14 @@ or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch).

For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions.
For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`.
Note that [`outputsize`](@ref Flux.outputsize) expects `size(x)`, the indices not the one-hot array.

# Examples
```jldoctest
julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22))
Embedding(26 => 4) # 104 parameters

julia> emb(2) # one column of e.weight (here not random!)
julia> emb(2) # one column of emb.weight (here not random!)
4-element Vector{Float32}:
0.0
22.0
Expand All @@ -680,6 +681,9 @@ true

julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions
(4, 10, 1, 12)

julia> Flux.outputsize(emb, (10, 1, 12)) # outputsize wants indices, not OneHotArray
(4, 10, 1, 12)
```
"""
struct Embedding{W<:AbstractMatrix}
Expand All @@ -691,8 +695,8 @@ end
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...)

(m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix
Expand Down
13 changes: 12 additions & 1 deletion src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ DimensionMismatch("Input channels must match! (7 vs. 3)")
julia> outputsize([Dense(10 => 4), Dense(4 => 2)], (10, 1)) # Vector of layers becomes a Chain
(2, 1)
```

Limitations:
* `Embedding` accepts either integers or one-hot arrays, and `ohx = onehotbatch(x, ...)`
has one more dimension than `x`. Here `outputsize` uses `size(x)`.
* At present `outputsize` does not work with recurrent layers,
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug.
`outputsize(RNN(2 => 3), (2, 1))` gives an error. See https://github.com/FluxML/Flux.jl/pull/1755 for more.

I'm not up to speed on any Embedding changes, but a link reference here would be nice.

"""
function outputsize(m, inputsizes::Tuple...; padbatch=false)
x = nil_input(padbatch, inputsizes...)
Expand Down Expand Up @@ -157,6 +163,9 @@ end

## fixes for layers that don't work out of the box

(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1))
(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...)
Comment on lines +166 to +167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these warn about the limitations? outputsize can be used in sanity checking contexts, so I think a user would be willing to accept a little noise.


for (fn, Dims) in ((:conv, DenseConvDims),)
@eval begin
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)
Expand Down Expand Up @@ -279,9 +288,11 @@ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return
"""
autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1))
autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:Embedding}, x::AbstractArray) = size(x, 1)
autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1)

autosizefor(::Type{<:Embedding}, x::AbstractArray) = error(
"@autosize Embeeding(_ => n) cannot work, as this _ is the size of the vocabulary, not an array size")

_replaceunderscore(e, s) = e === :_ ? s : e
_replaceunderscore(ex::Expr, s) = Expr(ex.head, map(a -> _replaceunderscore(a, s), ex.args)...)

Expand Down
25 changes: 20 additions & 5 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
m = Dense(10, 5)
@test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1)
@test outputsize(m, (10,); padbatch=true) == (5, 1)
@test outputsize(m, (10,)) == (5,)
@test outputsize(m, (10, 6, 7)) == (5, 6, 7)

m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2))
@test outputsize(m, (10,); padbatch=true) == (2, 1)
Expand Down Expand Up @@ -41,6 +43,19 @@
@test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
end

@testset "embeddings" begin
# Here outputsize expects indices, not one-hot representation:
m = Embedding(3 => 4)
@test outputsize(m, (3, 7)) == (4, 3, 7) == size(m(rand(1:3, 3, 7)))
@test outputsize(m, (5, 6, 7)) == (4, 5, 6, 7) == size(m(rand(1:3, 5, 6, 7)))

m = Chain(x -> Flux.onehotbatch(x, 1:5), Embedding(5 => 7))
@test size(m([3,4])) == (7, 2)
@test outputsize(m, (2,)) == (7, 2)
# This works because Flux.onehotbatch([nil, nil], 1:5) makes a 5×2 OneHotMatrix
# But e.g. Flux.onehotbatch([nil, nil], 'a':'e') will not work.
end

@testset "multiple inputs" begin
m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu))
@test outputsize(m, (2,), (3,)) == (10,)
Expand Down Expand Up @@ -175,11 +190,6 @@ end
m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)

@test_broken begin # outputsize fails on Embedding
m = @autosize (2, 3, 4, 5) Embedding(_ => 10) # goes by first dim, not 2nd-last
@test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5)
end

m = @autosize (9,) Dense(_ => div(_,2))
@test randn(9) |> m |> size == (4,)

Expand Down Expand Up @@ -234,6 +244,11 @@ end
# https://github.com/FluxML/Flux.jl/issues/2086
m = @autosize (3, 1) Chain(; c = Dense(_ => 2, sigmoid), b = BatchNorm(_, affine=false))
@test randn(Float32, 3, 32) |> m |> size == (2, 32)

# Embedding takes a vocab size, not an array size
@test_throws ErrorException @autosize (2, 3) Embedding(_ => 10)
m = @autosize (3,) Chain(Embedding(26 => 10), Dense(_, 4))
@test rand(1:26, 3) |> m |> size == (4, 3)
end

@testset "LazyLayer" begin
Expand Down