From 40675f0afb97972ecb26c830f1ade30268db1fc4 Mon Sep 17 00:00:00 2001 From: chris-langfield Date: Mon, 25 Sep 2023 18:54:27 +0100 Subject: [PATCH 1/2] filter units --- atlaselectrophysiology/load_data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/atlaselectrophysiology/load_data.py b/atlaselectrophysiology/load_data.py index 05e7e5d..1b015c6 100644 --- a/atlaselectrophysiology/load_data.py +++ b/atlaselectrophysiology/load_data.py @@ -263,6 +263,15 @@ def get_data(self): data['clusters'] = self.one.load_object(self.eid, 'clusters', collection=self.probe_collection, attribute=['metrics', 'peakToTrough', 'waveforms', 'channels']) + + # filter out low spike units + min_firing_rate = 50. / 3600. + include_cond = data['clusters'].metrics["firing_rate"] > min_firing_rate + data['clusters'].metrics = data['clusters'].metrics[include_cond] + include_idx = include_cond.to_numpy() + for k, v in data['clusters'].items(): + data['clusters'][k] = v[include_idx] + data['clusters']['exists'] = True data['channels'] = self.one.load_object(self.eid, 'channels', collection=self.probe_collection, From 3b53981d2c9122d15aa1d45d960ab8c1691adf2c Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 26 Sep 2023 09:42:50 +0100 Subject: [PATCH 2/2] remove clusters from spikes object --- atlaselectrophysiology/load_data.py | 24 +++++++++++++++--------- atlaselectrophysiology/plot_data.py | 4 ++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/atlaselectrophysiology/load_data.py b/atlaselectrophysiology/load_data.py index 1b015c6..c8906c1 100644 --- a/atlaselectrophysiology/load_data.py +++ b/atlaselectrophysiology/load_data.py @@ -5,6 +5,8 @@ from neuropixel import trace_header import ibllib.atlas as atlas from ibllib.qc.alignment_qc import AlignmentQC +from iblutil.numerical import ismember +from iblutil.util import Bunch from one.api import ONE from one.remote import aws from pathlib import Path @@ -259,25 +261,29 @@ def get_data(self): try: data['spikes'] = self.one.load_object(self.eid, 'spikes', collection=self.probe_collection, attribute=['depths', 'amps', 'times', 'clusters']) - data['spikes']['exists'] = True data['clusters'] = self.one.load_object(self.eid, 'clusters', collection=self.probe_collection, attribute=['metrics', 'peakToTrough', 'waveforms', 'channels']) - - # filter out low spike units + + # Remove low firing rate clusters min_firing_rate = 50. / 3600. - include_cond = data['clusters'].metrics["firing_rate"] > min_firing_rate - data['clusters'].metrics = data['clusters'].metrics[include_cond] - include_idx = include_cond.to_numpy() - for k, v in data['clusters'].items(): - data['clusters'][k] = v[include_idx] - + clu_idx = data['clusters'].metrics.firing_rate > min_firing_rate + data['clusters'] = Bunch({k: v[clu_idx] for k, v in data['clusters'].items()}) + spike_idx, ib = ismember(data['spikes'].clusters, data['clusters'].metrics.index) + data['clusters'].metrics.reset_index(drop=True, inplace=True) + data['spikes'] = Bunch({k: v[spike_idx] for k, v in data['spikes'].items()}) + data['spikes'].clusters = data['clusters'].metrics.index[ib].astype(np.int32) + + data['spikes']['exists'] = True data['clusters']['exists'] = True data['channels'] = self.one.load_object(self.eid, 'channels', collection=self.probe_collection, attribute=['rawInd', 'localCoordinates']) data['channels']['exists'] = True + # Set low firing rate clusters to bad + + except alf.exceptions.ALFObjectNotFound: logger.error(f'Could not load spike sorting for probe insertion {self.probe_id}, GUI' f' will not work') diff --git a/atlaselectrophysiology/plot_data.py b/atlaselectrophysiology/plot_data.py index 97043dc..e3dd517 100644 --- a/atlaselectrophysiology/plot_data.py +++ b/atlaselectrophysiology/plot_data.py @@ -1,6 +1,6 @@ from matplotlib import cm import numpy as np -from brainbox.processing import bincount2D +from iblutil.numerical import bincount2D from brainbox.io.spikeglx import Streamer from brainbox.population.decode import xcorr from brainbox.task import passive @@ -596,7 +596,7 @@ def get_autocorr(self, clust_idx): autocorr = xcorr(self.data['spikes']['times'][idx], self.data['spikes']['clusters'][idx], AUTOCORR_BIN_SIZE, AUTOCORR_WIN_SIZE) - return autocorr[0, 0, :], self.clust_id[clust_idx] + return autocorr[0, 0, :], self.data['clusters'].metrics.cluster_id[self.clust_id[clust_idx]] def get_template_wf(self, clust_idx): template_wf = (self.data['clusters']['waveforms'][self.clust_id[clust_idx], :, 0])