Skip to content

Commit ca40df7

Browse files
authored
Auto formatting (#34)
- Use black: `python -m black .` - [Update flake8 rules](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#flake8)
1 parent 86d8f67 commit ca40df7

File tree

11 files changed

+529
-406
lines changed

11 files changed

+529
-406
lines changed

.flake8

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[flake8]
2+
max-line-length = 88
3+
extend-ignore = E203

.github/workflows/black.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
name: Black Formatting Check
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
types: [ assigned, opened, synchronize, reopened ]
8+
jobs:
9+
formatting-check:
10+
name: Formatting Check
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v3
14+
- uses: psf/black@stable
15+
with:
16+
jupyter: true

benchmark.py

Lines changed: 125 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Benchmark script for LightGlue on real images
32
from pathlib import Path
43
import argparse
@@ -15,9 +14,9 @@
1514
torch.set_grad_enabled(False)
1615

1716

18-
def measure(matcher, data, device='cuda', r=100):
17+
def measure(matcher, data, device="cuda", r=100):
1918
timings = np.zeros((r, 1))
20-
if device.type == 'cuda':
19+
if device.type == "cuda":
2120
starter = torch.cuda.Event(enable_timing=True)
2221
ender = torch.cuda.Event(enable_timing=True)
2322
# warmup
@@ -26,7 +25,7 @@ def measure(matcher, data, device='cuda', r=100):
2625
# measurements
2726
with torch.no_grad():
2827
for rep in range(r):
29-
if device.type == 'cuda':
28+
if device.type == "cuda":
3029
starter.record()
3130
_ = matcher(data)
3231
ender.record()
@@ -40,77 +39,99 @@ def measure(matcher, data, device='cuda', r=100):
4039
timings[rep] = curr_time
4140
mean_syn = np.sum(timings) / r
4241
std_syn = np.std(timings)
43-
return {'mean': mean_syn, 'std': std_syn}
42+
return {"mean": mean_syn, "std": std_syn}
4443

4544

4645
def print_as_table(d, title, cnames):
4746
print()
48-
header = f'{title:30} '+' '.join([f'{x:>7}' for x in cnames])
47+
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
4948
print(header)
50-
print('-'*len(header))
49+
print("-" * len(header))
5150
for k, l in d.items():
52-
print(f'{k:30}', ' '.join([f'{x:>7.1f}' for x in l]))
53-
54-
55-
if __name__ == '__main__':
56-
parser = argparse.ArgumentParser(description='Benchmark script for LightGlue')
57-
parser.add_argument('--device', choices=['auto', 'cuda', 'cpu', 'mps'],
58-
default='auto', help='device to benchmark on')
59-
parser.add_argument('--compile', action='store_true',
60-
help='Compile LightGlue runs')
61-
parser.add_argument('--no_flash', action='store_true',
62-
help='disable FlashAttention')
63-
parser.add_argument('--no_prune_thresholds', action='store_true',
64-
help='disable pruning thresholds (i.e. always do pruning)')
65-
parser.add_argument('--add_superglue', action='store_true',
66-
help='add SuperGlue to the benchmark (requires hloc)')
67-
parser.add_argument('--measure', default='time',
68-
choices=['time', 'log-time', 'throughput'])
69-
parser.add_argument('--repeat', '--r', type=int, default=100,
70-
help='repetitions of measurements')
71-
parser.add_argument('--num_keypoints', nargs="+", type=int,
72-
default=[256, 512, 1024, 2048, 4096],
73-
help='number of keypoints (list separated by spaces)')
74-
parser.add_argument('--matmul_precision', default='highest',
75-
choices=['highest', 'high', 'medium'])
76-
parser.add_argument('--save', default=None, type=str,
77-
help='path where figure should be saved')
51+
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
52+
53+
54+
if __name__ == "__main__":
55+
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
56+
parser.add_argument(
57+
"--device",
58+
choices=["auto", "cuda", "cpu", "mps"],
59+
default="auto",
60+
help="device to benchmark on",
61+
)
62+
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
63+
parser.add_argument(
64+
"--no_flash", action="store_true", help="disable FlashAttention"
65+
)
66+
parser.add_argument(
67+
"--no_prune_thresholds",
68+
action="store_true",
69+
help="disable pruning thresholds (i.e. always do pruning)",
70+
)
71+
parser.add_argument(
72+
"--add_superglue",
73+
action="store_true",
74+
help="add SuperGlue to the benchmark (requires hloc)",
75+
)
76+
parser.add_argument(
77+
"--measure", default="time", choices=["time", "log-time", "throughput"]
78+
)
79+
parser.add_argument(
80+
"--repeat", "--r", type=int, default=100, help="repetitions of measurements"
81+
)
82+
parser.add_argument(
83+
"--num_keypoints",
84+
nargs="+",
85+
type=int,
86+
default=[256, 512, 1024, 2048, 4096],
87+
help="number of keypoints (list separated by spaces)",
88+
)
89+
parser.add_argument(
90+
"--matmul_precision", default="highest", choices=["highest", "high", "medium"]
91+
)
92+
parser.add_argument(
93+
"--save", default=None, type=str, help="path where figure should be saved"
94+
)
7895
args = parser.parse_intermixed_args()
7996

80-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81-
if args.device != 'auto':
97+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98+
if args.device != "auto":
8299
device = torch.device(args.device)
83100

84-
print('Running benchmark on device:', device)
101+
print("Running benchmark on device:", device)
85102

86-
images = Path('assets')
103+
images = Path("assets")
87104
inputs = {
88-
'easy': (load_image(images / 'DSC_0411.JPG'),
89-
load_image(images / 'DSC_0410.JPG')),
90-
'difficult': (load_image(images / 'sacre_coeur1.jpg'),
91-
load_image(images / 'sacre_coeur2.jpg')),
105+
"easy": (
106+
load_image(images / "DSC_0411.JPG"),
107+
load_image(images / "DSC_0410.JPG"),
108+
),
109+
"difficult": (
110+
load_image(images / "sacre_coeur1.jpg"),
111+
load_image(images / "sacre_coeur2.jpg"),
112+
),
92113
}
93114

94115
configs = {
95-
'LightGlue-full': {
96-
'depth_confidence': -1,
97-
'width_confidence': -1,
116+
"LightGlue-full": {
117+
"depth_confidence": -1,
118+
"width_confidence": -1,
98119
},
99120
# 'LG-prune': {
100121
# 'width_confidence': -1,
101122
# },
102123
# 'LG-depth': {
103124
# 'depth_confidence': -1,
104125
# },
105-
'LightGlue-adaptive': {}
126+
"LightGlue-adaptive": {},
106127
}
107128

108129
if args.compile:
109-
configs = {**configs, **{k+'-compile': v for k, v in configs.items()}}
130+
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
110131

111132
sg_configs = {
112133
# 'SuperGlue': {},
113-
'SuperGlue-fast': {'sinkhorn_iterations': 5}
134+
"SuperGlue-fast": {"sinkhorn_iterations": 5}
114135
}
115136

116137
torch.set_float32_matmul_precision(args.matmul_precision)
@@ -119,89 +140,108 @@ def print_as_table(d, title, cnames):
119140

120141
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
121142
extractor = extractor.eval().to(device)
122-
figsize = (len(inputs)*4.5, 4.5)
143+
figsize = (len(inputs) * 4.5, 4.5)
123144
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
124145
axes = axes if len(inputs) > 1 else [axes]
125-
fig.canvas.manager.set_window_title(f'LightGlue benchmark ({device.type})')
146+
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
126147

127148
for title, ax in zip(inputs.keys(), axes):
128-
ax.set_xscale('log', base=2)
149+
ax.set_xscale("log", base=2)
129150
bases = [2**x for x in range(7, 16)]
130151
ax.set_xticks(bases, bases)
131-
ax.grid(which='major')
132-
if args.measure == 'log-time':
133-
ax.set_yscale('log')
152+
ax.grid(which="major")
153+
if args.measure == "log-time":
154+
ax.set_yscale("log")
134155
yticks = [10**x for x in range(6)]
135156
ax.set_yticks(yticks, yticks)
136157
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
137-
mlabel = [10**x * i if i in [2, 5] else None for x in range(6) for i in range(2, 10)]
158+
mlabel = [
159+
10**x * i if i in [2, 5] else None
160+
for x in range(6)
161+
for i in range(2, 10)
162+
]
138163
ax.set_yticks(mpos, mlabel, minor=True)
139-
ax.grid(which='minor', linewidth=0.2)
164+
ax.grid(which="minor", linewidth=0.2)
140165
ax.set_title(title)
141166

142167
ax.set_xlabel("# keypoints")
143-
if args.measure == 'throughput':
144-
ax.set_ylabel("Throughput [pairs/s]")
168+
if args.measure == "throughput":
169+
ax.set_ylabel("Throughput [pairs/s]")
145170
else:
146171
ax.set_ylabel("Latency [ms]")
147172

148173
for name, conf in configs.items():
149-
print('Run benchmark for:', name)
174+
print("Run benchmark for:", name)
150175
torch.cuda.empty_cache()
151-
matcher = LightGlue(
152-
features='superpoint', flash=not args.no_flash, **conf)
176+
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
153177
if args.no_prune_thresholds:
154178
matcher.pruning_keypoint_thresholds = {
155-
k: -1 for k in matcher.pruning_keypoint_thresholds}
179+
k: -1 for k in matcher.pruning_keypoint_thresholds
180+
}
156181
matcher = matcher.eval().to(device)
157-
if name.endswith('compile'):
182+
if name.endswith("compile"):
158183
import torch._dynamo
184+
159185
torch._dynamo.reset() # avoid buffer overflow
160186
matcher.compile()
161-
for (pair_name, ax) in zip(inputs.keys(), axes):
187+
for pair_name, ax in zip(inputs.keys(), axes):
162188
image0, image1 = [x.to(device) for x in inputs[pair_name]]
163189
runtimes = []
164190
for num_kpts in args.num_keypoints:
165-
extractor.conf['max_num_keypoints'] = num_kpts
191+
extractor.conf["max_num_keypoints"] = num_kpts
166192
feats0 = extractor.extract(image0)
167193
feats1 = extractor.extract(image1)
168-
runtime = measure(matcher,
169-
{'image0': feats0, 'image1': feats1},
170-
device=device, r=args.repeat)['mean']
194+
runtime = measure(
195+
matcher,
196+
{"image0": feats0, "image1": feats1},
197+
device=device,
198+
r=args.repeat,
199+
)["mean"]
171200
results[pair_name][name].append(
172-
1000/runtime if args.measure == 'throughput' else runtime)
173-
ax.plot(args.num_keypoints, results[pair_name][name], label=name,
174-
marker='o')
201+
1000 / runtime if args.measure == "throughput" else runtime
202+
)
203+
ax.plot(
204+
args.num_keypoints, results[pair_name][name], label=name, marker="o"
205+
)
175206
del matcher, feats0, feats1
176207

177208
if args.add_superglue:
178209
from hloc.matchers.superglue import SuperGlue
210+
179211
for name, conf in sg_configs.items():
180-
print('Run benchmark for:', name)
212+
print("Run benchmark for:", name)
181213
matcher = SuperGlue(conf)
182214
matcher = matcher.eval().to(device)
183-
for (pair_name, ax) in zip(inputs.keys(), axes):
215+
for pair_name, ax in zip(inputs.keys(), axes):
184216
image0, image1 = [x.to(device) for x in inputs[pair_name]]
185217
runtimes = []
186218
for num_kpts in args.num_keypoints:
187-
extractor.conf['max_num_keypoints'] = num_kpts
219+
extractor.conf["max_num_keypoints"] = num_kpts
188220
feats0 = extractor.extract(image0)
189221
feats1 = extractor.extract(image1)
190222
data = {
191-
'image0': image0[None],
192-
'image1': image1[None],
193-
**{k+'0': v for k, v in feats0.items()},
194-
**{k+'1': v for k, v in feats1.items()}
223+
"image0": image0[None],
224+
"image1": image1[None],
225+
**{k + "0": v for k, v in feats0.items()},
226+
**{k + "1": v for k, v in feats1.items()},
195227
}
196-
data['scores0'] = data['keypoint_scores0']
197-
data['scores1'] = data['keypoint_scores1']
198-
data['descriptors0'] = data['descriptors0'].transpose(-1, -2).contiguous()
199-
data['descriptors1'] = data['descriptors1'].transpose(-1, -2).contiguous()
200-
runtime = measure(matcher, data, device=device, r=args.repeat)['mean']
228+
data["scores0"] = data["keypoint_scores0"]
229+
data["scores1"] = data["keypoint_scores1"]
230+
data["descriptors0"] = (
231+
data["descriptors0"].transpose(-1, -2).contiguous()
232+
)
233+
data["descriptors1"] = (
234+
data["descriptors1"].transpose(-1, -2).contiguous()
235+
)
236+
runtime = measure(matcher, data, device=device, r=args.repeat)[
237+
"mean"
238+
]
201239
results[pair_name][name].append(
202-
1000/runtime if args.measure == 'throughput' else runtime)
203-
ax.plot(args.num_keypoints, results[pair_name][name], label=name,
204-
marker='o')
240+
1000 / runtime if args.measure == "throughput" else runtime
241+
)
242+
ax.plot(
243+
args.num_keypoints, results[pair_name][name], label=name, marker="o"
244+
)
205245
del matcher, data, image0, image1, feats0, feats1
206246

207247
for name, runtimes in results.items():

0 commit comments

Comments
 (0)