1-
21# Benchmark script for LightGlue on real images
32from pathlib import Path
43import argparse
1514torch .set_grad_enabled (False )
1615
1716
18- def measure (matcher , data , device = ' cuda' , r = 100 ):
17+ def measure (matcher , data , device = " cuda" , r = 100 ):
1918 timings = np .zeros ((r , 1 ))
20- if device .type == ' cuda' :
19+ if device .type == " cuda" :
2120 starter = torch .cuda .Event (enable_timing = True )
2221 ender = torch .cuda .Event (enable_timing = True )
2322 # warmup
@@ -26,7 +25,7 @@ def measure(matcher, data, device='cuda', r=100):
2625 # measurements
2726 with torch .no_grad ():
2827 for rep in range (r ):
29- if device .type == ' cuda' :
28+ if device .type == " cuda" :
3029 starter .record ()
3130 _ = matcher (data )
3231 ender .record ()
@@ -40,77 +39,99 @@ def measure(matcher, data, device='cuda', r=100):
4039 timings [rep ] = curr_time
4140 mean_syn = np .sum (timings ) / r
4241 std_syn = np .std (timings )
43- return {' mean' : mean_syn , ' std' : std_syn }
42+ return {" mean" : mean_syn , " std" : std_syn }
4443
4544
4645def print_as_table (d , title , cnames ):
4746 print ()
48- header = f' { title :30} ' + ' ' .join ([f' { x :>7} ' for x in cnames ])
47+ header = f" { title :30} " + " " .join ([f" { x :>7} " for x in cnames ])
4948 print (header )
50- print ('-' * len (header ))
49+ print ("-" * len (header ))
5150 for k , l in d .items ():
52- print (f'{ k :30} ' , ' ' .join ([f'{ x :>7.1f} ' for x in l ]))
53-
54-
55- if __name__ == '__main__' :
56- parser = argparse .ArgumentParser (description = 'Benchmark script for LightGlue' )
57- parser .add_argument ('--device' , choices = ['auto' , 'cuda' , 'cpu' , 'mps' ],
58- default = 'auto' , help = 'device to benchmark on' )
59- parser .add_argument ('--compile' , action = 'store_true' ,
60- help = 'Compile LightGlue runs' )
61- parser .add_argument ('--no_flash' , action = 'store_true' ,
62- help = 'disable FlashAttention' )
63- parser .add_argument ('--no_prune_thresholds' , action = 'store_true' ,
64- help = 'disable pruning thresholds (i.e. always do pruning)' )
65- parser .add_argument ('--add_superglue' , action = 'store_true' ,
66- help = 'add SuperGlue to the benchmark (requires hloc)' )
67- parser .add_argument ('--measure' , default = 'time' ,
68- choices = ['time' , 'log-time' , 'throughput' ])
69- parser .add_argument ('--repeat' , '--r' , type = int , default = 100 ,
70- help = 'repetitions of measurements' )
71- parser .add_argument ('--num_keypoints' , nargs = "+" , type = int ,
72- default = [256 , 512 , 1024 , 2048 , 4096 ],
73- help = 'number of keypoints (list separated by spaces)' )
74- parser .add_argument ('--matmul_precision' , default = 'highest' ,
75- choices = ['highest' , 'high' , 'medium' ])
76- parser .add_argument ('--save' , default = None , type = str ,
77- help = 'path where figure should be saved' )
51+ print (f"{ k :30} " , " " .join ([f"{ x :>7.1f} " for x in l ]))
52+
53+
54+ if __name__ == "__main__" :
55+ parser = argparse .ArgumentParser (description = "Benchmark script for LightGlue" )
56+ parser .add_argument (
57+ "--device" ,
58+ choices = ["auto" , "cuda" , "cpu" , "mps" ],
59+ default = "auto" ,
60+ help = "device to benchmark on" ,
61+ )
62+ parser .add_argument ("--compile" , action = "store_true" , help = "Compile LightGlue runs" )
63+ parser .add_argument (
64+ "--no_flash" , action = "store_true" , help = "disable FlashAttention"
65+ )
66+ parser .add_argument (
67+ "--no_prune_thresholds" ,
68+ action = "store_true" ,
69+ help = "disable pruning thresholds (i.e. always do pruning)" ,
70+ )
71+ parser .add_argument (
72+ "--add_superglue" ,
73+ action = "store_true" ,
74+ help = "add SuperGlue to the benchmark (requires hloc)" ,
75+ )
76+ parser .add_argument (
77+ "--measure" , default = "time" , choices = ["time" , "log-time" , "throughput" ]
78+ )
79+ parser .add_argument (
80+ "--repeat" , "--r" , type = int , default = 100 , help = "repetitions of measurements"
81+ )
82+ parser .add_argument (
83+ "--num_keypoints" ,
84+ nargs = "+" ,
85+ type = int ,
86+ default = [256 , 512 , 1024 , 2048 , 4096 ],
87+ help = "number of keypoints (list separated by spaces)" ,
88+ )
89+ parser .add_argument (
90+ "--matmul_precision" , default = "highest" , choices = ["highest" , "high" , "medium" ]
91+ )
92+ parser .add_argument (
93+ "--save" , default = None , type = str , help = "path where figure should be saved"
94+ )
7895 args = parser .parse_intermixed_args ()
7996
80- device = torch .device (' cuda' if torch .cuda .is_available () else ' cpu' )
81- if args .device != ' auto' :
97+ device = torch .device (" cuda" if torch .cuda .is_available () else " cpu" )
98+ if args .device != " auto" :
8299 device = torch .device (args .device )
83100
84- print (' Running benchmark on device:' , device )
101+ print (" Running benchmark on device:" , device )
85102
86- images = Path (' assets' )
103+ images = Path (" assets" )
87104 inputs = {
88- 'easy' : (load_image (images / 'DSC_0411.JPG' ),
89- load_image (images / 'DSC_0410.JPG' )),
90- 'difficult' : (load_image (images / 'sacre_coeur1.jpg' ),
91- load_image (images / 'sacre_coeur2.jpg' )),
105+ "easy" : (
106+ load_image (images / "DSC_0411.JPG" ),
107+ load_image (images / "DSC_0410.JPG" ),
108+ ),
109+ "difficult" : (
110+ load_image (images / "sacre_coeur1.jpg" ),
111+ load_image (images / "sacre_coeur2.jpg" ),
112+ ),
92113 }
93114
94115 configs = {
95- ' LightGlue-full' : {
96- ' depth_confidence' : - 1 ,
97- ' width_confidence' : - 1 ,
116+ " LightGlue-full" : {
117+ " depth_confidence" : - 1 ,
118+ " width_confidence" : - 1 ,
98119 },
99120 # 'LG-prune': {
100121 # 'width_confidence': -1,
101122 # },
102123 # 'LG-depth': {
103124 # 'depth_confidence': -1,
104125 # },
105- ' LightGlue-adaptive' : {}
126+ " LightGlue-adaptive" : {},
106127 }
107128
108129 if args .compile :
109- configs = {** configs , ** {k + ' -compile' : v for k , v in configs .items ()}}
130+ configs = {** configs , ** {k + " -compile" : v for k , v in configs .items ()}}
110131
111132 sg_configs = {
112133 # 'SuperGlue': {},
113- ' SuperGlue-fast' : {' sinkhorn_iterations' : 5 }
134+ " SuperGlue-fast" : {" sinkhorn_iterations" : 5 }
114135 }
115136
116137 torch .set_float32_matmul_precision (args .matmul_precision )
@@ -119,89 +140,108 @@ def print_as_table(d, title, cnames):
119140
120141 extractor = SuperPoint (max_num_keypoints = None , detection_threshold = - 1 )
121142 extractor = extractor .eval ().to (device )
122- figsize = (len (inputs )* 4.5 , 4.5 )
143+ figsize = (len (inputs ) * 4.5 , 4.5 )
123144 fig , axes = plt .subplots (1 , len (inputs ), sharey = True , figsize = figsize )
124145 axes = axes if len (inputs ) > 1 else [axes ]
125- fig .canvas .manager .set_window_title (f' LightGlue benchmark ({ device .type } )' )
146+ fig .canvas .manager .set_window_title (f" LightGlue benchmark ({ device .type } )" )
126147
127148 for title , ax in zip (inputs .keys (), axes ):
128- ax .set_xscale (' log' , base = 2 )
149+ ax .set_xscale (" log" , base = 2 )
129150 bases = [2 ** x for x in range (7 , 16 )]
130151 ax .set_xticks (bases , bases )
131- ax .grid (which = ' major' )
132- if args .measure == ' log-time' :
133- ax .set_yscale (' log' )
152+ ax .grid (which = " major" )
153+ if args .measure == " log-time" :
154+ ax .set_yscale (" log" )
134155 yticks = [10 ** x for x in range (6 )]
135156 ax .set_yticks (yticks , yticks )
136157 mpos = [10 ** x * i for x in range (6 ) for i in range (2 , 10 )]
137- mlabel = [10 ** x * i if i in [2 , 5 ] else None for x in range (6 ) for i in range (2 , 10 )]
158+ mlabel = [
159+ 10 ** x * i if i in [2 , 5 ] else None
160+ for x in range (6 )
161+ for i in range (2 , 10 )
162+ ]
138163 ax .set_yticks (mpos , mlabel , minor = True )
139- ax .grid (which = ' minor' , linewidth = 0.2 )
164+ ax .grid (which = " minor" , linewidth = 0.2 )
140165 ax .set_title (title )
141166
142167 ax .set_xlabel ("# keypoints" )
143- if args .measure == ' throughput' :
144- ax .set_ylabel ("Throughput [pairs/s]" )
168+ if args .measure == " throughput" :
169+ ax .set_ylabel ("Throughput [pairs/s]" )
145170 else :
146171 ax .set_ylabel ("Latency [ms]" )
147172
148173 for name , conf in configs .items ():
149- print (' Run benchmark for:' , name )
174+ print (" Run benchmark for:" , name )
150175 torch .cuda .empty_cache ()
151- matcher = LightGlue (
152- features = 'superpoint' , flash = not args .no_flash , ** conf )
176+ matcher = LightGlue (features = "superpoint" , flash = not args .no_flash , ** conf )
153177 if args .no_prune_thresholds :
154178 matcher .pruning_keypoint_thresholds = {
155- k : - 1 for k in matcher .pruning_keypoint_thresholds }
179+ k : - 1 for k in matcher .pruning_keypoint_thresholds
180+ }
156181 matcher = matcher .eval ().to (device )
157- if name .endswith (' compile' ):
182+ if name .endswith (" compile" ):
158183 import torch ._dynamo
184+
159185 torch ._dynamo .reset () # avoid buffer overflow
160186 matcher .compile ()
161- for ( pair_name , ax ) in zip (inputs .keys (), axes ):
187+ for pair_name , ax in zip (inputs .keys (), axes ):
162188 image0 , image1 = [x .to (device ) for x in inputs [pair_name ]]
163189 runtimes = []
164190 for num_kpts in args .num_keypoints :
165- extractor .conf [' max_num_keypoints' ] = num_kpts
191+ extractor .conf [" max_num_keypoints" ] = num_kpts
166192 feats0 = extractor .extract (image0 )
167193 feats1 = extractor .extract (image1 )
168- runtime = measure (matcher ,
169- {'image0' : feats0 , 'image1' : feats1 },
170- device = device , r = args .repeat )['mean' ]
194+ runtime = measure (
195+ matcher ,
196+ {"image0" : feats0 , "image1" : feats1 },
197+ device = device ,
198+ r = args .repeat ,
199+ )["mean" ]
171200 results [pair_name ][name ].append (
172- 1000 / runtime if args .measure == 'throughput' else runtime )
173- ax .plot (args .num_keypoints , results [pair_name ][name ], label = name ,
174- marker = 'o' )
201+ 1000 / runtime if args .measure == "throughput" else runtime
202+ )
203+ ax .plot (
204+ args .num_keypoints , results [pair_name ][name ], label = name , marker = "o"
205+ )
175206 del matcher , feats0 , feats1
176207
177208 if args .add_superglue :
178209 from hloc .matchers .superglue import SuperGlue
210+
179211 for name , conf in sg_configs .items ():
180- print (' Run benchmark for:' , name )
212+ print (" Run benchmark for:" , name )
181213 matcher = SuperGlue (conf )
182214 matcher = matcher .eval ().to (device )
183- for ( pair_name , ax ) in zip (inputs .keys (), axes ):
215+ for pair_name , ax in zip (inputs .keys (), axes ):
184216 image0 , image1 = [x .to (device ) for x in inputs [pair_name ]]
185217 runtimes = []
186218 for num_kpts in args .num_keypoints :
187- extractor .conf [' max_num_keypoints' ] = num_kpts
219+ extractor .conf [" max_num_keypoints" ] = num_kpts
188220 feats0 = extractor .extract (image0 )
189221 feats1 = extractor .extract (image1 )
190222 data = {
191- ' image0' : image0 [None ],
192- ' image1' : image1 [None ],
193- ** {k + '0' : v for k , v in feats0 .items ()},
194- ** {k + '1' : v for k , v in feats1 .items ()}
223+ " image0" : image0 [None ],
224+ " image1" : image1 [None ],
225+ ** {k + "0" : v for k , v in feats0 .items ()},
226+ ** {k + "1" : v for k , v in feats1 .items ()},
195227 }
196- data ['scores0' ] = data ['keypoint_scores0' ]
197- data ['scores1' ] = data ['keypoint_scores1' ]
198- data ['descriptors0' ] = data ['descriptors0' ].transpose (- 1 , - 2 ).contiguous ()
199- data ['descriptors1' ] = data ['descriptors1' ].transpose (- 1 , - 2 ).contiguous ()
200- runtime = measure (matcher , data , device = device , r = args .repeat )['mean' ]
228+ data ["scores0" ] = data ["keypoint_scores0" ]
229+ data ["scores1" ] = data ["keypoint_scores1" ]
230+ data ["descriptors0" ] = (
231+ data ["descriptors0" ].transpose (- 1 , - 2 ).contiguous ()
232+ )
233+ data ["descriptors1" ] = (
234+ data ["descriptors1" ].transpose (- 1 , - 2 ).contiguous ()
235+ )
236+ runtime = measure (matcher , data , device = device , r = args .repeat )[
237+ "mean"
238+ ]
201239 results [pair_name ][name ].append (
202- 1000 / runtime if args .measure == 'throughput' else runtime )
203- ax .plot (args .num_keypoints , results [pair_name ][name ], label = name ,
204- marker = 'o' )
240+ 1000 / runtime if args .measure == "throughput" else runtime
241+ )
242+ ax .plot (
243+ args .num_keypoints , results [pair_name ][name ], label = name , marker = "o"
244+ )
205245 del matcher , data , image0 , image1 , feats0 , feats1
206246
207247 for name , runtimes in results .items ():
0 commit comments