Source code for imc.graphics

"""
Plotting functions and utilities to handle images.
"""

from __future__ import annotations
import typing as tp
from functools import wraps
import colorsys
from functools import partial

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from tqdm import tqdm

from skimage.exposure import equalize_hist as eq

from imc.types import Figure, Axis, Array, Series, ColorMap, Patch, AnnData, Path
from imc.utils import minmax_scale

DEFAULT_PIXEL_UNIT_NAME = r"$\mu$m"

DEFAULT_CHANNEL_COLORS = plt.get_cmap("tab10")(np.linspace(0, 1, 10))


[docs]class InteractiveViewer: """ An interactive image viewer for multiplexed images. Parameters ---------- obj: ROI | Array An ROI object or a numpy array **kwargs: dict Additional keyword arguments to pass to matplotlib.pyplot.imshow. """ def __init__( self, obj: tp.Union[_roi.ROI, Array], show: bool = False, up_key: str = "w", down_key: str = "s", log_key: str = "l", **kwargs, ): plt.close("all") self.array = obj if isinstance(obj, np.ndarray) else obj.stack self.labels = ( ([""] * len(self.array)) if isinstance(obj, np.ndarray) else obj.channel_labels.tolist() ) self.suptitle = "" if isinstance(obj, np.ndarray) else obj.name self.up_key = up_key self.down_key = down_key self.log_key = log_key self.kwargs = kwargs # internal self.index = 0 self.n_channels = self.array.shape[0] self.transforms: tp.Set[str] = set() self.fig, self.ax = plt.subplots(num=self.suptitle) # go self.multi_slice_viewer() if show: plt.show(block=False) # plt.close(self.fig)
[docs] def multi_slice_viewer(self) -> Figure: """Start the viewer process.""" self.remove_keymap_conflicts({self.up_key, self.down_key, self.log_key}) self.ax.imshow(self.array[self.index], **self.kwargs) self.img = self.ax.images[0] self.ax.set_title(self.index, loc="left") if self.labels is not None: self.ax.set_title(self.labels[self.index]) if self.suptitle is not None: self.fig.suptitle(self.suptitle) self.ax.set(xlabel="X", ylabel="Y") # TODO: add colorbar scale # Add event listener self.fig.canvas.mpl_connect("key_press_event", partial(self.process_key))
[docs] def remove_keymap_conflicts(self, new_keys_set: tp.Set) -> None: """Remove conflicts between viewer keyboard shortcuts and previously existing shortcuts.""" for prop in plt.rcParams: if prop.startswith("keymap."): keys = plt.rcParams[prop] remove_list = set(keys) & new_keys_set for key in remove_list: keys.remove(key)
[docs] def process_key(self, event) -> None: """Process keyboard events.""" if event.key == self.up_key: self.previous_slice() elif event.key == self.down_key: self.next_slice() elif event.key == self.log_key: self.log_slice() self.ax.set_title(self.index, loc="left") if self.labels is not None: self.ax.set_title(self.labels[self.index]) # Report transformations if self.transforms: trans = ", ".join(self.transforms) self.ax.set_xlabel("X" + f"\nTransformations: '{trans}'") else: self.ax.set_xlabel("X") # Draw self.fig.canvas.draw()
[docs] def get_slice(self) -> Array: """Get a array slice for the current index with current transformations.""" a = self.array[self.index] if "log" in self.transforms: a = np.log1p(a) return a
[docs] def set_image(self) -> None: """Update image to current index and transformations.""" a = self.get_slice() self.img.set_array(a) self.img.set_clim(a.min(), a.max())
[docs] def previous_slice(self) -> None: """Go to the previous slice.""" self.index = (self.index - 1) % self.n_channels self.set_image()
[docs] def next_slice(self) -> None: """Go to the next slice.""" self.index = (self.index + 1) % self.n_channels self.set_image()
[docs] def log_slice(self) -> None: """Go to the previous slice.""" if "log" not in self.transforms: self.transforms.add("log") else: self.transforms.remove("log") self.set_image()
[docs]def get_volume() -> Array: """Get example volumetric image.""" from urlpath import URL import imageio base_url = URL("https://prod-images-static.radiopaedia.org/images/") start_n = 53734044 length = 137 imgs = list() for i in tqdm(range(length)): url = base_url / f"{start_n + i}/{i + 1}_gallery.jpeg" resp = url.get() c = resp.content imgs.append(imageio.read(c, format="jpeg").get_data(0)) img = np.asarray(imgs) return img
[docs]def close_plots(func) -> tp.Callable: """ Decorator to close all plots on function exit. """ @wraps(func) def close(*args, **kwargs) -> None: func(*args, **kwargs) plt.close("all") return close
[docs]def add_scale( _ax: tp.Optional[Axis] = None, width: int = 100, unit: str = DEFAULT_PIXEL_UNIT_NAME, ) -> None: """ Add a scale bar to a figure. Should be called after plotting (usually with matplotlib.pyplot.imshow). """ # these values were optimized with a 1000 x 1000 reference figure if _ax is None: _ax = plt.gca() height = 1 / 40 text_separation = 40 / 1000 __h = sum(_ax.get_ylim()) + 1 height = height * __h xposition = 3 yposition = -height - 3 text_separation = text_separation * height _ax.add_patch( mpatches.Rectangle( (xposition, yposition), width, height, color="black", clip_on=False ) ) _ax.text( xposition + width + 10, yposition + height + text_separation, s=f"{width}{unit}", color="black", ha="left", fontsize=4, )
[docs]def add_minmax(minmax: tp.Tuple[float, float], _ax: tp.Optional[Axis] = None) -> None: """ Add an annotation of the min and max values of the array. """ # these values were optimized with a 1000 x 1000 reference figure if _ax is None: _ax = plt.gca() _ax.text( _ax.get_xlim()[1], -3, s=f"Range: {minmax[0]:.2f} -> {minmax[1]:.2f}", color="black", ha="right", fontsize=4, )
[docs]def add_legend( patches: tp.Sequence[Patch], ax: tp.Optional[Axis] = None, **kwargs ) -> None: """Add a legend to an existing axis.""" if ax is None: ax = plt.gca() _patches = np.asarray(patches) _patches = _patches[ pd.Series([(p.get_facecolor(), p.get_label()) for p in _patches]) .drop_duplicates() .index.tolist() ] defaults = dict(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) defaults.update(kwargs) ax.legend(handles=_patches.tolist(), **defaults)
[docs]def saturize(arr: Array) -> Array: """Saturize an image by channel, by minmax scalling each.""" if np.argmin(arr.shape) == 0: for i in range(arr.shape[0]): arr[i, :, :] = minmax_scale(arr[i, :, :]) elif np.argmin(arr.shape) == 2: for i in range(arr.shape[2]): arr[:, :, i] = minmax_scale(arr[:, :, i]) else: raise ValueError("Do not understand order of array axis.") return arr
[docs]def merge_channels( arr: Array, target_colors: tp.Optional[tp.Sequence[tp.Tuple[float, float, float]]] = None, return_colors: bool = False, ) -> tp.Union[Array, tp.Tuple[Array, tp.Sequence[tp.Tuple[float, float, float]]]]: """ Assumes [0, 1] float array. to is a tuple of 3 colors. """ # defaults = list(matplotlib.colors.TABLEAU_COLORS.values()) n_channels = arr.shape[0] if target_colors is None: target_colors = [ matplotlib.colors.to_rgb(col) for col in DEFAULT_CHANNEL_COLORS[:n_channels] ] if (n_channels == 3) and target_colors is None: m = np.moveaxis(np.asarray([eq(x) for x in arr]), 0, -1) res = (m - m.min((0, 1))) / (m.max((0, 1)) - m.min((0, 1))) return res if not return_colors else (res, target_colors) elif isinstance(target_colors, (list, tuple)): assert len(target_colors) == n_channels target_colors = [matplotlib.colors.to_rgb(col) for col in target_colors] # work in int space to avoid float underflow if arr.min() >= 0 and arr.max() <= 1: arr *= 256 else: arr = saturize(arr) * 256 res = np.zeros(arr.shape[1:] + (3,)) for i in range(n_channels): for j in range(3): res[:, :, j] = res[:, :, j] + arr[i] * target_colors[i][j] # return saturize(res) if not return_colors else (saturize(res), target_colors) return res if not return_colors else (res, target_colors)
[docs]def rainbow_text(x, y, strings, colors, orientation="horizontal", ax=None, **kwargs): """ Take a list of *strings* and *colors* and place them next to each other, with text strings[i] being shown in colors[i]. Parameters ---------- x, y : float Text position in data coordinates. strings : list of str The strings to draw. colors : list of color The colors to use. orientation : {'horizontal', 'vertical'} ax : Axes, tp.optional The Axes to draw into. If None, the current axes will be used. **kwargs All other keyword arguments are passed to plt.text(), so you can set the font size, family, etc. From: https://matplotlib.org/3.2.1/gallery/text_labels_and_annotations/rainbow_text.html """ from matplotlib.transforms import Affine2D if ax is None: ax = plt.gca() t = ax.transData canvas = ax.figure.canvas assert orientation in ["horizontal", "vertical"] if orientation == "vertical": kwargs.update(rotation=90, verticalalignment="bottom") for s, c in zip(strings, colors): text = ax.text(x, y, s + " ", color=c, transform=t, **kwargs) # Need to draw to update the text position. text.draw(canvas.get_renderer()) ex = text.get_window_extent() if orientation == "horizontal": t = text.get_transform() + Affine2D().translate(ex.width * 0.5, 0) else: t = text.get_transform() + Affine2D().translate(0, ex.height * 0.5)
[docs]def get_n_colors(n: int, max_value: float = 1.0) -> Array: """ With modifications from https://stackoverflow.com/a/13781114/1469535 """ import itertools from fractions import Fraction def zenos_dichotomy(): """ http://en.wikipedia.org/wiki/1/2_%2B_1/4_%2B_1/8_%2B_1/16_%2B_%C2%B7_%C2%B7_%C2%B7 """ for k in itertools.count(): yield Fraction(1, 2 ** k) def fracs(): """ [Fraction(0, 1), Fraction(1, 2), Fraction(1, 4), Fraction(3, 4), Fraction(1, 8), Fraction(3, 8), Fraction(5, 8), Fraction(7, 8), Fraction(1, 16), Fraction(3, 16), ...] [0.0, 0.5, 0.25, 0.75, 0.125, 0.375, 0.625, 0.875, 0.0625, 0.1875, ...] """ yield Fraction(0) for k in zenos_dichotomy(): i = k.denominator # [1,2,4,8,16,...] for j in range(1, i, 2): yield Fraction(j, i) # can be used for the v in hsv to map linear values 0..1 to something that looks equidistant # bias = lambda x: (math.sqrt(x/3)/Fraction(2,3)+Fraction(1,3))/Fraction(6,5) def hue_to_tones(h): for s in [Fraction(6, 10)]: # optionally use range for v in [Fraction(8, 10), Fraction(5, 10)]: # could use range too yield (h, s, v) # use bias for v here if you use range def hsv_to_rgb(x): return colorsys.hsv_to_rgb(*map(float, x)) flatten = itertools.chain.from_iterable def hsvs(): return flatten(map(hue_to_tones, fracs())) def rgbs(): return map(hsv_to_rgb, hsvs()) return np.asarray(list(itertools.islice(rgbs(), n))) * max_value
def get_rgb_cmaps() -> tp.Tuple[ColorMap, ColorMap, ColorMap]: r = np.linspace(0, 1, 100).reshape((-1, 1)) r = [ matplotlib.colors.LinearSegmentedColormap.from_list("", p * r) for p in np.eye(3) ] return tuple(r) # type: ignore def get_dark_cmaps(n: int = 3, from_palette: str = "colorblind") -> tp.List[ColorMap]: r = np.linspace(0, 1, 100).reshape((-1, 1)) if n > len(sns.color_palette(from_palette)): print( "Chosen palette has less than the requested number of colors. " "Will reuse!" ) return [ matplotlib.colors.LinearSegmentedColormap.from_list("", np.array(p) * r) for p in sns.color_palette(from_palette, n) ] def get_transparent_cmaps( n: int = 3, from_palette: tp.Optional[str] = "colorblind" ) -> tp.List[ColorMap]: __r = np.linspace(0, 1, 100) if n > len(sns.color_palette(from_palette)): print( "Chosen palette has less than the requested number of colors. " "Will reuse!" ) return [ matplotlib.colors.LinearSegmentedColormap.from_list("", [p + (c,) for c in __r]) for p in sns.color_palette(from_palette, n) ] def get_random_label_cmap(n=2 ** 16, h=(0, 1), l=(0.4, 1), s=(0.2, 0.8)): h, l, s = ( np.random.uniform(*h, n), np.random.uniform(*l, n), np.random.uniform(*s, n), ) cols = np.stack( [colorsys.hls_to_rgb(_h, _l, _s) for _h, _l, _s in zip(h, l, s)], axis=0 ) cols[0] = 0 return matplotlib.colors.ListedColormap(cols) random_label_cmap = get_random_label_cmap # TODO: see if function can be sped up e.g. with Numba
[docs]def cell_labels_to_mask(mask: Array, labels: tp.Union[Series, tp.Dict]) -> Array: """Replaces integers in `mask` with values from the mapping in `labels`.""" res = np.zeros(mask.shape, dtype=int) for k, v in labels.items(): res[mask == k] = v return res
[docs]def values_to_rgb_colors( mask: Array, from_palette: str = None, remove_zero: bool = True ) -> tp.Tuple[Array, tp.Dict[tp.Any, tp.Tuple[float, float, float]]]: """ Colors each integer in the 2D `mask` array with a unique color by expanding the array to 3 dimensions. Also returns the mapping of mask identity to color tuple. """ ident = np.sort(np.unique(mask)) if remove_zero: ident = ident[ident != 0] n_colors = len(ident) if from_palette is not None: palette = sns.color_palette(from_palette) else: palette = list(get_n_colors(n_colors)) if n_colors > len(palette): print( "Chosen palette has less than the requested number of colors. " "Will reuse!" ) palette = sns.color_palette(from_palette, ident.max()) colors = pd.Series(palette, index=ident) res = np.zeros((mask.shape) + (3,)) for c, i in zip(colors, ident): x, y = np.nonzero(np.isin(mask, i)) res[x, y, :] = c return res, colors.to_dict()
@tp.overload def get_grid_dims( dims: tp.Union[int, tp.Collection], return_fig: tp.Literal[True], nstart: tp.Optional[int], ) -> Figure: ... @tp.overload def get_grid_dims( dims: tp.Union[int, tp.Collection], return_fig: tp.Literal[False], nstart: tp.Optional[int], ) -> tp.Tuple[int, int]: ...
[docs]def get_grid_dims( dims: tp.Union[int, tp.Collection], return_fig: bool = False, nstart: tp.Optional[int] = None, **kwargs, ) -> tp.Union[tp.Tuple[int, int], Figure]: """ Given a number of `dims` subplots, choose optimal x/y dimentions of plotting grid maximizing in order to be as square as posible and if not with more columns than rows. """ if not isinstance(dims, int): dims = len(dims) if nstart is None: n = min(dims, 1 + int(np.ceil(np.sqrt(dims)))) else: n = nstart if (n * n) == dims: m = n else: a = pd.Series(n * np.arange(1, n + 1)) / dims m = a[a >= 1].index[0] + 1 assert n * m >= dims if n * m % dims > 1: try: n, m = get_grid_dims(dims=dims, return_fig=False, nstart=n - 1) except IndexError: pass if not return_fig: return n, m else: if "figsize" not in kwargs: kwargs["figsize"] = (m * 4, n * 4) fig, ax = plt.subplots(n, m, **kwargs) return fig
[docs]def share_axes_by(axes: Axis, by: str) -> None: """ Share given axes after figure creation. Useful when not all subplots of a figure should be shared. """ if by == "row": for row in axes: for axs in row[1:]: row[0].get_shared_x_axes().join(row[0], axs) row[0].get_shared_y_axes().join(row[0], axs) axs.set_yticklabels([]) elif by == "col": for col in axes: for axs in col[:-1]: col[-1].get_shared_x_axes().join(col[-1], axs) col[-1].get_shared_y_axes().join(col[-1], axs) axs.set_xticklabels([]) elif by == "both": # attach all axes to upper left one master = axes[0, 0] for axs in axes.flatten(): master.get_shared_x_axes().join(master, axs) master.get_shared_y_axes().join(master, axs) # remove both ticks from axs away from left down border for axs in axes[:-1, 1:].flatten(): axs.set_xticklabels([]) axs.set_yticklabels([]) # remove xticks from first columns except last for axs in axes[:-1, 0]: axs.set_xticklabels([]) # remove yticks from last row except first for axs in axes[-1, 1:]: axs.set_yticklabels([])
[docs]def plot_single_channel( arr: Array, axis: tp.Optional[Axis] = None, cmap: tp.Optional[ColorMap] = None ) -> tp.Union[Figure, Axis]: """Plot a single image channel either in a new figure or in an existing axis""" if axis is None: fig, axs = plt.subplots(1, 1, figsize=(6 * 1, 6 * 1), sharex=True, sharey=True) axs.imshow(arr, cmap=cmap, interpolation="bilinear", rasterized=True) axs.axis("off") return fig if axis is None else axs
def plot_overlayied_channels( arr: Array, channel_labels: tp.Sequence[str], axis: tp.Optional[Axis] = None, palette: tp.Optional[str] = None, ) -> tp.Union[Figure, Axis]: if axis is None: fig, ax = plt.subplots(1, 1, figsize=(6 * 1, 6 * 1), sharex=True, sharey=True) cmaps = get_transparent_cmaps(arr.shape[0], from_palette=palette) patches = list() for i, (m, c) in enumerate(zip(channel_labels, cmaps)): x = arr[i].squeeze() ax.imshow( x, cmap=c, label=m, interpolation="bilinear", rasterized=True, alpha=0.9, ) ax.axis("off") patches.append(mpatches.Patch(color=c(256), label=m)) ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) return fig if axis is None else ax
[docs]def rasterize_scanpy(fig: Figure) -> None: """ Rasterize figure containing Scatter plots of single cells such as PCA and UMAP plots drawn by Scanpy. """ import warnings with warnings.catch_warnings(record=False) as w: warnings.simplefilter("ignore") yes_class = ( matplotlib.collections.PathCollection, matplotlib.collections.LineCollection, ) not_clss = ( matplotlib.text.Text, matplotlib.axis.XAxis, matplotlib.axis.YAxis, ) for axs in fig.axes: for __c in axs.get_children(): if not isinstance(__c, not_clss): if not __c.get_children(): if isinstance(__c, yes_class): __c.set_rasterized(True) for _cc in __c.get_children(): if not isinstance(_cc, not_clss): if isinstance(_cc, yes_class): _cc.set_rasterized(True)
[docs]def add_centroids( a: AnnData, ax: tp.Union[tp.Sequence[Axis], Axis] = None, res: float = None, column: str = None, algo: str = "umap", ): """ a: AnnData ax: matplotlib.Axes.axes res: resolution of clusters to label column: Column to be used. Has precedence over `res`. """ from numpy_groupies import aggregate if ax is None: ax = plt.gca() if column is None: # try to guess the clustering key_added if res is None: try: lab = a.obs.columns[a.obs.columns.str.contains("cluster_")][0] except: lab = a.obs.columns[a.obs.columns.str.contains("leiden")][0] else: lab = f"cluster_{res}" else: lab = column # # # centroids: offset = 0 if algo != "diffmap" else 1 cent = aggregate( a.obs[lab].cat.codes, a.obsm[f"X_{algo}"][:, 0 + offset : 2 + offset], func="mean", axis=0, ) for i, clust in enumerate(a.obs[lab].sort_values().unique()): ax.text(*cent[i], s=clust)
def legend_without_duplicate_labels(ax: Axis, **kwargs) -> None: handles, labels = ax.get_legend_handles_labels() unique = [ (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] ] ax.legend(*zip(*unique), **kwargs) import imc.data_models.roi as _roi