Skip to content

Commit

Permalink
dataset: support for directly loading dataset zip files
Browse files Browse the repository at this point in the history
  • Loading branch information
msmk0 committed Apr 23, 2018
1 parent 75a3c03 commit efaf163
Showing 1 changed file with 49 additions and 14 deletions.
63 changes: 49 additions & 14 deletions trackml/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
__authors__ = ['Moritz Kiehn', 'Sabrina Amrouche']

import glob
import os
import os.path as op
import re
import zipfile

import pandas

Expand Down Expand Up @@ -50,6 +53,7 @@
'particles': PARTICLES_DTYPES,
'truth': TRUTH_DTYPES,
}
DEFAULT_PARTS = ['hits', 'cells', 'particles', 'truth']

def _load_event_data(prefix, name):
"""Load per-event data for one single type, e.g. hits, or particles.
Expand Down Expand Up @@ -84,7 +88,7 @@ def load_event_truth(prefix):
"""
return _load_event_data(prefix, 'truth')

def load_event(prefix, parts=['hits', 'cells', 'particles', 'truth']):
def load_event(prefix, parts=DEFAULT_PARTS):
"""Load data for a single event with the given prefix.
Parameters
Expand All @@ -103,13 +107,13 @@ def load_event(prefix, parts=['hits', 'cells', 'particles', 'truth']):
"""
return tuple(_load_event_data(prefix, name) for name in parts)

def load_dataset(path, skip=None, nevents=None, **kw):
"""Provide an iterator over (all) events in a dataset directory.
def load_dataset(path, skip=None, nevents=None, parts=DEFAULT_PARTS):
"""Provide an iterator over (all) events in a dataset.
Parameters
----------
path : str or pathlib.Path
Path to the dataset directory.
Path to a directory or a zip file containing event files.
skip : int, optional
Skip the first `skip` events.
nevents : int, optional
Expand All @@ -124,13 +128,44 @@ def load_dataset(path, skip=None, nevents=None, **kw):
*data
Event data element as specified in `parts`.
"""
files = glob.glob(op.join(path, 'event*-*'))
names = set(op.basename(_).split('-', 1)[0] for _ in files)
names = sorted(names)
if skip is not None:
names = names[skip:]
if nevents is not None:
names = names[:nevents]
for name in names:
event_id = int(name[5:])
yield (event_id,) + load_event(op.join(path, name), **kw)
# extract a sorted list of event file prefixes.
def list_prefixes(files):
regex = re.compile('^event\d{9}-[a-zA-Z]+.csv')
files = filter(regex.match, files)
prefixes = set(op.basename(_).split('-', 1)[0] for _ in files)
prefixes = sorted(prefixes)
if skip is not None:
prefixes = prefixes[skip:]
if nevents is not None:
prefixes = prefixes[:nevents]
return prefixes

# TODO use yield from when we increase the python requirement
if op.isdir(path):
for x in _iter_dataset_dir(path, list_prefixes(os.listdir(path)), parts):
yield x
else:
with zipfile.ZipFile(path, mode='r') as z:
for x in _iter_dataset_zip(z, list_prefixes(z.namelist()), parts):
yield x

def _extract_event_id(prefix):
"""Extract event_id from prefix, e.g. event_id=1 from `event000000001`.
"""
return int(prefix[5:])

def _iter_dataset_dir(directory, prefixes, parts):
"""Iterate over selected events files inside a directory.
"""
for p in prefixes:
yield (_extract_event_id(p),) + load_event(op.join(directory, p), parts)

def _iter_dataset_zip(zipfile, prefixes, parts):
""""Iterate over selected event files inside a zip archive.
"""
for p in prefixes:
files = [zipfile.open('{}-{}.csv'.format(p, _), mode='r') for _ in parts]
dtypes = [DTYPES[_] for _ in parts]
data = tuple(pandas.read_csv(f, header=0, index_col=False, dtype=d)
for f, d in zip(files, dtypes))
yield (_extract_event_id(p),) + data

0 comments on commit efaf163

Please sign in to comment.