diff --git a/stereo/algorithm/cell_correction_fast_by_mask.py b/stereo/algorithm/cell_correction_fast_by_mask.py index 3379eb55..4313e756 100644 --- a/stereo/algorithm/cell_correction_fast_by_mask.py +++ b/stereo/algorithm/cell_correction_fast_by_mask.py @@ -205,7 +205,7 @@ def array_to_block( bpc = (width - overlap) // (col_block_size - overlap) # blocks per column if ((height - overlap) % (row_block_size - overlap)) > 0: bpr += 1 - if ((width - overlap) // (col_block_size - overlap)) > 0: + if ((width - overlap) % (col_block_size - overlap)) > 0: bpc += 1 if n_split_data_jobs in (0, 1): block_list = [create_edm_labels(block) for block in @@ -230,7 +230,9 @@ def merge_to_mask(final_result, bpr, bpc, mask, start, end, overlap=MASK_BLOCK_O for rr in range(bpr): row_img_list = [] for cc in range(bpc): - if cc == 0: + if bpc == 1: + img = final_result[rr * bpc] + elif cc == 0: img = final_result[rr * bpc][:, :-half_step] elif cc == bpc - 1: img = final_result[rr * bpc + cc][:, half_step:] @@ -239,7 +241,9 @@ def merge_to_mask(final_result, bpr, bpc, mask, start, end, overlap=MASK_BLOCK_O row_img_list.append(img) row_img = np.concatenate(row_img_list, axis=1) - if rr == 0: + if bpr == 1: + img = row_img + elif rr == 0: img = row_img[:-half_step] elif rr == bpr - 1: img = row_img[half_step:] @@ -252,6 +256,22 @@ def merge_to_mask(final_result, bpr, bpc, mask, start, end, overlap=MASK_BLOCK_O return mask +# @log_consumed_time +# def merge_to_mask(final_result, bpr, bpc, mask, start, end, overlap=MASK_BLOCK_OVERLAP_DEFAULT): +# full_img_shape = (end[0] - start[0], end[1] - start[1]) +# full_img = np.zeros(full_img_shape, dtype=mask.dtype) +# for rr in range(bpr): +# for cc in range(bpc): +# row_start = rr * (MASK_ROW_BLOCK_SIZE_DEFAULT - overlap) +# col_start = cc * (MASK_COL_BLOCK_SIZE_DEFAULT - overlap) +# row_end = min(row_start + MASK_ROW_BLOCK_SIZE_DEFAULT, full_img_shape[0]) +# col_end = min(col_start + MASK_COL_BLOCK_SIZE_DEFAULT, full_img_shape[1]) +# full_img[row_start:row_end, col_start:col_end] = final_result[rr * bpc + cc] + +# mask[start[0]:end[0], start[1]:end[1]] = full_img + +# return mask + @log_consumed_time def est_para(mask):