Skip to content

Commit 106ca36

Browse files
authored
Merge pull request #221 from ljchang/fix_perm_tail
Fix perm tail Former-commit-id: 22df5e5
2 parents 7253dc8 + f29a9a1 commit 106ca36

File tree

4 files changed

+64
-36
lines changed

4 files changed

+64
-36
lines changed

.travis.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ sudo: false
55
python:
66
- "2.7"
77
- "3.6"
8-
8+
99
before_script:
1010
- "export DISPLAY=:99.0"
1111
- "sh -e /etc/init.d/xvfb start"
@@ -25,6 +25,7 @@ install:
2525
- pip install -r requirements.txt
2626
- pip install -r optional-dependencies.txt
2727
- python setup.py install
28+
- pip install git+https://github.com/nilearn/nilearn --upgrade
2829

2930
script: coverage run --source nltools -m py.test
3031

nltools/stats.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,32 @@ def _permute_group(data, random_state=None):
395395
return (np.mean(data.loc[perm_label==1, 'Values']) -
396396
np.mean(data.loc[perm_label==0, 'Values']))
397397

398-
def one_sample_permutation(data, n_permute=5000, n_jobs=-1, random_state=None):
398+
399+
def _calc_pvalue(all_p, stat, tail):
400+
"""Calculates p value based on distribution of correlations
401+
This function is called by the permutation functions
402+
all_p: list of correlation values from permutation
403+
stat: actual value being tested, i.e., stats['correlation'] or stats['mean']
404+
tail: (int) either 2 or 1 for two-tailed p-value or one-tailed
405+
"""
406+
if tail==2:
407+
p= np.mean( np.abs(all_p) >= np.abs(stat))
408+
elif tail==1:
409+
if stat >= 0:
410+
p = np.mean(all_p >= stat)
411+
else:
412+
p = np.mean(all_p <= stat)
413+
else:
414+
raise ValueError('tail must be either 1 or 2')
415+
return p
416+
417+
def one_sample_permutation(data, n_permute=5000, tail=2, n_jobs=-1, random_state=None):
399418
''' One sample permutation test using randomization.
400419
401420
Args:
402421
data: Pandas DataFrame or Series or numpy array
403422
n_permute: (int) number of permutations
423+
tail: (int) either 1 for one-tail or 2 for two-tailed test (default: 2)
404424
n_jobs: (int) The number of CPUs to use to do the computation.
405425
-1 means all CPUs.
406426
@@ -418,20 +438,18 @@ def one_sample_permutation(data, n_permute=5000, n_jobs=-1, random_state=None):
418438

419439
all_p = Parallel(n_jobs=n_jobs)(delayed(_permute_sign)(data,
420440
random_state=seeds[i]) for i in range(n_permute))
421-
if stats['mean'] >= 0:
422-
stats['p'] = np.mean(all_p >= stats['mean'])
423-
else:
424-
stats['p'] = np.mean(all_p <= stats['mean'])
441+
stats['p'] = _calc_pvalue(all_p,stats['mean'],tail)
425442
return stats
426443

427444
def two_sample_permutation(data1, data2, n_permute=5000,
428-
n_jobs=-1, random_state=None):
445+
tail=2, n_jobs=-1, random_state=None):
429446
''' Independent sample permutation test.
430447
431448
Args:
432449
data1: Pandas DataFrame or Series or numpy array
433450
data2: Pandas DataFrame or Series or numpy array
434451
n_permute: (int) number of permutations
452+
tail: (int) either 1 for one-tail or 2 for two-tailed test (default: 2)
435453
n_jobs: (int) The number of CPUs to use to do the computation.
436454
-1 means all CPUs.
437455
Returns:
@@ -451,14 +469,11 @@ def two_sample_permutation(data1, data2, n_permute=5000,
451469
all_p = Parallel(n_jobs=n_jobs)(delayed(_permute_group)(data,
452470
random_state=seeds[i]) for i in range(n_permute))
453471

454-
if stats['mean']>=0:
455-
stats['p'] = np.mean(all_p >= stats['mean'])
456-
else:
457-
stats['p'] = np.mean(all_p <= stats['mean'])
472+
stats['p'] = _calc_pvalue(all_p,stats['mean'],tail)
458473
return stats
459474

460475
def correlation_permutation(data1, data2, n_permute=5000, metric='spearman',
461-
n_jobs=-1, random_state=None):
476+
tail=2, n_jobs=-1, random_state=None):
462477
''' Permute correlation.
463478
464479
Args:
@@ -467,6 +482,7 @@ def correlation_permutation(data1, data2, n_permute=5000, metric='spearman',
467482
n_permute: (int) number of permutations
468483
metric: (str) type of association metric ['spearman','pearson',
469484
'kendall']
485+
tail: (int) either 1 for one-tail or 2 for two-tailed test (default: 2)
470486
n_jobs: (int) The number of CPUs to use to do the computation.
471487
-1 means all CPUs.
472488
@@ -504,10 +520,7 @@ def correlation_permutation(data1, data2, n_permute=5000, metric='spearman',
504520
for i in range(n_permute))
505521
all_p = [x[0] for x in all_p]
506522

507-
if stats['correlation'] >= 0:
508-
stats['p'] = np.mean(all_p >= stats['correlation'])
509-
else:
510-
stats['p'] = np.mean(all_p <= stats['correlation'])
523+
stats['p'] = _calc_pvalue(all_p,stats['correlation'],tail)
511524
return stats
512525

513526
def make_cosine_basis(nsamples, sampling_rate, filter_length, drop=0):

nltools/tests/test_stats.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,48 @@
77
upsample,
88
winsorize,
99
align,
10-
transform_pairwise)
10+
transform_pairwise, _calc_pvalue)
1111
from nltools.simulator import Simulator
1212
from nltools.mask import create_sphere
13+
# import pytest
1314

1415
def test_permutation():
1516
dat = np.random.multivariate_normal([2, 6], [[.5, 2], [.5, 3]], 1000)
1617
x = dat[:, 0]
1718
y = dat[:, 1]
18-
stats = two_sample_permutation(x, y)
19-
assert (stats['mean'] < -2) & (stats['mean'] > -6)
20-
assert stats['p'] < .001
21-
print(stats)
22-
stats = one_sample_permutation(x-y)
23-
assert (stats['mean'] < -2) & (stats['mean'] > -6)
24-
assert stats['p'] < .001
25-
print(stats)
26-
stats = correlation_permutation(x, y, metric='pearson')
27-
assert (stats['correlation'] > .4) & (stats['correlation']<.85)
28-
assert stats['p'] < .001
29-
stats = correlation_permutation(x, y, metric='spearman')
30-
assert (stats['correlation'] > .4) & (stats['correlation']<.85)
31-
assert stats['p'] < .001
32-
stats = correlation_permutation(x, y, metric='kendall')
33-
assert (stats['correlation'] > .4) & (stats['correlation']<.85)
34-
assert stats['p'] < .001
19+
stats = two_sample_permutation(x, y,tail=1)
20+
assert (stats['mean'] < -2) & (stats['mean'] > -6) & (stats['p'] < .001)
21+
stats = one_sample_permutation(x-y,tail=1)
22+
assert (stats['mean'] < -2) & (stats['mean'] > -6) & (stats['p'] < .001)
23+
stats = correlation_permutation(x, y, metric='pearson',tail=1)
24+
assert (stats['correlation'] > .4) & (stats['correlation']<.85) & (stats['p'] < .001)
25+
stats = correlation_permutation(x, y, metric='spearman',tail=1)
26+
assert (stats['correlation'] > .4) & (stats['correlation']<.85) & (stats['p'] < .001)
27+
stats = correlation_permutation(x, y, metric='kendall',tail=2)
28+
assert (stats['correlation'] > .4) & (stats['correlation']<.85) & (stats['p'] < .001)
29+
# with pytest.raises(ValueError):
30+
# correlation_permutation(x, y, metric='kendall',tail=3)
31+
# with pytest.raises(ValueError):
32+
# correlation_permutation(x, y, metric='doesntwork',tail=3)
33+
s = np.random.normal(0,1,10000)
34+
two_sided = _calc_pvalue(all_p = s, stat= 1.96, tail = 2)
35+
upper_p = _calc_pvalue(all_p = s, stat= 1.96, tail = 1)
36+
lower_p = _calc_pvalue(all_p = s, stat= -1.96, tail = 1)
37+
sum_p = upper_p + lower_p
38+
np.testing.assert_almost_equal(two_sided, sum_p)
3539

3640
def test_downsample():
3741
dat = pd.DataFrame()
3842
dat['x'] = range(0,100)
3943
dat['y'] = np.repeat(range(1,11),10)
4044
assert((dat.groupby('y').mean().values.ravel() == downsample(data=dat['x'],sampling_freq=10,target=1,target_type='hz',method='mean').values).all)
4145
assert((dat.groupby('y').median().values.ravel() == downsample(data=dat['x'],sampling_freq=10,target=1,target_type='hz',method='median').values).all)
46+
# with pytest.raises(ValueError):
47+
# downsample(data=list(dat['x']),sampling_freq=10,target=1,target_type='hz',method='median')
48+
# with pytest.raises(ValueError):
49+
# downsample(data=dat['x'],sampling_freq=10,target=1,target_type='hz',method='doesnotwork')
50+
# with pytest.raises(ValueError):
51+
# downsample(data=dat['x'],sampling_freq=10,target=1,target_type='doesnotwork',method='median')
4252

4353
def test_upsample():
4454
dat = pd.DataFrame()
@@ -50,6 +60,10 @@ def test_upsample():
5060
fs = 3
5161
us = upsample(dat,sampling_freq=1,target=fs,target_type='hz')
5262
assert(dat.shape[0]*fs-fs == us.shape[0])
63+
# with pytest.raises(ValueError):
64+
# upsample(dat,sampling_freq=1,target=fs,target_type='hz',method='doesnotwork')
65+
# with pytest.raises(ValueError):
66+
# upsample(dat,sampling_freq=1,target=fs,target_type='doesnotwork',method='linear')
5367

5468
def test_winsorize():
5569
outlier_test = pd.DataFrame([92, 19, 101, 58, 1053, 91, 26, 78, 10, 13,

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
nibabel>=2.0.1
2-
scikit-learn>=0.18.1
2+
scikit-learn>=0.19.1
33
nilearn>=0.4
44
pandas>=0.20
55
numpy>=1.9
66
seaborn>=0.7.0
7-
matplotlib>=2.1
7+
matplotlib>=2.2.0
88
scipy
99
six
1010
pynv

0 commit comments

Comments
 (0)