Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deeponet multi-output fix #11

Closed
wants to merge 17 commits into from
Closed

Deeponet multi-output fix #11

wants to merge 17 commits into from

Conversation

ayushinav
Copy link
Contributor

@ayushinav ayushinav commented Jul 4, 2024

Closes #9

src/display.jl Outdated Show resolved Hide resolved
src/display.jl Outdated
Comment on lines 1 to 8
# function Base.show(io::IO, model::conv) where {conv <: OperatorConv}
# # print(io, model.name*"() # "*string(Lux.parameterlength(model))*" parameters")
# print(io, model.name)
# end

# function Base.show(io::IO, ::MIME"text/plain", model::conv) where {conv <: OperatorConv}
# show(io, model.name)
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these, printing was fixed upstream

src/deeponet.jl Outdated

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> additional = Chain(Dense(1 => 4));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input for additional layer should be size of inner embedding size

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it not need reduction/sum/dropdim before additional layer. It should be additional = Chain(Dense(16 => 4)); here. Otherwise It's created a bottleneck and we lose information here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now. Using the linear layer as additional layer for the cases where we do not have the additional layer did not seem ideal to me because it would imply weighted sum, where the weights would be learnt during training, but since DeepONets by default take the dot product, aka non-weighted sum, which could be required by many users.

@ayushinav
Copy link
Contributor Author

Other than the doctests, the failing test cases were because the compiler because Tuple{Array{Float32, 4}, ...} would not be a subtype of Tuple{Union{Array{Float32, 3}, Array{Float32, 4}}

return type 
Tuple{Array{Float32, 4}, @NamedTuple{branch::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, trunk::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, additional::@NamedTuple{}}} 
does not match inferred return type 
Tuple{Union{Array{Float32, 3}, Array{Float32, 4}}, @NamedTuple{branch::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, trunk::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, additional::@NamedTuple{}}}

The tests pass if I comment out the @inferred and @jet cases.

@avik-pal
Copy link
Member

avik-pal commented Jul 9, 2024

Tuple{Union{Array{Float32, 3}, Array{Float32, 4}}, @NamedTuple{branch::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, trunk::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, additional::@NamedTuple{}}}

This is bad. You are returning a 3D array / 4D array based on the input sizes (which won't be type inferred). Avoid doing the dropdims

@ayushinav
Copy link
Contributor Author

ayushinav commented Jul 10, 2024

You are returning a 3D array / 4D array based on the input sizes (which won't be type inferred). Avoid doing the dropdims

Not sure how this might be an issue because the scalar tests also have the same dropdims calling

@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
        additional::Nothing) where {T1, T2}
    # b : p x nb
    # t : p x N x nb
    b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
    return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
end

still pass the test, and Scalar II and Vector Additonal layer tests calling

@inline function __project(
        b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
    # b : p x u x nb
    # t : p x N x nb

    if size(b, 2) == 1 || size(t, 2) == 1
        return additional(b .* t) # p x N x nb => out_dims x N x nb
    else
        b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
        t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb

        return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
    end
end

fail.

@ayushinav
Copy link
Contributor Author

Only doc tests fail for now.

@ayushinav ayushinav requested a review from avik-pal July 10, 2024 05:08
@avik-pal
Copy link
Member

Rebase with the latest changes to main.

@avik-pal
Copy link
Member

set your git config to rebase on pull instead of merge, else the commit history gets royally messed up.

Copy link

codecov bot commented Jul 12, 2024

Codecov Report

Attention: Patch coverage is 83.33333% with 5 lines in your changes missing coverage. Please review.

Project coverage is 93.20%. Comparing base (aaf7d45) to head (a8149e2).

Files Patch % Lines
src/utils.jl 80.76% 5 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (aaf7d45) and HEAD (a8149e2). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (aaf7d45) HEAD (a8149e2)
5 4
Additional details and impacted files
@@             Coverage Diff             @@
##              main      #11      +/-   ##
===========================================
- Coverage   100.00%   93.20%   -6.80%     
===========================================
  Files            7        7              
  Lines           77      103      +26     
===========================================
+ Hits            77       96      +19     
- Misses           0        7       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ayushinav ayushinav closed this Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeepOnet Multiple output
3 participants