Skip to content

Commit 6f05a87

Browse files
torfjeldeyebai
andauthored
Fix for bug in forward of Stacked (#192)
* fixed a bug with the default forward for stacked * bump patch version * Update Project.toml Co-authored-by: Hong Ge <[email protected]>
1 parent 95020ff commit 6f05a87

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.10.5"
3+
version = "0.10.6"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -35,4 +35,4 @@ MappedArrays = "0.2.2, 0.3, 0.4"
3535
Reexport = "0.2, 1"
3636
Requires = "0.5, 1"
3737
Roots = "1.3.4, 2"
38-
julia = "1.3"
38+
julia = "1.6"

src/bijectors/stacked.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,10 @@ function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
164164
N = length(sb.bs)
165165
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]])
166166
logjac = sum(linit)
167-
ys = mapvcat(drop(sb.bs, 1), drop(sb.ranges, 1)) do b, r
167+
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r
168168
y, l = with_logabsdet_jacobian(b, x[r])
169169
logjac += sum(l)
170170
y
171171
end
172-
return (vcat(yinit, ys), logjac)
172+
return (ys, logjac)
173173
end

0 commit comments

Comments
 (0)