Skip to content

Commit

Permalink
fix: nomad model fix (#55)
Browse files Browse the repository at this point in the history
* bug : nomad model fix

* formatter
  • Loading branch information
ayushinav authored Dec 29, 2024
1 parent c9a2530 commit 098b38e
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/models/nomad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
@concrete struct NOMAD <: AbstractLuxWrapperLayer{:model}
model <: Chain
@concrete struct NOMAD <: AbstractLuxContainerLayer{(:approximator, :decoder)}
approximator
decoder
concatenate <: Function
end

"""
Expand Down Expand Up @@ -96,8 +98,14 @@ function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8),
return NOMAD(approximator_net, decoder_net, concatenate)
end

function NOMAD(approximator, decoder, concatenate=nomad_concatenate)
return NOMAD(Chain(Parallel(concatenate, approximator, NoOpLayer()), decoder))
function (nomad::NOMAD)(x, ps, st::NamedTuple)
a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator)
out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder)
return out, (approximator=st_a, decoder=st_d)
end

function NOMAD(approximator_net, decoder_net; concatenate=nomad_concatenate)
NOMAD(approximator_net, decoder_net, concatenate)
end

batch_vectorize(x::AbstractArray) = reshape(x, :, size(x, ndims(x)))
Expand Down

0 comments on commit 098b38e

Please sign in to comment.