@@ -66,57 +66,49 @@ def get_benchmark_cols(
66
66
) -> Tuple [np .ndarray , List [float ], int ]:
67
67
longest_col = max (rows .values (), key = lambda x : len (x ))
68
68
longest_col_points = polygons [longest_col ]
69
- longest_x = longest_col_points [:, 0 , 0 ]
70
-
69
+ longest_x_start = list (longest_col_points [:, 0 , 0 ])
70
+ longest_x_end = list (longest_col_points [:, 2 , 0 ])
71
+ min_x = longest_x_start [0 ]
72
+ max_x = longest_x_end [- 1 ]
71
73
theta = 15
72
- for row_value in rows .values ():
73
- cur_row = polygons [row_value ][:, 0 , 0 ]
74
-
75
- range_res = {}
76
- for idx , cur_v in enumerate (cur_row ):
77
- start_idx , end_idx = None , None
78
- for i , v in enumerate (longest_x ):
79
- if cur_v - theta <= v <= cur_v + theta :
80
- break
81
74
82
- if cur_v > v :
83
- start_idx = i
84
- continue
75
+ # 根据当前col的起始x坐标,更新col的边界
76
+ def update_longest_col (col_x_list , cur_v , min_x_ , max_x_ ):
77
+ for i , v in enumerate (col_x_list ):
78
+ if cur_v - theta <= v <= cur_v + theta :
79
+ break
80
+ if cur_v > v :
81
+ continue
82
+ if cur_v < min_x_ :
83
+ col_x_list .insert (0 , cur_v )
84
+ min_x_ = cur_v
85
+ break
86
+ if cur_v > max_x_ :
87
+ col_x_list .append (max_x_ )
88
+ max_x_ = cur_v
89
+ if cur_v < v :
90
+ col_x_list .insert (i , cur_v )
91
+ break
92
+ return min_x_ , max_x_
85
93
86
- if cur_v < v :
87
- end_idx = i
88
- break
94
+ for row_value in rows .values ():
95
+ cur_row_start = list (polygons [row_value ][:, 0 , 0 ])
96
+ cur_row_end = list (polygons [row_value ][:, 2 , 0 ])
97
+ for idx , (cur_v_start , cur_v_end ) in enumerate (
98
+ zip (cur_row_start , cur_row_end )
99
+ ):
100
+ min_x , max_x = update_longest_col (
101
+ longest_x_start , cur_v_start , min_x , max_x
102
+ )
103
+ min_x , max_x = update_longest_col (
104
+ longest_x_start , cur_v_end , min_x , max_x
105
+ )
89
106
90
- range_res [idx ] = [start_idx , end_idx ]
91
-
92
- sorted_res = dict (
93
- sorted (range_res .items (), key = lambda x : x [0 ], reverse = True )
94
- )
95
- for k , v in sorted_res .items ():
96
- # bugfix: https://github.com/RapidAI/TableStructureRec/discussions/55
97
- # 最长列不包含第一列和最后一列的场景需要兼容
98
- if all (v ) or v [1 ] == 0 :
99
- longest_x = np .insert (longest_x , v [1 ], cur_row [k ])
100
- longest_col_points = np .insert (
101
- longest_col_points , v [1 ], polygons [row_value [k ]], axis = 0
102
- )
103
- elif v [0 ] and v [0 ] + 1 == len (longest_x ):
104
- longest_x = np .append (longest_x , cur_row [k ])
105
- longest_col_points = np .append (
106
- longest_col_points ,
107
- polygons [row_value [k ]][np .newaxis , :, :],
108
- axis = 0 ,
109
- )
110
- # 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
111
- rightmost_idxs = [v [- 1 ] for v in rows .values ()]
112
- rightmost_boxes = polygons [rightmost_idxs ]
113
- min_width = min ([self .compute_L2 (v [3 , :], v [0 , :]) for v in rightmost_boxes ])
114
-
115
- each_col_widths = (longest_x [1 :] - longest_x [:- 1 ]).tolist ()
116
- each_col_widths .append (min_width )
117
-
118
- col_nums = longest_x .shape [0 ]
119
- return longest_col_points , each_col_widths , col_nums
107
+ longest_x_start = np .array (longest_x_start )
108
+ each_col_widths = (longest_x_start [1 :] - longest_x_start [:- 1 ]).tolist ()
109
+ each_col_widths .append (max_x - longest_x_start [- 1 ])
110
+ col_nums = longest_x_start .shape [0 ]
111
+ return longest_x_start , each_col_widths , col_nums
120
112
121
113
def get_benchmark_rows (
122
114
self , rows : Dict [int , List ], polygons : np .ndarray
@@ -160,7 +152,7 @@ def get_merge_cells(
160
152
box_width = self .compute_L2 (box [3 , :], box [0 , :])
161
153
162
154
# 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
163
- loc_col_idx = np .argmin (np .abs (longest_col [:, 0 , 0 ] - box [0 , 0 ]))
155
+ loc_col_idx = np .argmin (np .abs (longest_col - box [0 , 0 ]))
164
156
col_start = max (sum (one_col_result .values ()), loc_col_idx )
165
157
166
158
# 计算合并多少个列方向单元格
0 commit comments