Skip to content

Commit 97f6254

Browse files
authored
Merge pull request #280 from cosanlab/arith
added vector multiplication to Brain_Data
2 parents 8bc15f9 + b355d46 commit 97f6254

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

nltools/data/brain_data.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,26 +235,66 @@ def __add__(self, y):
235235
new.data = new.data + y.data
236236
return new
237237

238+
def __radd__(self, y):
239+
new = deepcopy(self)
240+
if isinstance(y, (int, float)):
241+
new.data = y + new.data
242+
elif isinstance(y, Brain_Data):
243+
if self.shape() != y.shape():
244+
raise ValueError("Both Brain_Data() instances need to be the "
245+
"same shape.")
246+
new.data = y.data + new.data
247+
return new
248+
238249
def __sub__(self, y):
239250
new = deepcopy(self)
240251
if isinstance(y, (int, float)):
241252
new.data = new.data - y
242-
if isinstance(y, Brain_Data):
253+
elif isinstance(y, Brain_Data):
243254
if self.shape() != y.shape():
244255
raise ValueError('Both Brain_Data() instances need to be the '
245256
'same shape.')
246257
new.data = new.data - y.data
247258
return new
248259

260+
def __rsub__(self, y):
261+
new = deepcopy(self)
262+
if isinstance(y, (int, float)):
263+
new.data = y - new.data
264+
elif isinstance(y, Brain_Data):
265+
if self.shape() != y.shape():
266+
raise ValueError('Both Brain_Data() instances need to be the '
267+
'same shape.')
268+
new.data = y.data - new.data
269+
return new
270+
249271
def __mul__(self, y):
250272
new = deepcopy(self)
251273
if isinstance(y, (int, float)):
252274
new.data = new.data * y
253-
if isinstance(y, Brain_Data):
275+
elif isinstance(y, Brain_Data):
254276
if self.shape() != y.shape():
255277
raise ValueError("Both Brain_Data() instances need to be the "
256278
"same shape.")
257279
new.data = np.multiply(new.data, y.data)
280+
elif isinstance(y, (list, np.ndarray, np.array)):
281+
if len(y) != len(self):
282+
raise ValueError('Vector multiplication requires that the '
283+
'length of the vector match the number of '
284+
'images in Brain_Data instance.')
285+
else:
286+
new.data = np.dot(new.data.T, y).T
287+
return new
288+
289+
def __rmul__(self, y):
290+
new = deepcopy(self)
291+
if isinstance(y, (int, float)):
292+
new.data = y * new.data
293+
elif isinstance(y, Brain_Data):
294+
if self.shape() != y.shape():
295+
raise ValueError("Both Brain_Data() instances need to be the "
296+
"same shape.")
297+
new.data = np.multiply(y.data, new.data)
258298
return new
259299

260300
def __iter__(self):

nltools/tests/test_brain_data.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,24 @@ def test_sum(sim_brain_data):
6868
def test_add(sim_brain_data):
6969
new = sim_brain_data + sim_brain_data
7070
assert new.shape() == shape_2d
71-
71+
value = 10
72+
assert(value + sim_brain_data[0]).mean() == (sim_brain_data[0] + value).mean()
7273

7374
def test_subtract(sim_brain_data):
7475
new = sim_brain_data - sim_brain_data
7576
assert new.shape() == shape_2d
76-
77+
value = 10
78+
assert (-value-(-1)*sim_brain_data[0]).mean() == (sim_brain_data[0]-value).mean()
7779

7880
def test_multiply(sim_brain_data):
7981
new = sim_brain_data * sim_brain_data
8082
assert new.shape() == shape_2d
81-
83+
value = 10
84+
assert(value * sim_brain_data[0]).mean() == (sim_brain_data[0] * value).mean()
85+
c1 = [.5, .5, -.5, -.5]
86+
new = sim_brain_data[0:4]*c1
87+
new2 = sim_brain_data[0]*.5 + sim_brain_data[1]*.5 - sim_brain_data[2]*.5 - sim_brain_data[3]*.5
88+
np.testing.assert_almost_equal(0, (new-new2).sum(), decimal=5)
8289

8390
def test_indexing(sim_brain_data):
8491
index = [0, 3, 1]

0 commit comments

Comments
 (0)