import numpy as np
from . import evaluate
from skimage import color
from matplotlib import cm, pyplot as plt
import itertools as it
from math import ceil
###########################
# VISUALIZATION FUNCTIONS #
###########################
[docs]def imshow_grey(im, axis=None):
"""Show a segmentation using a gray colormap.
Parameters
----------
im : np.ndarray of int, shape (M, N)
The segmentation to be displayed.
Returns
-------
fig : plt.Figure
The image shown.
"""
if axis is None:
fig, axis = plt.subplots()
return axis.imshow(im, cmap='gray')
[docs]def imshow_magma(im, axis=None):
"""Show a segmentation using a magma colormap.
Parameters
----------
im : np.ndarray of int, shape (M, N)
The segmentation to be displayed.
Returns
-------
fig : plt.Figure
The image shown.
"""
if axis is None:
fig, axis = plt.subplots()
return axis.imshow(im, cmap='magma')
[docs]def imshow_rand(im, axis=None, labrandom=True):
"""Show a segmentation using a random colormap.
Parameters
----------
im : np.ndarray of int, shape (M, N)
The segmentation to be displayed.
labrandom : bool, optional
Use random points in the Lab colorspace instead of RGB.
Returns
-------
fig : plt.Figure
The image shown.
"""
if axis is None:
fig, axis = plt.subplots()
rand_colors = np.random.random(size=(ceil(np.max(im)), 3))
if labrandom:
rand_colors[:, 0] = rand_colors[:, 0] * 81 + 39
rand_colors[:, 1] = rand_colors[:, 1] * 185 - 86
rand_colors[:, 2] = rand_colors[:, 2] * 198 - 108
rand_colors = color.lab2rgb(rand_colors[np.newaxis, ...])[0]
rand_colors[rand_colors < 0] = 0
rand_colors[rand_colors > 1] = 1
rcmap = cm.colors.ListedColormap(np.concatenate((np.zeros((1, 3)),
rand_colors)))
return axis.imshow(im, cmap=rcmap)
[docs]def show_multiple_images(*images, axes=None, image_type='raw'):
"""Returns a figure with subplots containing multiple images.
Parameters
----------
images : np.ndarray of int, shape (M, N)
The input images to be displayed.
axes: matplotlib.AxesImage, optional
Whether to pass in multiple axes. Must be equal to the number of
input images.
image_type : string, optional
Displays the images with different colormaps. Set to display
'raw' by default. Other options that are accepted
are 'grey' and 'magma', or 'rand'.
Returns
-------
fig : plt.Figure
The image shown.
"""
number_of_im = len(images)
figure = plt.figure()
for i in range(number_of_im):
ax = (figure.add_subplot(1, number_of_im, i+1) if axes is None
else axes[i])
if image_type == 'grey' or image_type == 'gray':
imshow_grey(images[i], axis=ax)
elif image_type == 'magma':
imshow_magma(images[i], axis=ax)
elif image_type == 'rand':
imshow_rand(images[i], axis=ax)
elif image_type == 'raw':
ax.imshow(images[i])
else:
print("not a valid image type.")
return None
ax.set_title(f'Image number {i+1} with a {image_type} colormap.')
return ax
[docs]def draw_seg(seg, im):
"""Return a segmentation map matching the original image color.
Parameters
----------
seg : np.ndarray of int, shape (M, N, ...)
The segmentation to be displayed
im : np.ndarray, shape (M, N, ..., C)
The image corresponding to the segmentation.
Returns
-------
out : np.ndarray, same shape and type as `im`.
An image where each segment has uniform color.
Examples
--------
>>> a = np.array([[1, 1, 2, 2],
... [1, 2, 2, 3],
... [2, 2, 3, 3]])
>>> g = np.array([[0.5, 0.2, 1.0, 0.9],
... [0.2, 0.8, 0.9, 0.6],
... [0.9, 0.9, 0.4, 0.5]])
>>> draw_seg(a, g)
array([[ 0.3, 0.3, 0.9, 0.9],
[ 0.3, 0.9, 0.9, 0.5],
[ 0.9, 0.9, 0.5, 0.5]])
"""
out = np.zeros_like(im)
labels = np.unique(seg)
if (seg==0).any():
labels = labels[1:]
for u in labels:
mask = (seg == u).nonzero()
color = im[mask].mean(axis=0)
out[mask] = color
return out
[docs]def display_3d_segmentations(segs, image=None, probability_map=None, axis=0,
z=None, fignum=None):
"""Show slices of multiple 3D segmentations.
Parameters
----------
segs : list or tuple of np.ndarray of int, shape (M, N, P)
The segmentations to be examined.
image : np.ndarray, shape (M, N, P[, 3]), optional
The image corresponding to the segmentations.
probability_map : np.ndarray, shape (M, N, P), optional
The segment boundary probability map.
axis : int in {0, 1, 2}, optional
The axis along which to show a slice of the segmentation.
z : int in [0, `(M, N, P)[axis]`), optional
The slice to display. Defaults to the middle slice.
fignum : int, optional
Which figure number to use. Uses the default (new figure) if none is
provided.
Returns
-------
fig : plt.Figure
The figure handle.
"""
numplots = len(segs)
if image is not None:
numplots += 1
if probability_map is not None:
numplots += 1
candidate_plot_arrangements = list(it.combinations_with_replacement(
range(1, 5), 2))
# get the smallest plot arrangement that can display the number of
# segmentations we want
plot_arrangement = [(i, j) for i, j in candidate_plot_arrangements
if i * j >= numplots][0]
fig = plt.figure(fignum)
current_subplot = 1
if image is not None:
plt.subplot(*plot_arrangement + (current_subplot,))
imshow_grey(np.rollaxis(image, axis)[z])
current_subplot += 1
if probability_map is not None:
plt.subplot(*plot_arrangement + (current_subplot,))
imshow_magma(np.rollaxis(probability_map, axis)[z])
current_subplot += 1
for i, j in enumerate(range(current_subplot, numplots + 1)):
plt.subplot(*plot_arrangement + (j,))
imshow_rand(np.rollaxis(segs[i], axis)[z])
return fig
[docs]def plot_vi(g, history, gt, fig=None):
"""Plot the VI from segmentations based on Rag and sequence of merges.
Parameters
----------
g : agglo.Rag object
The region adjacency graph.
history : list of tuples
The merge history of the RAG.
gt : np.ndarray
The ground truth corresponding to the RAG.
fig : plt.Figure, optional
Use this figure for plotting. If not provided, a new figure is created.
Returns
-------
None
"""
v = []
n = []
seg = g.get_segmentation()
for i in history:
seg[seg==i[1]] = i[0]
v.append(evaluate.vi(seg, gt))
n.append(len(np.unique(seg)-1))
if fig is None:
fig = plt.figure()
plt.plot(n, v, figure = fig)
plt.xlabel('Number of segments', figure = fig)
plt.ylabel('vi', figure = fig)
[docs]def plot_vi_breakdown_panel(px, h, title, xlab, ylab, hlines, scatter_size,
**kwargs):
"""Plot a single panel (over or undersegmentation) of VI breakdown plot.
Parameters
----------
px : np.ndarray of float, shape (N,)
The probability (size) of each segment.
h : np.ndarray of float, shape (N,)
The conditional entropy of that segment.
title, xlab, ylab : string
Parameters for `matplotlib.plt.plot`.
hlines : iterable of float
Plot hyperbolic lines of same VI contribution. For each value `v` in
`hlines`, draw the line `h = v/px`.
scatter_size : int, optional
**kwargs : dict
Additional keyword arguments for `matplotlib.pyplot.plot`.
Returns
-------
None
"""
x = np.arange(max(min(px),1e-10), max(px), (max(px)-min(px))/100.0)
for val in hlines:
plt.plot(x, val/x, color='gray', ls=':', **kwargs)
plt.scatter(px, h, label=title, s=scatter_size, **kwargs)
# Make points clickable to identify ID. This section needs work.
plt.xlim(xmin=-0.05*max(px), xmax=1.05*max(px))
plt.ylim(ymin=-0.05*max(h), ymax=1.05*max(h))
plt.xlabel(xlab)
plt.ylabel(ylab)
plt.title(title)
[docs]def plot_vi_breakdown(seg, gt, ignore_seg=[], ignore_gt=[],
hlines=None, subplot=False, figsize=None, **kwargs):
"""Plot conditional entropy H(Y|X) vs P(X) for both seg|gt and gt|seg.
Parameters
----------
seg : np.ndarray of int, shape (M, [N, ..., P])
The automatic (candidate) segmentation.
gt : np.ndarray of int, shape (M, [N, ..., P]) (same as `seg`)
The gold standard/ground truth segmentation.
ignore_seg : list of int, optional
Ignore segments in this list from the automatic segmentation during
evaluation and plotting.
ignore_gt : list of int, optional
Ignore segments in this list from the ground truth segmentation during
evaluation and plotting.
hlines : int, optional
Plot this many isoclines between the minimum and maximum VI
contributions.
subplot : bool, optional
If True, plot oversegmentation and undersegmentation in separate
subplots.
figsize : tuple of float, optional
The figure width and height, in inches.
**kwargs : dict
Additional keyword arguments for `matplotlib.pyplot.plot`.
Returns
-------
None
"""
plt.ion()
pxy,px,py,hxgy,hygx,lpygx,lpxgy = evaluate.vi_tables(seg, gt,
ignore_seg, ignore_gt)
cu = -px*lpygx
co = -py*lpxgy
if hlines is None:
hlines = []
elif hlines == True:
hlines = 10
if type(hlines) == int:
maxc = max(cu[cu!=0].max(), co[co!=0].max())
hlines = np.arange(maxc/hlines, maxc, maxc/hlines)
plt.figure(figsize=figsize)
if subplot: plt.subplot(1,2,1)
plot_vi_breakdown_panel(px, -lpygx,
'False merges', 'p(S=seg)', 'H(G|S=seg)',
hlines, c='blue', **kwargs)
if subplot: plt.subplot(1,2,2)
plot_vi_breakdown_panel(py, -lpxgy,
'False splits', 'p(G=gt)', 'H(S|G=gt)',
hlines, c='orange', **kwargs)
if not subplot:
plt.title('vi contributions by body.')
plt.legend(loc='upper right', scatterpoints=1)
plt.xlabel('Segment size (fraction of volume)', fontsize='large')
plt.ylabel('Conditional entropy (bits)', fontsize='large')
xmax = max(px.max(), py.max())
plt.xlim(-0.05*xmax, 1.05*xmax)
ymax = max(-lpygx.min(), -lpxgy.min())
plt.ylim(-0.05*ymax, 1.05*ymax)
[docs]def add_opts_to_plot(ars, colors='k', markers='^', **kwargs):
"""In an existing active split-vi plot, add the point of optimal VI.
By default, a star marker is used.
Parameters
----------
ars : list of numpy arrays
Each array has shape (2, N) and represents a split-VI curve,
with `ars[i][0]` holding the undersegmentation and `ars[i][1]`
holding the oversegmentation for each `i`.
colors : string, list of string, or list of float tuple, optional
A color specification or list of color specifications. If there
are fewer colors than split-VI arrays, the colors are cycled.
markers : string, or list of string, optional
Point marker specification (as defined in matplotlib) or list
thereof. As with colors, if there are fewer markers than VI
arrays, the markers are cycled.
**kwargs : dict (string keys), optional
Keyword arguments to be passed through to
`matplotlib.pyplot.scatter`.
Returns
-------
points : list of `matplotlib.collections.PathCollection`
The points returned by each of the calls to `scatter`.
"""
if type(colors) not in [list, tuple]:
colors = [colors]
if len(colors) < len(ars):
colors = it.cycle(colors)
if type(markers) not in [list, tuple]:
markers = [markers]
if len(markers) < len(ars):
markers = it.cycle(markers)
points = []
for ar, c, m in zip(ars, colors, markers):
opt = ar[:,ar.sum(axis=0).argmin()]
points.append(plt.scatter(opt[0], opt[1], c=c, marker=m, **kwargs))
return points
[docs]def add_nats_to_plot(ars, tss, stops=0.5, colors='k', markers='o', **kwargs):
"""In an existing active split-vi plot, add the natural stopping point.
By default, a circle marker is used.
Parameters
----------
ars : list of numpy arrays
Each array has shape (2, N) and represents a split-VI curve,
with `ars[i][0]` holding the undersegmentation and `ars[i][1]`
holding the oversegmentation for each `i`.
tss : list of numpy arrays
Each array has shape (N,) and represents the algorithm
threshold that gave rise to the VI measurements in `ars`.
stops : float, optional
The natural stopping point for the algorithm. For example, if
an algorithm merges segments according to a merge probability,
the natural stopping point is at $p=0.5$, when there are even
odds of the merge being a true merge.
colors : string, list of string, or list of float tuple, optional
A color specification or list of color specifications. If there
are fewer colors than split-VI arrays, the colors are cycled.
markers : string, or list of string, optional
Point marker specification (as defined in matplotlib) or list
thereof. As with colors, if there are fewer markers than VI
arrays, the markers are cycled.
**kwargs : dict (string keys), optional
Keyword arguments to be passed through to
`matplotlib.pyplot.scatter`.
Returns
-------
points : list of `matplotlib.collections.PathCollection`
The points returned by each of the calls to `scatter`.
"""
if type(colors) not in [list, tuple]: colors = [colors]
if len(colors) < len(ars): colors = it.cycle(colors)
if type(markers) not in [list, tuple]: markers = [markers]
if len(markers) < len(ars): markers = it.cycle(markers)
if type(stops) not in [list, tuple]: stops = [stops]
if len(stops) < len(ars): stops = it.cycle(stops)
points = []
for ar, ts, stop, c, m in zip(ars, tss, stops, colors, markers):
nat = ar[:,np.flatnonzero(ts<stop)[-1]]
points.append(plt.scatter(nat[0], nat[1], c=c, marker=m, **kwargs))
return points
[docs]def plot_split_vi(ars, best=None, colors='k', linespecs='-', **kwargs):
"""Make a split-VI plot.
The split-VI plot was introduced in Nunez-Iglesias et al, 2013 [1]
Parameters
----------
ars : array or list of arrays of float, shape (2, N)
The input VI arrays. `ars[i][0]` should contain the
undersegmentation and `ars[i][1]` the oversegmentation.
best : array-like of float, len=2, optional
Agglomerative segmentations can't get to (0, 0) VI if the
starting superpixels are not perfectly aligned with the gold
standard segmentation. Therefore, there is a point of best
achievable VI. `best` should contain the coordinates of this
point.
colors : matplotlib color specification or list thereof, optional
The color of each line being plotted. If there are fewer colors
than arrays, they are cycled.
linespecs : matplotlib line type spec, or list thereof, optional
The line type to plot with ('-', '--', '-.', etc).
kwargs : dict, string keys, optional
Additional keyword arguments to pass through to plt.plot.
Returns
-------
lines : matplotlib Lines2D object(s)
The lines plotted.
"""
if type(ars) not in [list, tuple]: ars = [ars]
if type(colors) not in [list, tuple]: colors = [colors]
if len(colors) < len(ars): colors = it.cycle(colors)
if type(linespecs) not in [list, tuple]: linespecs = [linespecs]
if len(linespecs) < len(ars): linespecs = it.cycle(linespecs)
lines = []
for ar, color, linespec in zip(ars, colors, linespecs):
lines.append(plt.plot(ar[0], ar[1], c=color, ls=linespec, **kwargs))
if best is not None:
lines.append(plt.scatter(
best[0], best[1],
c=kwargs.get('best-color', 'k'), marker=(5,3,0), **kwargs)
)
return lines
[docs]def plot_decision_function(clf, data_range=None,
features=None, labels=None, feature_columns=[0, 1],
n_gridpoints=201):
"""Plot the decision function of a classifier in 2D.
Parameters
----------
clf : scikit-learn classifier
The classifier to be evaluated.
data_range : tuple of int, optional
The range of values to be evaluated.
features : 2D array of float, optional
The features of the training data.
labels : 1D array of int, optional
The labels of the training data.
feature_columns : tuple of int, optional
Which feature columns to plot, if there are more than two.
n_gridpoints : int, optional
The number of points to place on each dimension of the 2D grid.
"""
if features is not None:
features = features[:, feature_columns]
minfeat, maxfeat = np.min(features), np.max(features)
featrange = maxfeat - minfeat
if data_range is None:
if features is None:
data_range = (0, 1)
else:
data_range = (minfeat - 0.05 * featrange,
maxfeat + 0.05 * featrange)
data_range = np.array(data_range)
grid = np.linspace(*data_range, num=n_gridpoints, endpoint=True)
rr, cc = np.meshgrid(grid, grid, sparse=False)
feature_space = np.hstack((np.reshape(rr, (-1, 1)),
np.reshape(cc, (-1, 1))))
prediction = clf.predict_proba(feature_space)[:, 1] # Pr(class(X)=1)
prediction = np.reshape(prediction, (n_gridpoints, n_gridpoints))
fig, ax = plt.subplots()
ax.imshow(prediction, cmap='RdBu')
ax.set_xticks([])
ax.set_yticks([])
features = (features - data_range[0]) / (data_range[1] - data_range[0])
if features is not None:
if labels is not None:
label_colors = cm.viridis(labels.astype(float) / np.max(labels))
else:
label_colors = cm.viridis(np.zeros(features.shape[0]))
ax.scatter(*(features.T * n_gridpoints), c=label_colors)
plt.show()
def plot_seeds(raw_image, seed_image, ax=None):
if ax is None:
fig, ax = plt.subplots(1, 1)
ax.imshow(raw_image, cmap='gray', interpolation='nearest')
plt.autoscale = False
ax.plot(*np.nonzero(seed_image)[::-1], 'r.')