diff --git a/.github/workflows/ibllib_ci.yml b/.github/workflows/ibllib_ci.yml index 138fcead5..7a8fdcabc 100644 --- a/.github/workflows/ibllib_ci.yml +++ b/.github/workflows/ibllib_ci.yml @@ -35,12 +35,12 @@ jobs: - name: Install deps run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest flake8-docstrings + python -m pip install flake8 ruff flake8-docstrings pip install -r requirements.txt pip install -e . - - name: Flake8 + - name: Flake8 & ruff run: | - python -m flake8 + python -m ruff check . python -m flake8 --select D --ignore E ibllib/qc/camera.py ibllib/qc/task_metrics.py - name: Brainbox tests run: | diff --git a/brainbox/metrics/single_units.py b/brainbox/metrics/single_units.py index f1f0b46a3..cb805d756 100644 --- a/brainbox/metrics/single_units.py +++ b/brainbox/metrics/single_units.py @@ -105,11 +105,11 @@ def unit_stability(units_b, units=None, feat_names=['amps'], dist='norm', test=' """ # Get units. - if not (units is None): # we're using a subset of all units + if units is not None: # we're using a subset of all units unit_list = list(units_b[feat_names[0]].keys()) # for each `feat` and unit in `unit_list`, remove unit from `units_b` if not in `units` for feat in feat_names: - [units_b[feat].pop(unit) for unit in unit_list if not (int(unit) in units)] + [units_b[feat].pop(unit) for unit in unit_list if int(unit) not in units] unit_list = list(units_b[feat_names[0]].keys()) # get new `unit_list` after removing units # Initialize `p_vals` and `variances`. diff --git a/brainbox/plot.py b/brainbox/plot.py index 808c88d09..0e5edef3b 100644 --- a/brainbox/plot.py +++ b/brainbox/plot.py @@ -77,10 +77,10 @@ def feat_vars(units_b, units=None, feat_name='amps', dist='norm', test='ks', cma ''' # Get units. - if not (units is None): # we're using a subset of all units + if units is not None: # we're using a subset of all units unit_list = list(units_b['depths'].keys()) # For each unit in `unit_list`, remove unit from `units_b` if not in `units`. - [units_b['depths'].pop(unit) for unit in unit_list if not (int(unit) in units)] + [units_b['depths'].pop(unit) for unit in unit_list if int(unit) not in units] unit_list = list(units_b['depths'].keys()) # get new `unit_list` after removing unit # Calculate coefficients of variation for all units diff --git a/ibllib/__init__.py b/ibllib/__init__.py index 2626bdf02..0a7f52d6a 100644 --- a/ibllib/__init__.py +++ b/ibllib/__init__.py @@ -1,6 +1,7 @@ """Library implementing the International Brain Laboratory data pipeline.""" import logging import warnings +import os __version__ = '3.4.1' warnings.filterwarnings('always', category=DeprecationWarning, module='ibllib') @@ -8,11 +9,6 @@ # if this becomes a full-blown library we should let the logging configuration to the discretion of the dev # who uses the library. However since it can also be provided as an app, the end-users should be provided # with a useful default logging in standard output without messing with the complex python logging system -USE_LOGGING = True -#%(asctime)s,%(msecs)d -if USE_LOGGING: +if os.environ.get('IBLLIB_USE_LOGGING', '1').casefold() in ('1', 'true', 'yes'): from iblutil.util import setup_logger setup_logger(name='ibllib', level=logging.INFO) -else: - # deactivate all log calls for use as a library - logging.getLogger('ibllib').addHandler(logging.NullHandler()) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 76ae83279..b8cbe2d4b 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -683,13 +683,13 @@ def groom_pin_state(gpio, ttl, ts, tolerance=2., display=False, take='first', mi _logger.warning('Some onsets but not offsets (or vice versa) were not assigned; ' 'this may be a sign of faulty wiring or clock drift') # Find indices of GPIO upticks where only the downtick was marked for removal - orphaned_onsets, = np.where(~to_remove.reshape(-1, 2)[:, 0] & orphaned) + orphaned_onsets, = np.where(~to_remove.reshape(-1, 2)[:, 0] & orphaned) # The onsets_ array already has the other TTLs removed (same size as to_remove == # False) so subtract the number of removed elements from index. for i, v in enumerate(orphaned_onsets): orphaned_onsets[i] -= to_remove.reshape(-1, 2)[:v, 0].sum() # Same for offsets... - orphaned_offsets, = np.where(~to_remove.reshape(-1, 2)[:, 1] & orphaned) + orphaned_offsets, = np.where(~to_remove.reshape(-1, 2)[:, 1] & orphaned) for i, v in enumerate(orphaned_offsets): orphaned_offsets[i] -= to_remove.reshape(-1, 2)[:v, 1].sum() # Remove orphaned ttl onsets and offsets diff --git a/ibllib/plots/__init__.py b/ibllib/plots/__init__.py index 66b53c203..473e296b2 100644 --- a/ibllib/plots/__init__.py +++ b/ibllib/plots/__init__.py @@ -1 +1,326 @@ -from ibllib.plots.misc import * +from math import pi + +import numpy as np +import matplotlib.pyplot as plt +import scipy + +import ibldsp as dsp + + +def wiggle(w, fs=1, gain=0.71, color='k', ax=None, fill=True, linewidth=0.5, t0=0, clip=2, sf=None, + **kwargs): + """ + Matplotlib display of wiggle traces + + :param w: 2D array (numpy array dimension nsamples, ntraces) + :param fs: sampling frequency + :param gain: display gain ; Note that if sf is given, gain is not used + :param color: ('k') color of traces + :param ax: (None) matplotlib axes object + :param fill: (True) fill variable area above 0 + :param t0: (0) timestamp of the first sample + :param sf: scaling factor ; if None, uses the gain / SQRT of waveform RMS + :return: None + """ + nech, ntr = w.shape + tscale = np.arange(nech) / fs + if sf is None: + sf = gain / np.sqrt(dsp.utils.rms(w.flatten())) + + def insert_zeros(trace): + # Insert zero locations in data trace and tt vector based on linear fit + # Find zeros + zc_idx = np.where(np.diff(np.signbit(trace)))[0] + x1 = tscale[zc_idx] + x2 = tscale[zc_idx + 1] + y1 = trace[zc_idx] + y2 = trace[zc_idx + 1] + a = (y2 - y1) / (x2 - x1) + tt_zero = x1 - y1 / a + # split tt and trace + tt_split = np.split(tscale, zc_idx + 1) + trace_split = np.split(trace, zc_idx + 1) + tt_zi = tt_split[0] + trace_zi = trace_split[0] + # insert zeros in tt and trace + for i in range(len(tt_zero)): + tt_zi = np.hstack( + (tt_zi, np.array([tt_zero[i]]), tt_split[i + 1])) + trace_zi = np.hstack( + (trace_zi, np.zeros(1), trace_split[i + 1])) + return trace_zi, tt_zi + + if not ax: + ax = plt.gca() + for ntr in range(ntr): + if fill: + trace, t_trace = insert_zeros(w[:, ntr] * sf) + if clip: + trace = np.maximum(np.minimum(trace, clip), -clip) + ax.fill_betweenx(t_trace + t0, ntr, trace + ntr, + where=trace >= 0, + facecolor=color, + linewidth=linewidth) + wplot = np.minimum(np.maximum(w[:, ntr] * sf, -clip), clip) + ax.plot(wplot + ntr, tscale + t0, color, linewidth=linewidth, **kwargs) + + ax.set_xlim(-1, ntr + 1) + ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) + ax.set_ylabel('Time (s)') + ax.set_xlabel('Trace') + ax.invert_yaxis() + + return ax + + +class Density: + def __init__(self, w, fs=30_000, cmap='Greys_r', ax=None, taxis=0, title=None, gain=None, t0=0, unit='ms', **kwargs): + """ + Matplotlib display of traces as a density display using `imshow()`. + + :param w: 2D array (numpy array dimension nsamples, ntraces) + :param fs: sampling frequency (Hz). [default: 30000] + :param cmap: Name of MPL colormap to use in `imshow()`. [default: 'Greys_r'] + :param ax: Axis to plot in. If `None`, a new one is created. [default: `None`] + :param taxis: Time axis of input array (w). [default: 0] + :param title: Title to display on plot. [default: `None`] + :param gain: Gain in dB to display. Note: overrides `vmin` and `vmax` kwargs to `imshow()`. + Default: [`None` (auto)] + :param t0: Time offset to display in seconds. [default: 0] + :param kwargs: Key word arguments passed to `imshow()` + :param t_scalar: 1e3 for ms (default), 1 for s + :return: None + """ + w = w.reshape(w.shape[0], -1) + t_scalar = 1e3 if unit == 'ms' else 1 + if taxis == 0: + nech, ntr = w.shape + tscale = np.array([0, nech - 1]) / fs * t_scalar + extent = [-0.5, ntr - 0.5, tscale[1] + t0 * t_scalar, tscale[0] + t0 * t_scalar] + xlabel, ylabel, origin = ('Trace', f'Time ({unit})', 'upper') + elif taxis == 1: + ntr, nech = w.shape + tscale = np.array([0, nech - 1]) / fs * t_scalar + extent = [tscale[0] + t0 * t_scalar, tscale[1] + t0 * t_scalar, -0.5, ntr - 0.5] + ylabel, xlabel, origin = ('Trace', f'Time ({unit})', 'lower') + if ax is None: + self.figure, ax = plt.subplots() + else: + self.figure = ax.get_figure() + if gain: + kwargs["vmin"] = - 4 * (10 ** (gain / 20)) + kwargs["vmax"] = -kwargs["vmin"] + self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) + ax.set_ylabel(ylabel) + ax.set_xlabel(xlabel) + self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) + self.ax = ax + self.title = title or None + + def on_key_press(self, event): + if event.key == 'ctrl+a': + self.im.set_data(self.im.get_array() * np.sqrt(2)) + elif event.key == 'ctrl+z': + self.im.set_data(self.im.get_array() / np.sqrt(2)) + else: + return + self.figure.canvas.draw() + + +class Traces: + def __init__(self, w, fs=1, gain=0.71, color='k', ax=None, linewidth=0.5, t0=0, **kwargs): + """ + Matplotlib display of traces as a density display + + :param w: 2D array (numpy array dimension nsamples, ntraces) + :param fs: sampling frequency (Hz) + :param ax: axis to plot in + :return: None + """ + w = w.reshape(w.shape[0], -1) + nech, ntr = w.shape + tscale = np.arange(nech) / fs * 1e3 + sf = gain / dsp.utils.rms(w.flatten()) / 2 + if ax is None: + self.figure, ax = plt.subplots() + else: + self.figure = ax.get_figure() + self.plot = ax.plot(w * sf + np.arange(ntr), tscale + t0, color, + linewidth=linewidth, **kwargs) + ax.set_xlim(-1, ntr + 1) + ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) + ax.set_ylabel('Time (ms)') + ax.set_xlabel('Trace') + ax.invert_yaxis() + self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) + self.ax = ax + + def on_key_press(self, event): + if event.key == 'ctrl+a': + for i, ln in enumerate(self.plot): + ln.set_xdata((ln.get_xdata() - i) * np.sqrt(2) + i) + elif event.key == 'ctrl+z': + for i, ln in enumerate(self.plot): + ln.set_xdata((ln.get_xdata() - i) / np.sqrt(2) + i) + else: + return + self.figure.canvas.draw() + + +def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs): + """ + Matplotlib display of rising and falling fronts in a square-wave pattern + + :param tscale: time of indices of fronts + :param polarity: polarity of front (1: rising, -1:falling) + :param ax: matplotlib axes object + :return: None + """ + if not ax: + ax = plt.gca() + isort = np.argsort(tscale) + tscale = tscale[isort] + polarity = polarity[isort] + f = np.tile(polarity, (2, 1)) + t = np.concatenate((tscale, np.r_[tscale[1:], tscale[-1]])).reshape(2, f.shape[1]) + ydata = f.transpose().ravel() + ydata = (ydata + 1) / 2 * (yrange[1] - yrange[0]) + yrange[0] + ax.plot(t.transpose().ravel(), ydata, **kwargs) + + +def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs): + """ + From an x vector, draw separate vertical lines at each x location ranging from ymin to ymax + + :param x: numpy array vector of x values where to display lines + :param ymin: lower end of the lines (scalar) + :param ymax: higher end of the lines (scalar) + :param ax: (optional) matplotlib axis instance + :return: None + """ + x = np.tile(x, (3, 1)) + x[2, :] = np.nan + y = np.zeros_like(x) + y[0, :] = ymin + y[1, :] = ymax + y[2, :] = np.nan + if not ax: + ax = plt.gca() + ax.plot(x.T.flatten(), y.T.flatten(), **kwargs) + + +def spectrum(w, fs, smooth=None, unwrap=True, axis=0, **kwargs): + """ + Display spectral density of a signal along a given dimension + spectrum(w, fs) + :param w: signal + :param fs: sampling frequency (Hz) + :param smooth: (None) frequency samples to smooth over + :param unwrap: (True) unwraps the phase specrum + :param axis: axis on which to compute the FFT + :param kwargs: plot arguments to be passed to matplotlib + :return: matplotlib axes + """ + axis = 0 + smooth = None + unwrap = True + + ns = w.shape[axis] + fscale = dsp.fourier.fscale(ns, 1 / fs, one_sided=True) + W = scipy.fft.rfft(w, axis=axis) + amp = 20 * np.log10(np.abs(W)) + phi = np.angle(W) + + if unwrap: + phi = np.unwrap(phi) + + if smooth: + nf = np.round(smooth / fscale[1] / 2) * 2 + 1 + amp = scipy.signal.medfilt(amp, nf) + phi = scipy.signal.medfilt(phi, nf) + + fig, ax = plt.subplots(2, 1, sharex=True) + ax[0].plot(fscale, amp, **kwargs) + ax[1].plot(fscale, phi, **kwargs) + + ax[0].set_title('Spectral Density (dB rel to amplitude.Hz^-0.5)') + ax[0].set_ylabel('Amp (dB)') + ax[1].set_ylabel('Phase (rad)') + ax[1].set_xlabel('Frequency (Hz)') + return ax + + +def color_cycle(ind=None): + """ + Gets the matplotlib color-cycle as RGB numpy array of floats between 0 and 1 + :return: + """ + # import matplotlib as mpl + # c = np.uint32(np.array([int(c['color'][1:], 16) for c in mpl.rcParams['axes.prop_cycle']])) + # c = np.double(np.flip(np.reshape(c.view(np.uint8), (c.size, 4))[:, :3], 1)) / 255 + c = np.array([[0.12156863, 0.46666667, 0.70588235], + [1., 0.49803922, 0.05490196], + [0.17254902, 0.62745098, 0.17254902], + [0.83921569, 0.15294118, 0.15686275], + [0.58039216, 0.40392157, 0.74117647], + [0.54901961, 0.3372549, 0.29411765], + [0.89019608, 0.46666667, 0.76078431], + [0.49803922, 0.49803922, 0.49803922], + [0.7372549, 0.74117647, 0.13333333], + [0.09019608, 0.74509804, 0.81176471]]) + if ind is None: + return c + else: + return tuple(c[ind % c.shape[0], :]) + + +def starplot(labels, radii, ticks=None, ax=None, ylim=None, color=None, title=None): + """ + Function to create a star plot (also known as a spider plot, polar plot, or radar chart). + + Parameters: + labels (list): A list of labels for the variables to be plotted along the axes. + radii (numpy array): The values to be plotted for each variable. + ticks (numpy array, optional): A list of values to be used for the radial ticks. + If None, 5 ticks will be created between the minimum and maximum values of radii. + ax (matplotlib.axes._subplots.PolarAxesSubplot, optional): A polar axis object to plot on. + If None, a new figure and axis will be created. + ylim (tuple, optional): A tuple specifying the upper and lower limits of the y-axis. + If None, the limits will be set to the minimum and maximum values of radii. + color (str, optional): A string specifying the color of the plot. + If None, the color will be determined by the current matplotlib color cycle. + title (str, optional): A string specifying the title of the plot. + If None, no title will be displayed. + + Returns: + ax (matplotlib.axes._subplots.PolarAxesSubplot): The polar axis object containing the plot. + """ + + # What will be the angle of each axis in the plot? (we divide the plot / number of variable) + angles = [n / float(radii.size) * 2 * pi for n in range(radii.size)] + angles += angles[:1] + + if ax is None: + # Initialise the spider plot + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111, polar=True) + # If you want the first axis to be on top: + ax.set_theta_offset(pi / 2) + ax.set_theta_direction(-1) + # Draw one axe per variable + add labels + plt.xticks(angles[:-1], labels) + # Draw ylabels + ax.set_rlabel_position(0) + if ylim is None: + ylim = (0, np.max(radii)) + if ticks is None: + ticks = np.linspace(ylim[0], ylim[1], 5) + plt.yticks(ticks, [f'{t:2.2f}' for t in ticks], color="grey", size=7) + plt.ylim(ylim) + + r = np.r_[radii, radii[0]] + p = ax.plot(angles, r, linewidth=1, linestyle='solid', label="group A", color=color) + ax.fill(angles, r, alpha=0.1, color=p[0].get_color()) + if title is not None: + ax.set_title(title) + return ax diff --git a/ibllib/plots/misc.py b/ibllib/plots/misc.py index 288d110d5..801a512b9 100644 --- a/ibllib/plots/misc.py +++ b/ibllib/plots/misc.py @@ -1,328 +1,5 @@ #!/usr/bin/env python # -*- coding:utf-8 -*- -from math import pi - -import numpy as np -import matplotlib.pyplot as plt -import scipy - -import ibldsp as dsp - - -def wiggle(w, fs=1, gain=0.71, color='k', ax=None, fill=True, linewidth=0.5, t0=0, clip=2, sf=None, - **kwargs): - """ - Matplotlib display of wiggle traces - - :param w: 2D array (numpy array dimension nsamples, ntraces) - :param fs: sampling frequency - :param gain: display gain ; Note that if sf is given, gain is not used - :param color: ('k') color of traces - :param ax: (None) matplotlib axes object - :param fill: (True) fill variable area above 0 - :param t0: (0) timestamp of the first sample - :param sf: scaling factor ; if None, uses the gain / SQRT of waveform RMS - :return: None - """ - nech, ntr = w.shape - tscale = np.arange(nech) / fs - if sf is None: - sf = gain / np.sqrt(dsp.utils.rms(w.flatten())) - - def insert_zeros(trace): - # Insert zero locations in data trace and tt vector based on linear fit - # Find zeros - zc_idx = np.where(np.diff(np.signbit(trace)))[0] - x1 = tscale[zc_idx] - x2 = tscale[zc_idx + 1] - y1 = trace[zc_idx] - y2 = trace[zc_idx + 1] - a = (y2 - y1) / (x2 - x1) - tt_zero = x1 - y1 / a - # split tt and trace - tt_split = np.split(tscale, zc_idx + 1) - trace_split = np.split(trace, zc_idx + 1) - tt_zi = tt_split[0] - trace_zi = trace_split[0] - # insert zeros in tt and trace - for i in range(len(tt_zero)): - tt_zi = np.hstack( - (tt_zi, np.array([tt_zero[i]]), tt_split[i + 1])) - trace_zi = np.hstack( - (trace_zi, np.zeros(1), trace_split[i + 1])) - return trace_zi, tt_zi - - if not ax: - ax = plt.gca() - for ntr in range(ntr): - if fill: - trace, t_trace = insert_zeros(w[:, ntr] * sf) - if clip: - trace = np.maximum(np.minimum(trace, clip), -clip) - ax.fill_betweenx(t_trace + t0, ntr, trace + ntr, - where=trace >= 0, - facecolor=color, - linewidth=linewidth) - wplot = np.minimum(np.maximum(w[:, ntr] * sf, -clip), clip) - ax.plot(wplot + ntr, tscale + t0, color, linewidth=linewidth, **kwargs) - - ax.set_xlim(-1, ntr + 1) - ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) - ax.set_ylabel('Time (s)') - ax.set_xlabel('Trace') - ax.invert_yaxis() - - return ax - - -class Density: - def __init__(self, w, fs=30_000, cmap='Greys_r', ax=None, taxis=0, title=None, gain=None, t0=0, unit='ms', **kwargs): - """ - Matplotlib display of traces as a density display using `imshow()`. - - :param w: 2D array (numpy array dimension nsamples, ntraces) - :param fs: sampling frequency (Hz). [default: 30000] - :param cmap: Name of MPL colormap to use in `imshow()`. [default: 'Greys_r'] - :param ax: Axis to plot in. If `None`, a new one is created. [default: `None`] - :param taxis: Time axis of input array (w). [default: 0] - :param title: Title to display on plot. [default: `None`] - :param gain: Gain in dB to display. Note: overrides `vmin` and `vmax` kwargs to `imshow()`. - Default: [`None` (auto)] - :param t0: Time offset to display in seconds. [default: 0] - :param kwargs: Key word arguments passed to `imshow()` - :param t_scalar: 1e3 for ms (default), 1 for s - :return: None - """ - w = w.reshape(w.shape[0], -1) - t_scalar = 1e3 if unit == 'ms' else 1 - if taxis == 0: - nech, ntr = w.shape - tscale = np.array([0, nech - 1]) / fs * t_scalar - extent = [-0.5, ntr - 0.5, tscale[1] + t0 * t_scalar, tscale[0] + t0 * t_scalar] - xlabel, ylabel, origin = ('Trace', f'Time ({unit})', 'upper') - elif taxis == 1: - ntr, nech = w.shape - tscale = np.array([0, nech - 1]) / fs * t_scalar - extent = [tscale[0] + t0 * t_scalar, tscale[1] + t0 * t_scalar, -0.5, ntr - 0.5] - ylabel, xlabel, origin = ('Trace', f'Time ({unit})', 'lower') - if ax is None: - self.figure, ax = plt.subplots() - else: - self.figure = ax.get_figure() - if gain: - kwargs["vmin"] = - 4 * (10 ** (gain / 20)) - kwargs["vmax"] = -kwargs["vmin"] - self.im = ax.imshow(w, aspect='auto', cmap=cmap, extent=extent, origin=origin, **kwargs) - ax.set_ylabel(ylabel) - ax.set_xlabel(xlabel) - self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) - self.ax = ax - self.title = title or None - - def on_key_press(self, event): - if event.key == 'ctrl+a': - self.im.set_data(self.im.get_array() * np.sqrt(2)) - elif event.key == 'ctrl+z': - self.im.set_data(self.im.get_array() / np.sqrt(2)) - else: - return - self.figure.canvas.draw() - - -class Traces: - def __init__(self, w, fs=1, gain=0.71, color='k', ax=None, linewidth=0.5, t0=0, **kwargs): - """ - Matplotlib display of traces as a density display - - :param w: 2D array (numpy array dimension nsamples, ntraces) - :param fs: sampling frequency (Hz) - :param ax: axis to plot in - :return: None - """ - w = w.reshape(w.shape[0], -1) - nech, ntr = w.shape - tscale = np.arange(nech) / fs * 1e3 - sf = gain / dsp.utils.rms(w.flatten()) / 2 - if ax is None: - self.figure, ax = plt.subplots() - else: - self.figure = ax.get_figure() - self.plot = ax.plot(w * sf + np.arange(ntr), tscale + t0, color, - linewidth=linewidth, **kwargs) - ax.set_xlim(-1, ntr + 1) - ax.set_ylim(tscale[0] + t0, tscale[-1] + t0) - ax.set_ylabel('Time (ms)') - ax.set_xlabel('Trace') - ax.invert_yaxis() - self.cid_key = self.figure.canvas.mpl_connect('key_press_event', self.on_key_press) - self.ax = ax - - def on_key_press(self, event): - if event.key == 'ctrl+a': - for i, l in enumerate(self.plot): - l.set_xdata((l.get_xdata() - i) * np.sqrt(2) + i) - elif event.key == 'ctrl+z': - for i, l in enumerate(self.plot): - l.set_xdata((l.get_xdata() - i) / np.sqrt(2) + i) - else: - return - self.figure.canvas.draw() - - -def squares(tscale, polarity, ax=None, yrange=[-1, 1], **kwargs): - """ - Matplotlib display of rising and falling fronts in a square-wave pattern - - :param tscale: time of indices of fronts - :param polarity: polarity of front (1: rising, -1:falling) - :param ax: matplotlib axes object - :return: None - """ - if not ax: - ax = plt.gca() - isort = np.argsort(tscale) - tscale = tscale[isort] - polarity = polarity[isort] - f = np.tile(polarity, (2, 1)) - t = np.concatenate((tscale, np.r_[tscale[1:], tscale[-1]])).reshape(2, f.shape[1]) - ydata = f.transpose().ravel() - ydata = (ydata + 1) / 2 * (yrange[1] - yrange[0]) + yrange[0] - ax.plot(t.transpose().ravel(), ydata, **kwargs) - - -def vertical_lines(x, ymin=0, ymax=1, ax=None, **kwargs): - """ - From an x vector, draw separate vertical lines at each x location ranging from ymin to ymax - - :param x: numpy array vector of x values where to display lines - :param ymin: lower end of the lines (scalar) - :param ymax: higher end of the lines (scalar) - :param ax: (optional) matplotlib axis instance - :return: None - """ - x = np.tile(x, (3, 1)) - x[2, :] = np.nan - y = np.zeros_like(x) - y[0, :] = ymin - y[1, :] = ymax - y[2, :] = np.nan - if not ax: - ax = plt.gca() - ax.plot(x.T.flatten(), y.T.flatten(), **kwargs) - - -def spectrum(w, fs, smooth=None, unwrap=True, axis=0, **kwargs): - """ - Display spectral density of a signal along a given dimension - spectrum(w, fs) - :param w: signal - :param fs: sampling frequency (Hz) - :param smooth: (None) frequency samples to smooth over - :param unwrap: (True) unwraps the phase specrum - :param axis: axis on which to compute the FFT - :param kwargs: plot arguments to be passed to matplotlib - :return: matplotlib axes - """ - axis = 0 - smooth = None - unwrap = True - - ns = w.shape[axis] - fscale = dsp.fourier.fscale(ns, 1 / fs, one_sided=True) - W = scipy.fft.rfft(w, axis=axis) - amp = 20 * np.log10(np.abs(W)) - phi = np.angle(W) - - if unwrap: - phi = np.unwrap(phi) - - if smooth: - nf = np.round(smooth / fscale[1] / 2) * 2 + 1 - amp = scipy.signal.medfilt(amp, nf) - phi = scipy.signal.medfilt(phi, nf) - - fig, ax = plt.subplots(2, 1, sharex=True) - ax[0].plot(fscale, amp, **kwargs) - ax[1].plot(fscale, phi, **kwargs) - - ax[0].set_title('Spectral Density (dB rel to amplitude.Hz^-0.5)') - ax[0].set_ylabel('Amp (dB)') - ax[1].set_ylabel('Phase (rad)') - ax[1].set_xlabel('Frequency (Hz)') - return ax - - -def color_cycle(ind=None): - """ - Gets the matplotlib color-cycle as RGB numpy array of floats between 0 and 1 - :return: - """ - # import matplotlib as mpl - # c = np.uint32(np.array([int(c['color'][1:], 16) for c in mpl.rcParams['axes.prop_cycle']])) - # c = np.double(np.flip(np.reshape(c.view(np.uint8), (c.size, 4))[:, :3], 1)) / 255 - c = np.array([[0.12156863, 0.46666667, 0.70588235], - [1., 0.49803922, 0.05490196], - [0.17254902, 0.62745098, 0.17254902], - [0.83921569, 0.15294118, 0.15686275], - [0.58039216, 0.40392157, 0.74117647], - [0.54901961, 0.3372549, 0.29411765], - [0.89019608, 0.46666667, 0.76078431], - [0.49803922, 0.49803922, 0.49803922], - [0.7372549, 0.74117647, 0.13333333], - [0.09019608, 0.74509804, 0.81176471]]) - if ind is None: - return c - else: - return tuple(c[ind % c.shape[0], :]) - - -def starplot(labels, radii, ticks=None, ax=None, ylim=None, color=None, title=None): - """ - Function to create a star plot (also known as a spider plot, polar plot, or radar chart). - - Parameters: - labels (list): A list of labels for the variables to be plotted along the axes. - radii (numpy array): The values to be plotted for each variable. - ticks (numpy array, optional): A list of values to be used for the radial ticks. - If None, 5 ticks will be created between the minimum and maximum values of radii. - ax (matplotlib.axes._subplots.PolarAxesSubplot, optional): A polar axis object to plot on. - If None, a new figure and axis will be created. - ylim (tuple, optional): A tuple specifying the upper and lower limits of the y-axis. - If None, the limits will be set to the minimum and maximum values of radii. - color (str, optional): A string specifying the color of the plot. - If None, the color will be determined by the current matplotlib color cycle. - title (str, optional): A string specifying the title of the plot. - If None, no title will be displayed. - - Returns: - ax (matplotlib.axes._subplots.PolarAxesSubplot): The polar axis object containing the plot. - """ - - # What will be the angle of each axis in the plot? (we divide the plot / number of variable) - angles = [n / float(radii.size) * 2 * pi for n in range(radii.size)] - angles += angles[:1] - - if ax is None: - # Initialise the spider plot - fig = plt.figure(figsize=(8, 8)) - ax = fig.add_subplot(111, polar=True) - # If you want the first axis to be on top: - ax.set_theta_offset(pi / 2) - ax.set_theta_direction(-1) - # Draw one axe per variable + add labels - plt.xticks(angles[:-1], labels) - # Draw ylabels - ax.set_rlabel_position(0) - if ylim is None: - ylim = (0, np.max(radii)) - if ticks is None: - ticks = np.linspace(ylim[0], ylim[1], 5) - plt.yticks(ticks, [f'{t:2.2f}' for t in ticks], color="grey", size=7) - plt.ylim(ylim) - - r = np.r_[radii, radii[0]] - p = ax.plot(angles, r, linewidth=1, linestyle='solid', label="group A", color=color) - ax.fill(angles, r, alpha=0.1, color=p[0].get_color()) - if title is not None: - ax.set_title(title) - return ax +import warnings +from ibllib.plots import * # noqa: F403,F401 +warnings.warn('This module is deprecated and will be removed in future versions', DeprecationWarning) diff --git a/ibllib/qc/camera.py b/ibllib/qc/camera.py index b157c48e6..127dbf7c0 100644 --- a/ibllib/qc/camera.py +++ b/ibllib/qc/camera.py @@ -914,7 +914,7 @@ def check_focus(self, n=20, threshold=(100, 6), # Plot variance over frames ax = f.add_subplot(gs[len(roi), :]) ln = plt.plot(lpc_var) - [l.set_label(f'ROI #{i + 1}') for i, l in enumerate(ln)] + [_l.set_label(f'ROI #{i + 1}') for i, _l in enumerate(ln)] ax.axhline(threshold[0], 0, n, linestyle=':', color='r', label='lower threshold') ax.set(xlabel='Frame sample', ylabel='Variance of the Laplacian') plt.tight_layout() diff --git a/ibllib/qc/task_qc_viewer/task_qc.py b/ibllib/qc/task_qc_viewer/task_qc.py index b26060cde..6fbda3584 100644 --- a/ibllib/qc/task_qc_viewer/task_qc.py +++ b/ibllib/qc/task_qc_viewer/task_qc.py @@ -163,9 +163,9 @@ def create_plots(self, axes, plot_args['ymax'] = 3 ylabels = ['', 'frame2ttl', 'sound', ''] - for event, c, l in zip(trial_events, cycle(color_map), linestyle): + for event, c, ln in zip(trial_events, cycle(color_map), linestyle): if event in trial_data: - plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args) + plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=ln, **plot_args) axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) axes.set_yticks(list(range(plot_args['ymax'] + 1))) diff --git a/ruff.toml b/ruff.toml index 253516e9f..99ff8bc23 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,4 +1,23 @@ +preview = true line-length = 130 +exclude = [ + "brainbox/examples/", + "examples/" +] + + +[lint] +# Enable the same rule sets as flake8 +select = ["E", "F", "W", "W391"] + +# Ignore the same rules as in .flake8 config +ignore = [ + "E266", # too many leading '#' for block comment + "F841", # local variable assigned but never used +] + +[lint.extend-per-file-ignores] +"__init__.py" = ["F401"] # imported but unused [format] quote-style = "single"