@@ -285,47 +285,40 @@ def sample_coordinates(
285285 ds = inputs .sample_patch_coordinates (
286286 boxes_cfg , volume_names , rng_seed = rng_seed
287287 )
288- elif config .sampling .bag_coords :
289- raise NotImplementedError ('bag file reading not supported yet.' )
290- elif config .sampling .arrayrecord_coords :
291-
292- def _make_source (pattern ):
293- return array_record .ArrayRecordDataSource (
294- sorted (tf .io .gfile .glob (pattern ))
295- )
296-
297- if isinstance (config .sampling .arrayrecord_coords , str ):
298- sources = _make_source (config .sampling .arrayrecord_coords )
299- weights = [1.0 ]
300- else :
301- sources , weights = [], []
302- for pattern , weight in config .sampling .arrayrecord_coords .items ():
303- sources .append (_make_source (pattern ))
304- weights .append (weight )
305-
306- def _tf_load (idx , source ):
307- return tf .numpy_function (
308- lambda x , src = source : src [x ], [idx ], [tf .string ], stateful = False
309- )[0 ]
310-
311- def _sample_indices (source , seed ):
312- rng = np .random .default_rng (seed )
313- ds = tf .data .Dataset .from_tensor_slices (rng .permutation (len (source )))
314- ds = ds .map (lambda x , src = source : _tf_load (x , source = src ))
315- return ds
316-
317- all_ds = [_sample_indices (s , rng_seed ) for s in sources ]
318-
319- weights = np .array (weights )
320- weights = weights .astype (float ) / weights .sum ()
321- ds = tf .data .Dataset .sample_from_datasets (all_ds , weights , seed = rng_seed )
322- ds = ds .map (inputs .parse_tf_coords , deterministic = True )
323288 else :
324289 raise ValueError ('No sampling scheme specified.' )
325290
326291 return ds
327292
328293
294+ def _coord_in_bboxes_np (
295+ coord : np .ndarray ,
296+ volname : bytes ,
297+ bboxes : dict [str , Sequence [bounding_box .BoundingBox ]],
298+ ) -> bool :
299+ """Checks if coordinates are in bounding boxes."""
300+ volname = volname .decode ('utf-8' )
301+ if volname not in bboxes :
302+ return True
303+ for bbox in bboxes [volname ]:
304+ if np .all (coord >= bbox .start ) and np .all (coord < bbox .end ):
305+ return True
306+ return False
307+
308+
309+ def _filter_coordinates_by_bbox (
310+ item : dict [str , tf .Tensor ],
311+ bboxes : dict [str , Sequence [bounding_box .BoundingBox ]],
312+ ) -> tf .Tensor :
313+ ret = tf .numpy_function (
314+ lambda c , v : _coord_in_bboxes_np (c , v , bboxes ),
315+ [item ['coord' ][0 ], item ['volname' ][0 ]],
316+ tf .bool ,
317+ )
318+ ret .set_shape ([])
319+ return ret
320+
321+
329322def load_and_augment_subvolumes (
330323 config : InputConfig ,
331324 rng_seed : int | None = None ,
@@ -358,6 +351,14 @@ def load_and_augment_subvolumes(
358351 if transform_locations is not None :
359352 ds = transform_locations (ds , config )
360353
354+ if config .sampling .bounding_boxes :
355+ ds = ds .filter (
356+ ft .partial (
357+ _filter_coordinates_by_bbox ,
358+ bboxes = config .sampling .bounding_boxes ,
359+ )
360+ )
361+
361362 for vol in config .volumes .values ():
362363 if vol .filter_shape is not None :
363364 ds = ds .filter (
0 commit comments