Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
mayofaulkner committed Sep 3, 2024
2 parents 49d6829 + a84e54c commit fb7f809
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
8 changes: 4 additions & 4 deletions atlasview/atlasview.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ class ControllerTopView(PgImageController):
"""
TopView ControllerTopView
"""
def __init__(self, qmain: TopView, res: int = 25, volume='image', **kwargs):
def __init__(self, qmain: TopView, res: int = 25, volume='image', atlas=None, **kwargs):
super(ControllerTopView, self).__init__(qmain)
self.volume = volume
self.atlas = AllenAtlas(res)
self.atlas = AllenAtlas(res) if atlas is None else atlas
self.fig_top = self.qwidget = qmain
# Setup Coronal slice: width: ml, height: dv, depth: ap
self.fig_coronal = SliceView(qmain, waxis=0, haxis=2, daxis=1)
Expand Down Expand Up @@ -362,10 +362,10 @@ class ImageLayer:
slice_kwargs: dict = field(default_factory=lambda: {'volume': 'image', 'mode': 'clip'})


def view(res=25, title=None, brainmap='Allen'):
def view(res=25, title=None, atlas=None):
""" application entry point """
qt.create_app()
av = TopView._get_or_create(title=title, res=res, brainmap=brainmap)
av = TopView._get_or_create(title=title, res=res, atlas=atlas)
av.show()
return av

Expand Down
65 changes: 50 additions & 15 deletions viewspikes/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,35 @@
(0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
(0.09019607843137255, 0.7450980392156863, 0.8117647058823529)]

YMAX = 4000
YMIN, YMAX = (-1, 4000)


def get_trial_events_to_display(trials):
errors = trials['feedback_times'][trials['feedbackType'] == -1].values
errors = np.sort(np.r_[errors, errors + .5])
gocue = trials['goCue_times'].values
gocue = np.sort(np.r_[gocue, gocue + .11])
trial_events = dict(
goCue_times=gocue,
error_times=errors,
reward_times=trials['feedback_times'][trials['feedbackType'] == 1].values)
return trial_events


def view_raster(pid, one, stream=True):
from qt import create_app
app = create_app()
ssl = SpikeSortingLoader(one=one, pid=pid)
sl = EphysSessionLoader(one=one, eid=ssl.eid)
sl.load_trials()
spikes, clusters, channels = ssl.load_spike_sorting(dataset_types=['spikes.samples'])

clusters = ssl.merge_clusters(spikes, clusters, channels)
return RasterView(ssl, spikes, clusters, channels, trials=sl.trials, stream=stream)


class RasterView(QtWidgets.QMainWindow):
plotItem_raster: pg.PlotWidget = None

def __init__(self, ssl, spikes, clusters, channels=None, trials=None, stream=True, *args, **kwargs):
self.ssl = ssl
self.spikes = spikes
Expand Down Expand Up @@ -81,18 +97,14 @@ def __init__(self, ssl, spikes, clusters, channels=None, trials=None, stream=Tru
self.plotItem_raster.addItem(self.line_eqc)
################################################## plot trials
if self.trials is not None:
trial_times = dict(
goCue_times=trials['goCue_times'].values,
error_times=trials['feedback_times'][trials['feedbackType'] == -1].values,
reward_times=trials['feedback_times'][trials['feedbackType'] == 1].values)
trial_times = get_trial_events_to_display(trials)
self.trial_lines = {}
for i, k in enumerate(trial_times):
self.trial_lines[k] = pg.PlotCurveItem()
self.plotItem_raster.addItem(self.trial_lines[k])
x = np.tile(trial_times[k][:, np.newaxis], (1, 2)).flatten()
y = np.tile(np.array([0, 1, 1, 0]), int(trial_times[k].shape[0] / 2 + 1))[
:trial_times[k].shape[0] * 2] * YMAX
self.trial_lines[k].setData(x=x.flatten(), y=y.flatten(), pen=pg.mkPen(np.array(SNS_PALETTE[i]) * 256))
y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(trial_times[k].shape[0] / 2 + 1))[:trial_times[k].shape[0] * 2]
self.trial_lines[k].setData(x=x.flatten(), y=y.flatten(), pen=pg.mkPen(np.array(SNS_PALETTE[i]) * 255, width=2))
self.show()

def mouseClick(self, event):
Expand All @@ -101,7 +113,7 @@ def mouseClick(self, event):
return
qxy = self.imageItem_raster.mapFromScene(event.scenePos())
x = qxy.x()
self.show_ephys(t0=self.rtimes[int(x - .5)])
self.show_ephys(t0=self.rtimes[int(x - 1)])
ymax = np.max(self.depths) + 50
self.line_eqc.setData(x=x + np.array([-.5, -.5, .5, .5]),
y=np.array([0, ymax, ymax, 0]),
Expand All @@ -124,8 +136,13 @@ def keyPressEvent(self, e):
m == QtCore.Qt.ControlModifier and k == QtCore.Qt.Key_Z):
self.imageItem_raster.setLevels([0, self.imageItem_raster.levels[1] * 1.4])

def show_ephys(self, t0, tlen=1):

def show_ephys(self, t0, tlen=1.8):
"""
:param t0: behaviour time in seconds at which to start the view
:param tlen:
:return:
"""
print(t0)
s0 = int(self.ssl.samples2times(t0, direction='reverse'))
s1 = s0 + int(self.sr.fs * tlen)
raw = self.sr[s0:s1, : - self.sr.nsync].T
Expand All @@ -146,6 +163,24 @@ def show_ephys(self, t0, tlen=1):
# we slice the spikes using the samples according to ephys time, but display in session times
slice_spikes = slice(np.searchsorted(self.spikes['samples'], s0), np.searchsorted(self.spikes['samples'], s1))
t = self.spikes['times'][slice_spikes]
c = self.clusters.channels[self.spikes.clusters[slice_spikes]]
self.eqc_raw.ctrl.add_scatter(t, c)
self.eqc_des.ctrl.add_scatter(t, c)
ic = self.spikes.clusters[slice_spikes]

iok = self.clusters['label'][ic] == 1

for eqc in [self.eqc_des, self.eqc_raw]:
eqc.ctrl.add_scatter(t[~iok], self.clusters.channels[ic[~iok]], (255, 0, 0, 100), label='bad units')
eqc.ctrl.add_scatter(t[iok], self.clusters.channels[ic[iok]], rgb=(0, 255, 0, 100), label='good units')

if self.trials is not None:
trial_events = get_trial_events_to_display(self.trials)
for i, k in enumerate(trial_events):
ie = np.logical_and(trial_events[k] >= t0, trial_events[k] <= (t0 + tlen))
if np.sum(ie) == 0:
continue
te = trial_events[k][ie]
x = np.tile(te[:, np.newaxis], (1, 2)).flatten()
y = np.tile(np.array([YMIN, YMAX, YMAX, YMIN]), int(te.shape[0] / 2 + 1))[:te.shape[0] * 2]
for eqc in [self.eqc_des, self.eqc_raw]:
eqc.ctrl.add_curve(x, y, rgb=(np.array(SNS_PALETTE[i]) * 255).astype(int), label=k)
print(te)

0 comments on commit fb7f809

Please sign in to comment.