diff --git a/trackml/dataset.py b/trackml/dataset.py index b19cb1a..a0c9559 100644 --- a/trackml/dataset.py +++ b/trackml/dataset.py @@ -3,10 +3,19 @@ __authors__ = ['Moritz Kiehn', 'Sabrina Amrouche'] import glob +import os import os.path as op +import re +import zipfile import pandas +CELLS_DTYPES = dict([ + ('hit_id', 'i4'), + ('ch0', 'i4'), + ('ch1', 'i4'), + ('value', 'f4'), +]) HITS_DTYPES = dict([ ('hit_id', 'i4'), ('x', 'f4'), @@ -16,12 +25,6 @@ ('layer_id', 'i4'), ('module_id', 'i4'), ]) -CELLS_DTYPES = dict([ - ('hit_id', 'i4'), - ('ch0', 'i4'), - ('ch1', 'i4'), - ('value', 'f4'), -]) PARTICLES_DTYPES = dict([ ('particle_id', 'i8'), ('vx', 'f4'), @@ -44,12 +47,20 @@ ('tpz', 'f4'), ('weight', 'f4'), ]) - -def _load_event_data(prefix, name, dtype): +DTYPES = { + 'cells': CELLS_DTYPES, + 'hits': HITS_DTYPES, + 'particles': PARTICLES_DTYPES, + 'truth': TRUTH_DTYPES, +} +DEFAULT_PARTS = ['hits', 'cells', 'particles', 'truth'] + +def _load_event_data(prefix, name): """Load per-event data for one single type, e.g. hits, or particles. """ expr = '{!s}-{}.csv*'.format(prefix, name) files = glob.glob(expr) + dtype = DTYPES[name] if len(files) == 1: return pandas.read_csv(files[0], header=0, index_col=False, dtype=dtype) elif len(files) == 0: @@ -60,30 +71,24 @@ def _load_event_data(prefix, name, dtype): def load_event_hits(prefix): """Load the hits information for a single event with the given prefix. """ - return _load_event_data(prefix, 'hits', HITS_DTYPES) + return _load_event_data(prefix, 'hits') def load_event_cells(prefix): """Load the hit cells information for a single event with the given prefix. """ - return _load_event_data(prefix, 'cells', CELLS_DTYPES) + return _load_event_data(prefix, 'cells') def load_event_particles(prefix): """Load the particles information for a single event with the given prefix. """ - return _load_event_data(prefix, 'particles', PARTICLES_DTYPES) + return _load_event_data(prefix, 'particles') def load_event_truth(prefix): """Load only the truth information for a single event with the given prefix. """ - return _load_event_data(prefix, 'truth', TRUTH_DTYPES) - -_LOAD_FUNCTIONS = { - 'hits': load_event_hits, - 'cells': load_event_cells, - 'particles': load_event_particles, - 'truth': load_event_truth, } + return _load_event_data(prefix, 'truth') -def load_event(prefix, parts=['hits', 'cells', 'particles', 'truth']): +def load_event(prefix, parts=DEFAULT_PARTS): """Load data for a single event with the given prefix. Parameters @@ -100,15 +105,15 @@ def load_event(prefix, parts=['hits', 'cells', 'particles', 'truth']): element has field names identical to the CSV column names with appropriate types. """ - return tuple(_LOAD_FUNCTIONS[_](prefix) for _ in parts) + return tuple(_load_event_data(prefix, name) for name in parts) -def load_dataset(path, skip=None, nevents=None, **kw): - """Provide an iterator over (all) events in a dataset directory. +def load_dataset(path, skip=None, nevents=None, parts=DEFAULT_PARTS): + """Provide an iterator over (all) events in a dataset. Parameters ---------- path : str or pathlib.Path - Path to the dataset directory. + Path to a directory or a zip file containing event files. skip : int, optional Skip the first `skip` events. nevents : int, optional @@ -123,13 +128,44 @@ def load_dataset(path, skip=None, nevents=None, **kw): *data Event data element as specified in `parts`. """ - files = glob.glob(op.join(path, 'event*-*')) - names = set(op.basename(_).split('-', 1)[0] for _ in files) - names = sorted(names) - if skip is not None: - names = names[skip:] - if nevents is not None: - names = names[:nevents] - for name in names: - event_id = int(name[5:]) - yield (event_id,) + load_event(op.join(path, name), **kw) + # extract a sorted list of event file prefixes. + def list_prefixes(files): + regex = re.compile('^event\d{9}-[a-zA-Z]+.csv') + files = filter(regex.match, files) + prefixes = set(op.basename(_).split('-', 1)[0] for _ in files) + prefixes = sorted(prefixes) + if skip is not None: + prefixes = prefixes[skip:] + if nevents is not None: + prefixes = prefixes[:nevents] + return prefixes + + # TODO use yield from when we increase the python requirement + if op.isdir(path): + for x in _iter_dataset_dir(path, list_prefixes(os.listdir(path)), parts): + yield x + else: + with zipfile.ZipFile(path, mode='r') as z: + for x in _iter_dataset_zip(z, list_prefixes(z.namelist()), parts): + yield x + +def _extract_event_id(prefix): + """Extract event_id from prefix, e.g. event_id=1 from `event000000001`. + """ + return int(prefix[5:]) + +def _iter_dataset_dir(directory, prefixes, parts): + """Iterate over selected events files inside a directory. + """ + for p in prefixes: + yield (_extract_event_id(p),) + load_event(op.join(directory, p), parts) + +def _iter_dataset_zip(zipfile, prefixes, parts): + """"Iterate over selected event files inside a zip archive. + """ + for p in prefixes: + files = [zipfile.open('{}-{}.csv'.format(p, _), mode='r') for _ in parts] + dtypes = [DTYPES[_] for _ in parts] + data = tuple(pandas.read_csv(f, header=0, index_col=False, dtype=d) + for f, d in zip(files, dtypes)) + yield (_extract_event_id(p),) + data