Skip to content

Commit 84fc236

Browse files
authored
Merge pull request #331 from cosanlab/bugfix
Fixed many open issues.
2 parents d0974de + c6966df commit 84fc236

File tree

12 files changed

+436
-164
lines changed

12 files changed

+436
-164
lines changed

nltools/data/adjacency.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ class Adjacency(object):
6161
6262
'''
6363

64-
def __init__(self, data=None, Y=None, matrix_type=None, labels=None,
65-
**kwargs):
64+
def __init__(self, data=None, Y=None, matrix_type=None, labels=[], **kwargs):
6665
if matrix_type is not None:
6766
if matrix_type.lower() not in ['distance', 'similarity', 'directed',
6867
'distance_flat', 'similarity_flat',
@@ -126,7 +125,7 @@ def __init__(self, data=None, Y=None, matrix_type=None, labels=None,
126125
else:
127126
self.Y = pd.DataFrame()
128127

129-
if labels is not None:
128+
if labels:
130129
if not isinstance(labels, (list, np.ndarray)):
131130
raise ValueError("Make sure labels is a list or numpy array.")
132131
if self.is_single_matrix:
@@ -147,7 +146,7 @@ def __init__(self, data=None, Y=None, matrix_type=None, labels=None,
147146
raise ValueError("All lists of labels must be same length as shape of data.")
148147
self.labels = deepcopy(labels)
149148
else:
150-
self.labels = None
149+
self.labels = []
151150

152151
def __repr__(self):
153152
return ("%s.%s(shape=%s, square_shape=%s, Y=%s, is_symmetric=%s,"
@@ -162,7 +161,7 @@ def __repr__(self):
162161

163162
def __getitem__(self, index):
164163
new = self.copy()
165-
if isinstance(index, int):
164+
if isinstance(index, (int, np.integer)):
166165
new.data = np.array(self.data[index, :]).squeeze()
167166
new.is_single_matrix = True
168167
else:
@@ -184,35 +183,41 @@ def __iter__(self):
184183

185184
def __add__(self, y):
186185
new = deepcopy(self)
187-
if isinstance(y, (int, float)):
186+
if isinstance(y, (int, np.integer, float, np.floating)):
188187
new.data = new.data + y
189-
if isinstance(y, Adjacency):
188+
elif isinstance(y, Adjacency):
190189
if self.shape() != y.shape():
191190
raise ValueError('Both Adjacency() instances need to be the '
192191
'same shape.')
193192
new.data = new.data + y.data
193+
else:
194+
raise ValueError('Can only add int, float, or Adjacency')
194195
return new
195196

196197
def __sub__(self, y):
197198
new = deepcopy(self)
198-
if isinstance(y, (int, float)):
199+
if isinstance(y, (int, np.integer, float, np.floating)):
199200
new.data = new.data - y
200-
if isinstance(y, Adjacency):
201+
elif isinstance(y, Adjacency):
201202
if self.shape() != y.shape():
202203
raise ValueError('Both Adjacency() instances need to be the '
203204
'same shape.')
204205
new.data = new.data - y.data
206+
else:
207+
raise ValueError('Can only subtract int, float, or Adjacency')
205208
return new
206209

207210
def __mul__(self, y):
208211
new = deepcopy(self)
209-
if isinstance(y, (int, float)):
212+
if isinstance(y, (int, np.integer, float, np.floating)):
210213
new.data = new.data * y
211-
if isinstance(y, Adjacency):
214+
elif isinstance(y, Adjacency):
212215
if self.shape() != y.shape():
213216
raise ValueError('Both Adjacency() instances need to be the '
214217
'same shape.')
215218
new.data = np.multiply(new.data, y.data)
219+
else:
220+
raise ValueError('Can only multiply int, float, or Adjacency')
216221
return new
217222

218223
@staticmethod
@@ -330,8 +335,8 @@ def plot(self, limit=3, *args, **kwargs):
330335
''' Create Heatmap of Adjacency Matrix'''
331336

332337
if self.is_single_matrix:
333-
f, a = plt.subplots(nrows=1, figsize=(7, 5))
334-
if self.labels is None:
338+
_, a = plt.subplots(nrows=1, figsize=(7, 5))
339+
if not self.labels:
335340
sns.heatmap(self.squareform(), square=True, ax=a,
336341
*args, **kwargs)
337342
else:
@@ -341,8 +346,8 @@ def plot(self, limit=3, *args, **kwargs):
341346
*args, **kwargs)
342347
else:
343348
n_subs = np.minimum(len(self), limit)
344-
f, a = plt.subplots(nrows=n_subs, figsize=(7, len(self)*5))
345-
if self.labels is None:
349+
_, a = plt.subplots(nrows=n_subs, figsize=(7, len(self)*5))
350+
if not self.labels:
346351
for i in range(n_subs):
347352
sns.heatmap(self[i].squareform(), square=True, ax=a[i],
348353
*args, **kwargs)
@@ -352,7 +357,7 @@ def plot(self, limit=3, *args, **kwargs):
352357
xticklabels=self.labels[i],
353358
yticklabels=self.labels[i],
354359
ax=a[i], *args, **kwargs)
355-
return f
360+
return
356361

357362
def mean(self, axis=0):
358363
''' Calculate mean of Adjacency
@@ -548,18 +553,18 @@ def _convert_data_similarity(data, perm_type=None, ignore_diagonal=ignore_diagon
548553
metric=metric, n_permute=n_permute,
549554
**kwargs) for x in self]
550555

