Skip to content

Commit f27d0e3

Browse files
committed
accept layer_reshape(-1)
1 parent faef4ce commit f27d0e3

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
- Added `str` S3 method for Keras Variables.
66

7+
- `layer_reshape()` can now accept `-1` as a sentinel for an automatically calculated axis size.
8+
79
- Updated dependencies declared by `use_backend("jax", gpu=TRUE)`
810
for compatability with `keras-hub`.
911

R/layers-reshaping.R

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,20 @@ function (object, n, ...)
398398
layer_reshape <-
399399
function (object, target_shape, ...)
400400
{
401-
args <- capture_args(list(input_shape = normalize_shape,
402-
batch_size = as_integer, batch_input_shape = normalize_shape,
403-
target_shape = as_integer), ignore = "object")
404-
create_layer(keras$layers$Reshape, object, args)
401+
args <- capture_args(
402+
list(
403+
input_shape = normalize_shape,
404+
batch_input_shape = normalize_shape,
405+
batch_size = as_integer,
406+
target_shape = function(shp) {
407+
tuple(lapply(py_to_r(normalize_shape(shp)), function(d) {
408+
if (is.null(d) || is.na(d)) - 1L else d
409+
}))
410+
}
411+
),
412+
ignore = "object"
413+
)
414+
create_layer(keras$layers$Reshape, object, args)
405415
}
406416

407417

0 commit comments

Comments
 (0)