This repository has been archived by the owner on Feb 15, 2021. It is now read-only.
forked from mlcommons/inference
-
Notifications
You must be signed in to change notification settings - Fork 2
/
QSL.py
68 lines (55 loc) · 2.48 KB
/
QSL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), "pytorch"))
from parts.manifest import Manifest
from parts.segment import AudioSegment
import numpy as np
import mlperf_loadgen as lg
class AudioQSL:
def __init__(self, dataset_dir, manifest_filepath, labels,
sample_rate=16000, perf_count=None):
m_paths = [manifest_filepath]
self.manifest = Manifest(dataset_dir, m_paths, labels, len(labels),
normalize=True, max_duration=15.0)
self.sample_rate = sample_rate
self.count = len(self.manifest)
perf_count = self.count if perf_count is None else perf_count
self.sample_id_to_sample = {}
self.qsl = lg.ConstructQSL(self.count, perf_count,
self.load_query_samples,
self.unload_query_samples)
print(
"Dataset loaded with {0:.2f} hours. Filtered {1:.2f} hours. Number of samples: {2}".format(
self.manifest.duration / 3600,
self.manifest.filtered_duration / 3600,
self.count))
def load_query_samples(self, sample_list):
for sample_id in sample_list:
self.sample_id_to_sample[sample_id] = self._load_sample(sample_id)
def unload_query_samples(self, sample_list):
for sample_id in sample_list:
del self.sample_id_to_sample[sample_id]
def _load_sample(self, index):
sample = self.manifest[index]
segment = AudioSegment.from_file(sample['audio_filepath'][0],
target_sr=self.sample_rate)
waveform = segment.samples
assert isinstance(waveform, np.ndarray) and waveform.dtype == np.float32
return waveform
def __getitem__(self, index):
return self.sample_id_to_sample[index]
def __del__(self):
lg.DestroyQSL(self.qsl)
print("Finished destroying QSL.")
# We have no problem fitting all data in memory, so we do that, in
# order to speed up execution of the benchmark.
class AudioQSLInMemory(AudioQSL):
def __init__(self, dataset_dir, manifest_filepath, labels,
sample_rate=16000, perf_count=None):
super().__init__(dataset_dir, manifest_filepath, labels,
sample_rate, perf_count)
super().load_query_samples(range(self.count))
def load_query_samples(self, sample_list):
pass
def unload_query_samples(self, sample_list):
pass