551-
def distance(self, method='correlation', **kwargs):
556+
def distance(self, metric='correlation', **kwargs):
552557
''' Calculate distance between images within an Adjacency() instance.
553558
554559
Args:
555-
method: (str) type of distance metric (can use any scikit learn or
560+
metric: (str) type of distance metric (can use any scikit learn or
556561
sciypy metric)
557562
558563
Returns:
559564
dist: (Adjacency) Outputs a 2D distance matrix.
560565
561566
'''
562-
return Adjacency(pairwise_distances(self.data, metric=method, **kwargs),
567+
return Adjacency(pairwise_distances(self.data, metric=metric, **kwargs),
563568
matrix_type='distance')
564569

565570
def threshold(self, upper=None, lower=None, binarize=False):
@@ -611,7 +616,7 @@ def to_graph(self):
611616
G = nx.DiGraph(self.squareform())
612617
else:
613618
G = nx.Graph(self.squareform())
614-
if self.labels is not None:
619+
if self.labels:
615620
labels = {x: y for x, y in zip(G.nodes, self.labels)}
616621
nx.relabel_nodes(G, labels, copy=False)
617622
return G
@@ -687,7 +692,7 @@ def plot_label_distance(self, labels=None, ax=None):
687692
palette={"Within": "lightskyblue", "Between": "red"}, ax=ax)
688693
f.set_ylabel('Average Distance')
689694
f.set_title('Average Group Distance')
690-
return f
695+
return
691696

