forked from brendonw1/KilosortWrapper
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathPhyAutoClustering.m
218 lines (194 loc) · 7.98 KB
/
PhyAutoClustering.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
function PhyAutoClustering(clusteringpath,varargin)
% INPUT:
% clusteringpath: char
%
% Optional:
% AutoClustering(clusteringpath,elec,dim)
% where dim is the number of channels in electro group (if not
% defined, will read the first line of the fet file
%
% Requirements:
% CCGHeart has to be compiled. Go to the private folder of the wrapper and type:
% mex -O CCGHeart.c
%
%
% PhyAutoClustering is cleaning the output of Kilosort and labels the units accordingly:
% 1. Removing spikes with large artifacts:
% Uses the amplitude vector and removes spikes with an amplitude larger
% than amplitude_thr, where the spikes are convoluted to get time points
% with greater general amplitude than the amplitude_thr. Artifact spikes
% are grouped and labeled 'artifacts'.
%
% 2. Mahal artifact removal
% Uses the private PCAs and removes any spikes with a larger mahal
% distance than mahal_thr. Removes spikes are labeled as 'mua'.
% 3. Determines MUA (labeled 'mua')
%
% 4. Removes noise artifacts (labeled 'noise')
% Sorts out units which have no clear refractory period (based on Hill,
% Mehta and Kleinfeld, J Neurosci., 2012). Threshold can be set in the
% parameter section of this file ("Rogue spike threshold"). Then, it
% separates electrical artifats from MUA based on the assumption that
% electrical artifacts are highly correlated on the different channels:
% the average waveform of at least one channel has to be different from
% the across-channel average waveform by a certrain amount of total
% variance (can be set in the parameter section, "Deviation from average
% spike threshold") (including units with less than 100 spikes):
%
% 5. Merging potential units based on CCGs
% Once the program has determined which of the clusters are putative
% isolated units, it tries to merge them based on waveform similarity
% (mahalanobis distance) and quality of the refractory period in the new
% merged cluster (or "Inter Common Spike Interval" from MS Fee et al.
% JNeurosci. Meth., 1996).
%
% By Adrien Peyrache, Peter Petersen & Yuta Senzai
% Refractory period in msec
tR = 1.5; % 1.5
% Censored period in msec (specific to the spike detection process)
tC = 0.85;
% Rogue spike threshold (for MUA); value between 0 an 1
%rogThres = 0.25;
rogThres = 0.33;
% Relative deviation (from 0 to 1) from average spike threshold (for electrical artifacts)
%devThres = 0.25;
% =1000 => bypass it
% devThres = 1000;
rThres = 0.7;
mprThres = 2;
% Artifact removal threshold
amplitude_thr = 50;
mahal_thr = 18;
% Load spike timing
cd(clusteringpath)
dirname = ['PhyAutoClustering_', datestr(clock,'yyyy-mm-dd_HHMMSS')];
mkdir(dirname)
copyfile(fullfile(clusteringpath, 'spike_clusters.npy'), fullfile(clusteringpath, dirname,'spike_clusters.npy'))
if exist(fullfile(clusteringpath, 'cluster_group.tsv'))
copyfile(fullfile(clusteringpath, 'cluster_group.tsv'), fullfile(clusteringpath, dirname, 'cluster_group.tsv'))
end
clu = readNPY(fullfile(clusteringpath, 'spike_clusters.npy'));
clu = double(clu);
cids = unique(clu);
wav_all_orig = readNPY(fullfile(clusteringpath,'templates.npy'));
wav_all_orig2 = permute(wav_all_orig,[2,3,1]);
ch_indx = [];
for i = 1:size(wav_all_orig2,3)
[~,ch_indx(i)] = max(max(wav_all_orig2(:,:,i))-min(wav_all_orig2(:,:,i)));
end
channel_shanks = readNPY(fullfile(clusteringpath, 'channel_shanks.npy'));
ch_indx2 = {};
for j = unique(channel_shanks)
ch_indx2{j} = find(channel_shanks == j);
end
% Removing spikes with large artifacts
disp('Removing spikes with large artifacts')
spike_amplitudes = readNPY(fullfile(clusteringpath, 'amplitudes.npy'));
spike_amplitudes = nanconv(spike_amplitudes',ones(1,20),'edge');
indx = find(spike_amplitudes>amplitude_thr);
disp([num2str(length(indx)),' artifacts detected'])
clu2 = clu;
artifact_clusters = [];
for j = unique(channel_shanks)
indx22 = find(ismember(ch_indx(clu(indx)+1),ch_indx2{j}));
clu2(indx(indx22)) = max(clu)+j;
artifact_clusters = [artifact_clusters,max(clu)+j];
end
clu = clu2;
spike_PCAs = double(readNPY(fullfile(clusteringpath, 'pc_features.npy')));
% Mahal artifact removal
disp('Removing outliers by Mahalanobis theshold...')
spike_clusters = clu;
mahal_outlier_clusters = [];
spikes_removed = 0;
for i = 1:length(cids)
cluster_id = cids(i);
indexes = find(spike_clusters==cluster_id);
if length(indexes)>100
indexes1 = spike_PCAs(indexes,:,:);
indexes2 = reshape(indexes1,[size(indexes1,1),size(indexes1,2)*size(indexes1,3)]);
test2 = mahal(indexes2,indexes2);
test3 = find(test2>mahal_thr^2);
mahal_outlier_clusters = [mahal_outlier_clusters,max(spike_clusters)+1];
spikes_removed = spikes_removed+length(test3);
spike_clusters(indexes(test3)) = mahal_outlier_clusters(end);
end
end
clu = spike_clusters;
disp([num2str(length(mahal_outlier_clusters)),' units cleaned by Mahal outlier detection. Spikes removed: ',num2str(spikes_removed)])
writeNPY(uint32(clu), fullfile(clusteringpath, 'spike_clusters.npy'));
% Loading rez.mat for sampling rate
disp('Loading rez.mat')
load(fullfile(clusteringpath,'rez.mat'))
sr = rez.ops.fs;
res_int = readNPY(fullfile(clusteringpath,'spike_times.npy'));
res = double(res_int)/sr;
wav_all = wav_all_orig2;
disp('Classifying noise/mua')
meanR = [];
fractRogue = [];
maxPwRatio = [];
for ii=1:length(cids)
spktime = res(clu==cids(ii));
if ~isempty(spktime)
% dim = channel_shanks(ch_indx(cids(ii)+1));
wav = squeeze(wav_all(13:end,:,ii));
wav = wav(:,find(any(wav)));
dim = size(wav,2);
[R,~] = corrcoef(wav);
meanR_cur = (sum(sum(R)) - dim) /(dim*(dim-1));
meanR = [meanR; meanR_cur];
maxPwRatio_cur = max(abs(wav(11,:)))/mean(abs(wav(11,:)));
if isempty(maxPwRatio_cur)
maxPwRatio = [maxPwRatio; 0];
end
maxPwRatio = [maxPwRatio; maxPwRatio_cur];
[ccgR,t] = CCG(spktime,ones(size(spktime)),'binsize',.0005,'duration',.06);
indx3 = find(t > -0.0015 & t < 0.0015);
spkRef = mean(ccgR(indx3)); % refractory period: -1.5ms to 1.5ms
spkMean = mean(ccgR(round(indx3(1)/2):indx3(1)-1));
% l = FractionRogueSpk(spktime,tR,tC);
l = spkRef/spkMean;
fractRogue = [fractRogue;l];
else
maxPwRatio = [maxPwRatio; 0];
meanR = [meanR; 0];
fractRogue = [fractRogue;0];
end
end
% Here we compute # of spike per cell. Some code for the errormatrix fails
% when the cluster is defined by only a few samples. We'll put a
% threshopld a bit later on the total # of spikes.
h = hist(clu,unique(clu));
h = h(:);
h = h(1:length(meanR));
% Definition of cluster 0 (noiseIx) and cluster 1 (muaIx)
% Outliers of total spike power (putative electrical artifacts) not imlemented yet
noiseIx = find((meanR >= rThres & maxPwRatio < mprThres)|h<100);
muaIx = find(fractRogue>rogThres & ~(meanR >= rThres & maxPwRatio < mprThres) & h>=100);
goodIx = find(fractRogue<=rogThres & ~(meanR >= rThres & maxPwRatio < mprThres) & h>=100); % 100 or samlenum
% Saving clusters to cluster_group.tsv (Phy)
fid = fopen(fullfile(clusteringpath,'cluster_group.tsv'),'w');
fwrite(fid, sprintf('cluster_id\t%s\r\n', 'group'));
for ii=1:length(cids)
if any(clu==cids(ii))
if any(goodIx==ii)
% fwrite(fid, sprintf('%d\t%s\r\n', cids(ii), 'good'));
elseif any(muaIx==ii)
fwrite(fid, sprintf('%d\t%s\r\n', cids(ii), 'mua'));
elseif any(noiseIx==ii)
fwrite(fid, sprintf('%d\t%s\r\n', cids(ii), 'noise'));
end
end
end
for jj = 1:length(artifact_clusters)
fwrite(fid, sprintf('%d\t%s\r\n', artifact_clusters(jj), 'artifacts'));
end
mahal_outlier_clusters = unique(mahal_outlier_clusters);
for jj = 1:length(mahal_outlier_clusters)
fwrite(fid, sprintf('%d\t%s\r\n', mahal_outlier_clusters(jj), 'mua'));
end
fclose(fid);
save(fullfile(clusteringpath,'autoclusta_params.mat'),'meanR','maxPwRatio','fractRogue','noiseIx','muaIx');
disp('AutoClustering complete.')
end