Skip to content

Commit

Permalink
Merge pull request #68 from CoffeaTeam/accumulator
Browse files Browse the repository at this point in the history
Accumulator tests and updates
  • Loading branch information
nsmith- authored Apr 12, 2019
2 parents 9d70bca + cbd814a commit 002575b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ python:
env:
matrix:
- AWKWARD="awkward>=0.8.1"
- AWKWARD="awkward-numba"
# - AWKWARD="awkward-numba"
matrix:
addons:
apt:
Expand Down
25 changes: 17 additions & 8 deletions fnal_column_analysis_tools/processor/accumulator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from six import with_metaclass
from abc import ABCMeta, abstractmethod
import collections
from collections import defaultdict

try:
from collections.abc import Set
from collections.abc import Set, Mapping
except ImportError:
from collections import Set
from collections import Set, Mapping


class AccumulatorABC(with_metaclass(ABCMeta)):
Expand All @@ -15,7 +15,7 @@ class AccumulatorABC(with_metaclass(ABCMeta)):
such that self + self.identity() == self
add(other): adds an object of same type as self to self
Concrete implementations are provided for __add__, __iadd__
Concrete implementations are provided for __add__, __radd__, __iadd__
'''
@abstractmethod
def identity(self):
Expand All @@ -31,6 +31,12 @@ def __add__(self, other):
ret.add(other)
return ret

def __radd__(self, other):
ret = self.identity()
ret.add(other)
ret.add(self)
return ret

def __iadd__(self, other):
self.add(other)
return self
Expand All @@ -48,7 +54,7 @@ def identity(self):
return accumulator(self._identity)

def add(self, other):
if isinstance(other, AccumulatorABC):
if isinstance(other, accumulator):
self.value += other.value
else:
self.value += other
Expand Down Expand Up @@ -80,16 +86,19 @@ def identity(self):
return ret

def add(self, other):
if isinstance(other, dict_accumulator):
if isinstance(other, Mapping):
for key, value in other.items():
if key not in self:
self[key] = value.identity()
if isinstance(value, AccumulatorABC):
self[key] = value.identity()
else:
raise ValueError
self[key] += value
else:
raise ValueError


class defaultdict_accumulator(collections.defaultdict, AccumulatorABC):
class defaultdict_accumulator(defaultdict, AccumulatorABC):
'''
Like a defaultdict but also has accumulator semantics
It is assumed that the contents of the dict have accumulator semantics
Expand Down
44 changes: 44 additions & 0 deletions tests/test_accumulators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import print_function, division
from fnal_column_analysis_tools import processor


def test_accumulators():
a = processor.accumulator(0.)
a += 3.
a += processor.accumulator(2)
assert a.value == 5.

b = processor.set_accumulator({'apples', 'oranges'})
b += {'pears'}
b += 'grapes'
assert b == {'apples', 'oranges', 'pears', 'grapes'}

c = processor.dict_accumulator({'num': a, 'fruit': b})
c['num'] += 2.
c += processor.dict_accumulator({
'num2': processor.accumulator(0),
'fruit': processor.set_accumulator({'apples', 'cherries'}),
})
assert c['num2'].value == 0
assert c['num'].value == 7.
assert c['fruit'] == {'apples', 'oranges', 'pears', 'grapes', 'cherries'}

d = processor.defaultdict_accumulator(lambda: processor.accumulator(0.))
d['x'] = processor.accumulator(0.) + 4
d['y'] += 5.
d['z'] += d['x']
d['x'] += d['y']
assert d['x'].value == 9.
assert d['y'].value == 5.
assert d['z'].value == 4.
assert d['w'].value == 0.

e = d + c

f = processor.defaultdict_accumulator(lambda: 2.)
f['x'] += 4.
assert f['x'] == 6.

f += f
assert f['x'] == 12.
assert f['y'] == 2.

0 comments on commit 002575b

Please sign in to comment.