692697
def stats_label_distance(self, labels=None, n_permute=5000, n_jobs=-1):
693698
''' Calculate permutation tests on within and between label distance.
@@ -745,10 +750,9 @@ def plot_silhouette(self, labels=None, ax=None, permutation_test=True,
745750
if len(labels) != distance.shape[0]:
746751
raise ValueError('Labels must be same length as distance matrix')
747752

748-
(f, outAll) = plot_silhouette(distance, labels, ax=None,
753+
return plot_silhouette(distance, pd.Series(labels), ax=None,
749754
permutation_test=True,
750755
n_permute=5000, **kwargs)
751-
return (f, outAll)
752756

753757
def bootstrap(self, function, n_samples=5000, save_weights=False,
754758
n_jobs=-1, random_state=None, *args, **kwargs):
@@ -779,20 +783,19 @@ def bootstrap(self, function, n_samples=5000, save_weights=False,
779783
bootstrapped = Adjacency(bootstrapped)
780784
return summarize_bootstrap(bootstrapped, save_weights=save_weights)
781785

782-
def plot_mds(self, n_components=2, metric=True, labels_color=None,
786+
def plot_mds(self, n_components=2, metric=True, labels=None, labels_color=None,
783787
cmap=plt.cm.hot_r, n_jobs=-1, view=(30, 20),
784788
figsize=[12, 8], ax=None, *args, **kwargs):
785789
''' Plot Multidimensional Scaling
786790
787791
Args:
788792
n_components: (int) Number of dimensions to project (can be 2 or 3)
789793
metric: (bool) Perform metric or non-metric dimensional scaling; default
794+
labels: (list) Can override labels stored in Adjacency Class
790795
labels_color: (str) list of colors for labels, if len(1) then make all same color
791796
n_jobs: (int) Number of parallel jobs
792797
view: (tuple) view for 3-Dimensional plot; default (30,20)
793798
794-
Returns:
795-
fig: returns matplotlib figure
796799
'''
797800

798801
if self.matrix_type != 'distance':
@@ -801,10 +804,15 @@ def plot_mds(self, n_components=2, metric=True, labels_color=None,
801804
raise ValueError("MDS only works on single matrices.")
802805
if n_components not in [2, 3]:
803806
raise ValueError('Cannot plot {0}-d image'.format(n_components))
807+
if labels is not None:
808+
if len(labels) != self.square_shape()[0]:
809+
raise ValueError("Make sure labels matches the same shape as Adjaency data")
810+
else:
811+
labels = self.labels
804812
if labels_color is not None:
805-
if self.labels is None:
813+
if len(labels) == 0:
806814
raise ValueError("Make sure that Adjacency object has labels specified.")
807-
if len(self.labels) != len(labels_color):
815+
if len(labels) != len(labels_color):
808816
raise ValueError("Length of labels_color must match self.labels.")
809817

810818
# Run MDS
@@ -814,7 +822,6 @@ def plot_mds(self, n_components=2, metric=True, labels_color=None,
814822

815823
# Create Plot
816824
if ax is None: # Create axis
817-
returnFig = True
818825
fig = plt.figure(figsize=figsize)
819826
if n_components == 3:
820827
ax = fig.add_subplot(111, projection='3d')
@@ -830,21 +837,18 @@ def plot_mds(self, n_components=2, metric=True, labels_color=None,
830837

831838
# Plot labels
832839
if labels_color is None:
833-
labels_color = ['black'] * len(self.labels)
840+
labels_color = ['black'] * len(labels)
834841
if n_components == 3:
835-
for ((x, y, z), label, color) in zip(proj, self.labels, labels_color):
842+
for ((x, y, z), label, color) in zip(proj, labels, labels_color):
836843
ax.text(x, y, z, label, color='white', bbox=dict(facecolor=color, alpha=1, boxstyle="round,pad=0.3"))
837844
else:
838-
for ((x, y), label, color) in zip(proj, self.labels, labels_color):
845+
for ((x, y), label, color) in zip(proj, labels, labels_color):
839846
ax.text(x, y, label, color='white', # color,
840847
bbox=dict(facecolor=color, alpha=1, boxstyle="round,pad=0.3"))
841848

842849
ax.xaxis.set_visible(False)
843850
ax.yaxis.set_visible(False)
844851

845-
if returnFig:
846-
return fig
847-
848852
def distance_to_similarity(self, beta=1):
849853
'''Convert distance matrix to similarity matrix
850854
@@ -918,14 +922,15 @@ def regress(self, X, mode='ols', **kwargs):
918922
stats['beta'].data, stats['t'].data, stats['p'].data = b.squeeze(), t.squeeze(), p.squeeze()
919923
stats['residual'] = self.copy()
920924
stats['residual'].data = res
925+
stats['df'] = df
921926
else:
922927
raise ValueError('X must be a Design_Matrix or Adjacency Instance.')
923928

924929
return stats
925930

926931
def social_relations_model(self, summarize_results=True, nan_replace=True):
927-
'''Estimate the social relations model from a matrix for a round-robin design
928-
932+
'''Estimate the social relations model from a matrix for a round-robin design.
933+
929934
X_{ij} = m + \alpha_i + \beta_j + g_{ij} + \episolon_{ijl}
930935
931936
where X_{ij} is the score for person i rating person j, m is the group mean,
@@ -1133,7 +1138,7 @@ def fix_missing(data):
11331138
if data.is_single_matrix:
11341139
X, coord = fix_missing(data)
11351140
else:
1136-
X = []; coord = [];
1141+
X = []; coord = []
11371142
for d in data:
11381143
m, c = fix_missing(d)
11391144
X.append(m)

0 commit comments

Comments
 (0)