Skip to content

Commit 837e5ab

Browse files
mjanuszcopybara-github
authored andcommitted
Add support for filtering precomputed coordinates with bounding boxes.
PiperOrigin-RevId: 844172271
1 parent 9af8e87 commit 837e5ab

File tree

2 files changed

+37
-35
lines changed

2 files changed

+37
-35
lines changed

ffn/input/volume.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -285,47 +285,40 @@ def sample_coordinates(
285285
ds = inputs.sample_patch_coordinates(
286286
boxes_cfg, volume_names, rng_seed=rng_seed
287287
)
288-
elif config.sampling.bag_coords:
289-
raise NotImplementedError('bag file reading not supported yet.')
290-
elif config.sampling.arrayrecord_coords:
291-
292-
def _make_source(pattern):
293-
return array_record.ArrayRecordDataSource(
294-
sorted(tf.io.gfile.glob(pattern))
295-
)
296-
297-
if isinstance(config.sampling.arrayrecord_coords, str):
298-
sources = _make_source(config.sampling.arrayrecord_coords)
299-
weights = [1.0]
300-
else:
301-
sources, weights = [], []
302-
for pattern, weight in config.sampling.arrayrecord_coords.items():
303-
sources.append(_make_source(pattern))
304-
weights.append(weight)
305-
306-
def _tf_load(idx, source):
307-
return tf.numpy_function(
308-
lambda x, src=source: src[x], [idx], [tf.string], stateful=False
309-
)[0]
310-
311-
def _sample_indices(source, seed):
312-
rng = np.random.default_rng(seed)
313-
ds = tf.data.Dataset.from_tensor_slices(rng.permutation(len(source)))
314-
ds = ds.map(lambda x, src=source: _tf_load(x, source=src))
315-
return ds
316-
317-
all_ds = [_sample_indices(s, rng_seed) for s in sources]
318-
319-
weights = np.array(weights)
320-
weights = weights.astype(float) / weights.sum()
321-
ds = tf.data.Dataset.sample_from_datasets(all_ds, weights, seed=rng_seed)
322-
ds = ds.map(inputs.parse_tf_coords, deterministic=True)
323288
else:
324289
raise ValueError('No sampling scheme specified.')
325290

326291
return ds
327292

328293

294+
def _coord_in_bboxes_np(
295+
coord: np.ndarray,
296+
volname: bytes,
297+
bboxes: dict[str, Sequence[bounding_box.BoundingBox]],
298+
) -> bool:
299+
"""Checks if coordinates are in bounding boxes."""
300+
volname = volname.decode('utf-8')
301+
if volname not in bboxes:
302+
return True
303+
for bbox in bboxes[volname]:
304+
if np.all(coord >= bbox.start) and np.all(coord < bbox.end):
305+
return True
306+
return False
307+
308+
309+
def _filter_coordinates_by_bbox(
310+
item: dict[str, tf.Tensor],
311+
bboxes: dict[str, Sequence[bounding_box.BoundingBox]],
312+
) -> tf.Tensor:
313+
ret = tf.numpy_function(
314+
lambda c, v: _coord_in_bboxes_np(c, v, bboxes),
315+
[item['coord'][0], item['volname'][0]],
316+
tf.bool,
317+
)
318+
ret.set_shape([])
319+
return ret
320+
321+
329322
def load_and_augment_subvolumes(
330323
config: InputConfig,
331324
rng_seed: int | None = None,
@@ -358,6 +351,14 @@ def load_and_augment_subvolumes(
358351
if transform_locations is not None:
359352
ds = transform_locations(ds, config)
360353

354+
if config.sampling.bounding_boxes:
355+
ds = ds.filter(
356+
ft.partial(
357+
_filter_coordinates_by_bbox,
358+
bboxes=config.sampling.bounding_boxes,
359+
)
360+
)
361+
361362
for vol in config.volumes.values():
362363
if vol.filter_shape is not None:
363364
ds = ds.filter(

ffn/jax/input_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Callable
2222

2323
from absl import logging
24+
from connectomics.common import bounding_box
2425
from connectomics.common import utils
2526
from ffn.input import volume
2627
from ffn.jax import tracker

0 commit comments

Comments
 (0)