Skip to content

Commit 639f6f7

Browse files
authored
Merge pull request #83 from RapidAI/optim_wired_logic_decode
fix: optim logic points decode
2 parents 574391d + 6e5f7c4 commit 639f6f7

File tree

1 file changed

+40
-48
lines changed

1 file changed

+40
-48
lines changed

wired_table_rec/table_recover.py

+40-48
Original file line numberDiff line numberDiff line change
@@ -66,57 +66,49 @@ def get_benchmark_cols(
6666
) -> Tuple[np.ndarray, List[float], int]:
6767
longest_col = max(rows.values(), key=lambda x: len(x))
6868
longest_col_points = polygons[longest_col]
69-
longest_x = longest_col_points[:, 0, 0]
70-
69+
longest_x_start = list(longest_col_points[:, 0, 0])
70+
longest_x_end = list(longest_col_points[:, 2, 0])
71+
min_x = longest_x_start[0]
72+
max_x = longest_x_end[-1]
7173
theta = 15
72-
for row_value in rows.values():
73-
cur_row = polygons[row_value][:, 0, 0]
74-
75-
range_res = {}
76-
for idx, cur_v in enumerate(cur_row):
77-
start_idx, end_idx = None, None
78-
for i, v in enumerate(longest_x):
79-
if cur_v - theta <= v <= cur_v + theta:
80-
break
8174

82-
if cur_v > v:
83-
start_idx = i
84-
continue
75+
# 根据当前col的起始x坐标,更新col的边界
76+
def update_longest_col(col_x_list, cur_v, min_x_, max_x_):
77+
for i, v in enumerate(col_x_list):
78+
if cur_v - theta <= v <= cur_v + theta:
79+
break
80+
if cur_v > v:
81+
continue
82+
if cur_v < min_x_:
83+
col_x_list.insert(0, cur_v)
84+
min_x_ = cur_v
85+
break
86+
if cur_v > max_x_:
87+
col_x_list.append(max_x_)
88+
max_x_ = cur_v
89+
if cur_v < v:
90+
col_x_list.insert(i, cur_v)
91+
break
92+
return min_x_, max_x_
8593

86-
if cur_v < v:
87-
end_idx = i
88-
break
94+
for row_value in rows.values():
95+
cur_row_start = list(polygons[row_value][:, 0, 0])
96+
cur_row_end = list(polygons[row_value][:, 2, 0])
97+
for idx, (cur_v_start, cur_v_end) in enumerate(
98+
zip(cur_row_start, cur_row_end)
99+
):
100+
min_x, max_x = update_longest_col(
101+
longest_x_start, cur_v_start, min_x, max_x
102+
)
103+
min_x, max_x = update_longest_col(
104+
longest_x_start, cur_v_end, min_x, max_x
105+
)
89106

90-
range_res[idx] = [start_idx, end_idx]
91-
92-
sorted_res = dict(
93-
sorted(range_res.items(), key=lambda x: x[0], reverse=True)
94-
)
95-
for k, v in sorted_res.items():
96-
# bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55
97-
# 最长列不包含第一列和最后一列的场景需要兼容
98-
if all(v) or v[1] == 0:
99-
longest_x = np.insert(longest_x, v[1], cur_row[k])
100-
longest_col_points = np.insert(
101-
longest_col_points, v[1], polygons[row_value[k]], axis=0
102-
)
103-
elif v[0] and v[0] + 1 == len(longest_x):
104-
longest_x = np.append(longest_x, cur_row[k])
105-
longest_col_points = np.append(
106-
longest_col_points,
107-
polygons[row_value[k]][np.newaxis, :, :],
108-
axis=0,
109-
)
110-
# 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
111-
rightmost_idxs = [v[-1] for v in rows.values()]
112-
rightmost_boxes = polygons[rightmost_idxs]
113-
min_width = min([self.compute_L2(v[3, :], v[0, :]) for v in rightmost_boxes])
114-
115-
each_col_widths = (longest_x[1:] - longest_x[:-1]).tolist()
116-
each_col_widths.append(min_width)
117-
118-
col_nums = longest_x.shape[0]
119-
return longest_col_points, each_col_widths, col_nums
107+
longest_x_start = np.array(longest_x_start)
108+
each_col_widths = (longest_x_start[1:] - longest_x_start[:-1]).tolist()
109+
each_col_widths.append(max_x - longest_x_start[-1])
110+
col_nums = longest_x_start.shape[0]
111+
return longest_x_start, each_col_widths, col_nums
120112

121113
def get_benchmark_rows(
122114
self, rows: Dict[int, List], polygons: np.ndarray
@@ -160,7 +152,7 @@ def get_merge_cells(
160152
box_width = self.compute_L2(box[3, :], box[0, :])
161153

162154
# 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
163-
loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0]))
155+
loc_col_idx = np.argmin(np.abs(longest_col - box[0, 0]))
164156
col_start = max(sum(one_col_result.values()), loc_col_idx)
165157

166158
# 计算合并多少个列方向单元格

0 commit comments

Comments
 (0)