|
8 | 8 | import tempfile |
9 | 9 | from glob import glob |
10 | 10 | import os |
| 11 | +import logging |
| 12 | +from funlib.persistence import open_ds, prepare_ds |
| 13 | +from funlib.geometry import Roi, Coordinate |
| 14 | +import numpy as np |
| 15 | +from skimage.draw import line_nd |
11 | 16 |
|
12 | 17 |
|
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
13 | 20 | def download_wk_skeleton( |
14 | 21 | annotation_ID, |
15 | 22 | save_path, |
@@ -52,6 +59,90 @@ def download_wk_skeleton( |
52 | 59 | return zip_path |
53 | 60 |
|
54 | 61 |
|
| 62 | +def parse_skeleton(zip_path) -> dict: |
| 63 | + fin = zip_path |
| 64 | + if not fin.endswith(".zip"): |
| 65 | + try: |
| 66 | + fin = get_updated_skeleton(zip_path) |
| 67 | + assert fin.endswith(".zip"), "Skeleton zip file not found." |
| 68 | + except: |
| 69 | + assert False, "CATMAID NOT IMPLEMENTED" |
| 70 | + |
| 71 | + wk_skels = wk.skeleton.Skeleton.load(fin) |
| 72 | + # return wk_skels |
| 73 | + skel_coor = {} |
| 74 | + for tree in wk_skels.trees: |
| 75 | + skel_coor[tree.id] = [] |
| 76 | + for start, end in tree.edges.keys(): |
| 77 | + start_pos = start.position.to_np() |
| 78 | + end_pos = end.position.to_np() |
| 79 | + skel_coor[tree.id].append([start_pos, end_pos]) |
| 80 | + |
| 81 | + return skel_coor |
| 82 | + |
| 83 | + |
| 84 | +def get_updated_skeleton(zip_path) -> str: |
| 85 | + if not os.path.exists(zip_path): |
| 86 | + path = os.path.dirname(os.path.realpath(zip_path)) |
| 87 | + search_path = os.path.join(path, "skeletons/*") |
| 88 | + files = glob(search_path) |
| 89 | + if len(files) == 0: |
| 90 | + skel_file = download_wk_skeleton() |
| 91 | + else: |
| 92 | + skel_file = max(files, key=os.path.getctime) |
| 93 | + skel_file = os.path.abspath(skel_file) |
| 94 | + |
| 95 | + return skel_file |
| 96 | + |
| 97 | +def rasterize_skeleton(zip_path:str, |
| 98 | + raw_file:str, |
| 99 | + raw_ds:str) -> np.ndarray: |
| 100 | + |
| 101 | + logger.info(f"Rasterizing skeleton...") |
| 102 | + |
| 103 | + skel_coor = parse_skeleton(zip_path) |
| 104 | + |
| 105 | + # Initialize rasterized skeleton image |
| 106 | + raw = open_ds(raw_file, raw_ds) |
| 107 | + |
| 108 | + dataset_shape = raw.data.shape |
| 109 | + print(dataset_shape) |
| 110 | + voxel_size = raw.voxel_size |
| 111 | + offset = raw.roi.begin # unhardcode for nonzero offset |
| 112 | + image = np.zeros(dataset_shape, dtype=np.uint8) |
| 113 | + |
| 114 | + def adjust(coor): |
| 115 | + ds_under = [x-1 for x in dataset_shape] |
| 116 | + return np.min([coor - offset, ds_under], 0) |
| 117 | + |
| 118 | + print("adjusting . . .") |
| 119 | + for id, tree in skel_coor.items(): |
| 120 | + # iterates through ever node and assigns id to {image} |
| 121 | + for start, end in tree: |
| 122 | + line = line_nd(adjust(start), adjust(end)) |
| 123 | + image[line] = id |
| 124 | + |
| 125 | + |
| 126 | + # Save GT rasterization #TODO: implement daisy blockwise option |
| 127 | + total_roi = Roi( |
| 128 | + Coordinate(offset) * Coordinate(voxel_size), |
| 129 | + Coordinate(dataset_shape) * Coordinate(voxel_size), |
| 130 | + ) |
| 131 | + |
| 132 | + print("saving . . .") |
| 133 | + out_ds = prepare_ds( |
| 134 | + raw_file, |
| 135 | + "volumes/training_rasters", |
| 136 | + total_roi, |
| 137 | + voxel_size, |
| 138 | + image.dtype, |
| 139 | + delete=True, |
| 140 | + ) |
| 141 | + out_ds[out_ds.roi] = image |
| 142 | + |
| 143 | + return image |
| 144 | + |
| 145 | + |
55 | 146 | def get_wk_mask( |
56 | 147 | annotation_ID, |
57 | 148 | save_path, # TODO: Add mkdtemp() as default |
|
0 commit comments