from __future__ import absolute_import
# built-ins
from itertools import combinations, repeat, product
import itertools as it
import argparse
import random
import logging
import json
from copy import deepcopy
from math import isnan
# libraries
from numpy import (array, mean, zeros, zeros_like, uint8, where, unique,
double, newaxis, nonzero, median, exp, log2, float, ones, arange, inf,
flatnonzero, sign, unravel_index, bincount)
import numpy as np
from scipy.stats import sem
from scipy.sparse import lil_matrix
from scipy.misc import comb as nchoosek
from scipy.ndimage.measurements import label
from networkx import Graph, biconnected_components
from networkx.algorithms.traversal.depth_first_search import dfs_preorder_nodes
from skimage.segmentation import relabel_sequential
from viridis import tree
# local modules
from . import morpho
from . import iterprogress as ip
from . import optimized as opt
from .ncut import ncutW
from .mergequeue import MergeQueue
from .evaluate import contingency_table as ev_contingency_table, split_vi, xlogx
from . import features
from . import classify
from .classify import get_classifier, \
unique_learning_data_elements, concatenate_data_elements
from six.moves import map
from six.moves import range
from six.moves import zip
def contingency_table(a, b):
ct = ev_contingency_table(a, b)
nx, ny = ct.shape
ctout = np.zeros((2*nx + 1, ny), ct.dtype)
ct.todense(out=ctout[:nx, :])
return ctout
arguments = argparse.ArgumentParser(add_help=False)
arggroup = arguments.add_argument_group('Agglomeration options')
arggroup.add_argument('-t', '--thresholds', nargs='+', default=[128],
type=float, metavar='FLOAT',
help='''The agglomeration thresholds. One output file will be written
for each threshold.'''
)
arggroup.add_argument('-l', '--ladder', type=int, metavar='SIZE',
help='Merge any bodies smaller than SIZE.'
)
arggroup.add_argument('-p', '--pre-ladder', action='store_true', default=True,
help='Run ladder before normal agglomeration (default).'
)
arggroup.add_argument('-L', '--post-ladder',
action='store_false', dest='pre_ladder',
help='Run ladder after normal agglomeration instead of before (SLOW).'
)
arggroup.add_argument('-s', '--strict-ladder', type=int, metavar='INT',
default=1,
help='''Specify the strictness of the ladder agglomeration. Level 1
(default): merge anything smaller than the ladder threshold as
long as it's not on the volume border. Level 2: only merge smaller
bodies to larger ones. Level 3: only merge when the border is
larger than or equal to 2 pixels.'''
)
arggroup.add_argument('-M', '--low-memory', action='store_true',
help='''Use less memory at a slight speed cost. Note that the phrase
'low memory' is relative.'''
)
arggroup.add_argument('--disallow-shared-boundaries', action='store_false',
dest='allow_shared_boundaries',
help='''Watershed pixels that are shared between more than 2 labels are
not counted as edges.'''
)
arggroup.add_argument('--allow-shared-boundaries', action='store_true',
default=True,
help='''Count every watershed pixel in every edge in which it participates
(default: True).'''
)
[docs]def conditional_countdown(seq, start=1, pred=bool):
"""Count down from 'start' each time pred(elem) is true for elem in seq.
Used to know how many elements of a sequence remain that satisfy a
predicate.
Parameters
----------
seq : iterable
Any sequence.
start : int, optional
The starting element.
pred : function, type(next(seq)) -> bool
A predicate acting on the elements of `seq`.
Examples
--------
>>> seq = range(10)
>>> cc = conditional_countdown(seq, start=5, pred=lambda x: x % 2 == 1)
>>> next(cc)
5
>>> next(cc)
4
>>> next(cc)
4
>>> next(cc)
3
"""
remaining = start
for elem in seq:
if pred(elem):
remaining -= 1
yield remaining
############################
# Merge priority functions #
############################
def oriented_boundary_mean(g, n1, n2):
return mean(g.oriented_probabilities_r[g[n1][n2]['boundary']])
def boundary_mean(g, n1, n2):
return mean(g.probabilities_r[g[n1][n2]['boundary']])
def boundary_median(g, n1, n2):
return median(g.probabilities_r[g[n1][n2]['boundary']])
[docs]def approximate_boundary_mean(g, n1, n2):
"""Return the boundary mean as computed by a MomentsFeatureManager.
The feature manager is assumed to have been set up for g at construction.
"""
return g.feature_manager.compute_edge_features(g, n1, n2)[1]
def make_ladder(priority_function, threshold, strictness=1):
def ladder_function(g, n1, n2):
s1 = g.node[n1]['size']
s2 = g.node[n2]['size']
ladder_condition = \
(s1 < threshold and not g.at_volume_boundary(n1)) or \
(s2 < threshold and not g.at_volume_boundary(n2))
if strictness >= 2:
ladder_condition &= ((s1 < threshold) != (s2 < threshold))
if strictness >= 3:
ladder_condition &= len(g[n1][n2]['boundary']) > 2
if ladder_condition:
return priority_function(g, n1, n2)
else:
return inf
return ladder_function
def no_mito_merge(priority_function):
def predict(g, n1, n2):
frozen = (n1 in g.frozen_nodes or
n2 in g.frozen_nodes or
(n1, n2) in g.frozen_edges)
if frozen:
return np.inf
else:
return priority_function(g, n1, n2)
return predict
def mito_merge():
def predict(g, n1, n2):
if n1 in g.frozen_nodes and n2 in g.frozen_nodes:
return np.inf
elif (n1, n2) in g.frozen_edges:
return np.inf
elif n1 not in g.frozen_nodes and n2 not in g.frozen_nodes:
return np.inf
else:
if n1 in g.frozen_nodes:
mito = n1
cyto = n2
else:
mito = n2
cyto = n1
if g.node[mito]['size'] > g.node[cyto]['size']:
return np.inf
else:
return 1.0 - (float(len(g[mito][cyto]['boundary']))/
sum([len(g[mito][x]['boundary']) for x in g.neighbors(mito)]))
return predict
def classifier_probability(feature_extractor, classifier):
def predict(g, n1, n2):
if n1 == g.boundary_body or n2 == g.boundary_body:
return inf
features = np.atleast_2d(feature_extractor(g, n1, n2))
try:
prediction = classifier.predict_proba(features)
prediction_arr = np.array(prediction, copy=False)
if prediction_arr.ndim > 2:
prediction_arr = prediction_arr[0]
try:
prediction = prediction_arr[0][1]
except (TypeError, IndexError):
prediction = prediction_arr[0]
except AttributeError:
prediction = classifier.predict(features)[0]
return prediction
return predict
def ordered_priority(edges):
d = {}
n = len(edges)
for i, (n1, n2) in enumerate(edges):
score = float(i)/n
d[(n1,n2)] = score
d[(n2,n1)] = score
def ord(g, n1, n2):
return d.get((n1,n2), inf)
return ord
def expected_change_vi(feature_extractor, classifier, alpha=1.0, beta=1.0):
prob_func = classifier_probability(feature_extractor, classifier)
def predict(g, n1, n2):
p = prob_func(g, n1, n2) # Prediction from the classifier
# Calculate change in VI if n1 and n2 should not be merged
v = compute_local_vi_change(
g.node[n1]['size'], g.node[n2]['size'], g.volume_size
)
# Return expected change
return (p*alpha*v + (1.0-p)*(-beta*v))
return predict
[docs]def compute_local_vi_change(s1, s2, n):
"""Compute change in VI if we merge disjoint sizes s1,s2 in a volume n."""
py1 = float(s1)/n
py2 = float(s2)/n
py = py1+py2
return -(py1*log2(py1) + py2*log2(py2) - py*log2(py))
def compute_true_delta_vi(ctable, n1, n2):
p1 = ctable[n1].sum()
p2 = ctable[n2].sum()
p3 = p1+p2
p1g_log_p1g = xlogx(ctable[n1]).sum()
p2g_log_p2g = xlogx(ctable[n2]).sum()
p3g_log_p3g = xlogx(ctable[n1]+ctable[n2]).sum()
return p3*log2(p3) - p1*log2(p1) - p2*log2(p2) - \
2*(p3g_log_p3g - p1g_log_p1g - p2g_log_p2g)
def expected_change_rand(feature_extractor, classifier, alpha=1.0, beta=1.0):
prob_func = classifier_probability(feature_extractor, classifier)
def predict(g, n1, n2):
p = float(prob_func(g, n1, n2)) # Prediction from the classifier
v = compute_local_rand_change(
g.node[n1]['size'], g.node[n2]['size'], g.volume_size
)
return p*v*alpha + (1.0-p)*(-beta*v)
return predict
[docs]def compute_local_rand_change(s1, s2, n):
"""Compute change in rand if we merge disjoint sizes s1,s2 in volume n."""
return float(s1*s2)/nchoosek(n,2)
[docs]def compute_true_delta_rand(ctable, n1, n2, n):
"""Compute change in RI obtained by merging rows n1 and n2.
This function assumes ctable is normalized to sum to 1.
"""
localct = n*ctable[(n1,n2),]
delta_sxy = 1.0/2*((localct.sum(axis=0)**2).sum()-(localct**2).sum())
delta_sx = 1.0/2*(localct.sum()**2 - (localct.sum(axis=1)**2).sum())
return (2*delta_sxy - delta_sx) / nchoosek(n,2)
def boundary_mean_ladder(g, n1, n2, threshold, strictness=1):
f = make_ladder(boundary_mean, threshold, strictness)
return f(g, n1, n2)
def boundary_mean_plus_sem(g, n1, n2, alpha=-6):
bvals = g.probabilities_r[g[n1][n2]['boundary']]
return mean(bvals) + alpha*sem(bvals)
def random_priority(g, n1, n2):
if n1 == g.boundary_body or n2 == g.boundary_body:
return inf
return random.random()
[docs]class Rag(Graph):
"""Region adjacency graph for segmentation of nD volumes.
Parameters
----------
watershed : array of int, shape (M, N, ..., P)
The labeled regions of the image. Note: this is called
`watershed` for historical reasons, but could refer to a
superpixel map of any origin.
probabilities : array of float, shape (M, N, ..., P[, Q])
The probability of each pixel of belonging to a particular
class. Typically, this has the same shape as `watershed`
and represents the probability that the pixel is part of a
region boundary, but it can also have an additional
dimension for probabilities of belonging to other classes,
such as mitochondria (in biological images) or specific
textures (in natural images).
merge_priority_function : callable function, optional
This function must take exactly three arguments as input
(a Rag object and two node IDs) and return a single float.
feature_manager : ``features.base.Null`` object, optional
A feature manager object that controls feature computation
and feature caching.
mask : array of bool, shape (M, N, ..., P)
A mask of the same shape as `watershed`, `True` in the
positions to be processed when making a RAG, `False` in the
positions to ignore.
show_progress : bool, optional
Whether to display an ASCII progress bar during long-
-running graph operations.
connectivity : int in {1, ..., `watershed.ndim`}
When determining adjacency, allow neighbors along
`connectivity` dimensions.
channel_is_oriented : array-like of bool, shape (Q,), optional
For multi-channel images, some channels, for example some
edge detectors, have a specific orientation. In conjunction
with the `orientation_map` argument, specify which channels
have an orientation associated with them.
orientation_map : array-like of float, shape (Q,)
Specify the orientation of the corresponding channel. (2D
images only)
normalize_probabilities : bool, optional
Divide the input `probabilities` by their maximum to ensure
a range in [0, 1].
exclusions : array-like of int, shape (M, N, ..., P), optional
Volume of same shape as `watershed`. Mark points in the
volume with the same label (>0) to prevent them from being
merged during agglomeration. For example, if
`exclusions[45, 92] == exclusions[51, 105] == 1`, then
segments `watershed[45, 92]` and `watershed[51, 105]` will
never be merged, regardless of the merge priority function.
isfrozennode : function, optional
Function taking in a Rag object and a node id and returning
a bool. If the function returns ``True``, the node will not
be merged, regardless of the merge priority function.
isfrozenedge : function, optional
As `isfrozennode`, but the function should take the graph
and *two* nodes, to specify an edge that cannot be merged.
"""
def __init__(self, watershed=array([], int), probabilities=array([]),
merge_priority_function=boundary_mean, gt_vol=None,
feature_manager=features.base.Null(), mask=None,
show_progress=False, connectivity=1,
channel_is_oriented=None, orientation_map=array([]),
normalize_probabilities=False, exclusions=array([]),
isfrozennode=None, isfrozenedge=None):
super(Rag, self).__init__(weighted=False)
self.show_progress = show_progress
self.connectivity = connectivity
self.pbar = (ip.StandardProgressBar() if self.show_progress
else ip.NoProgressBar())
self.set_watershed(watershed, connectivity)
self.set_probabilities(probabilities, normalize_probabilities)
self.set_orientations(orientation_map, channel_is_oriented)
self.merge_priority_function = merge_priority_function
self.max_merge_score = -inf
if mask is None:
self.mask = np.ones(self.watershed_r.shape, dtype=bool)
else:
self.mask = morpho.pad(mask, True).ravel()
self.build_graph_from_watershed()
self.set_feature_manager(feature_manager)
self.set_ground_truth(gt_vol)
self.set_exclusions(exclusions)
self.merge_queue = MergeQueue()
self.tree = tree.Ultrametric(self.nodes())
self.frozen_nodes = set()
if isfrozennode is not None:
for node in self.nodes():
if isfrozennode(self, node):
self.frozen_nodes.add(node)
self.frozen_edges = set()
if isfrozenedge is not None:
for n1, n2 in self.edges():
if isfrozenedge(self, n1, n2):
self.frozen_edges.add((n1,n2))
for nodeid in self.nodes():
del self.node[nodeid]["extent"]
def __copy__(self):
"""Return a copy of the object and attributes.
"""
pr_shape = self.probabilities_r.shape
g = super(Rag, self).copy()
g.watershed_r = g.watershed.ravel()
g.probabilities_r = g.probabilities.reshape(pr_shape)
return g
[docs] def copy(self):
"""Return a copy of the object and attributes.
"""
return self.__copy__()
def extent(self, nodeid):
if 'extent' in self.node[nodeid]:
return self.node[nodeid]['extent']
extent_array = opt.flood_fill(self.watershed,
np.array(self.node[nodeid]['entrypoint']),
np.array(self.node[nodeid]['watershed_ids']))
if len(extent_array) != self.node[nodeid]['size']:
sys.stderr.write('Flood fill fail - found %d voxels but size'
'expected %d\n' %
(len(extent_array), self.node[nodeid]['size']))
raveled_indices = np.ravel_multi_index(extent_array.T,
self.watershed.shape)
return set(raveled_indices)
[docs] def real_edges(self, *args, **kwargs):
"""Return edges internal to the volume.
The RAG actually includes edges to a "virtual" region that
envelops the entire volume. This function returns the list of
edges that are internal to the volume.
Parameters
----------
*args, **kwargs : arbitrary types
Arguments and keyword arguments are passed through to the
``edges()`` function of the ``networkx.Graph`` class.
Returns
-------
edge_list : list of tuples
A list of pairs of node IDs, which are typically integers.
See Also
--------
real_edges_iter, networkx.Graph.edges
"""
return [e for e in super(Rag, self).edges(*args, **kwargs) if
self.boundary_body not in e[:2]]
[docs] def real_edges_iter(self, *args, **kwargs):
"""Return iterator of edges internal to the volume.
The RAG actually includes edges to a "virtual" region that
envelops the entire volume. This function returns the list of
edges that are internal to the volume.
Parameters
----------
*args, **kwargs : arbitrary types
Arguments and keyword arguments are passed through to the
``edges()`` function of the ``networkx.Graph`` class.
Returns
-------
edges_iter : iterator of tuples
An iterator over pairs of node IDs, which are typically
integers.
"""
return (e for e in super(Rag, self).edges_iter(*args, **kwargs) if
self.boundary_body not in e[:2])
[docs] def build_graph_from_watershed(self, idxs=None):
"""Build the graph object from the region labels.
The region labels should have been set ahead of time using
``set_watershed()``.
Parameters
----------
idxs : array-like of int, optional
Linear indices into raveled volume array. If provided, the
graph is built only for these indices.
"""
if self.watershed.size == 0:
return # stop processing for empty graphs
if idxs is None:
idxs = arange(self.watershed.size, dtype=self.steps.dtype)
self.add_node(self.boundary_body,
extent=flatnonzero(self.watershed==self.boundary_body))
inner_idxs = idxs[self.watershed_r[idxs] != self.boundary_body]
inner_idxs = inner_idxs[self.mask[inner_idxs]] # use only masked idxs
labels = np.unique(self.watershed_r[inner_idxs])
sizes = np.bincount(self.watershed_r[inner_idxs])
for nodeid in labels:
self.add_node(nodeid)
node = self.node[nodeid]
node['size'] = sizes[nodeid]
node['extent'] = np.zeros(sizes[nodeid], dtype=inner_idxs.dtype)
node['visited'] = 0 # number of idxs seen so far
node['watershed_ids'] = [nodeid]
if self.show_progress:
inner_idxs = ip.with_progress(inner_idxs, title='Graph ',
pbar=self.pbar)
for idx in inner_idxs:
nodeid = self.watershed_r[idx]
node = self.node[nodeid]
if 'entrypoint' not in node: # node not initialised
node['entrypoint'] = np.array(
np.unravel_index(idx, self.watershed.shape))
node['extent'][node['visited']] = idx
node['visited'] += 1
ns = idx + self.steps
ns = ns[self.mask[ns]]
adj = self.watershed_r[ns]
adj = set(adj)
for v in adj:
if v == nodeid:
continue
if self.has_edge(nodeid, v):
self[nodeid][v]['boundary'].append(idx)
else:
self.add_edge(nodeid, v, boundary=[idx])
[docs] def set_feature_manager(self, feature_manager):
"""Set the feature manager and ensure feature caches are computed.
Parameters
----------
feature_manager : ``features.base.Null`` object
The feature manager to be used by this RAG.
Returns
-------
None
"""
self.feature_manager = feature_manager
self.compute_feature_caches()
[docs] def compute_feature_caches(self):
"""Use the feature manager to compute node and edge feature caches.
Parameters
----------
None
Returns
-------
None
"""
for n in ip.with_progress(
self.nodes(), title='Node caches ', pbar=self.pbar):
self.node[n]['feature-cache'] = \
self.feature_manager.create_node_cache(self, n)
for n1, n2 in ip.with_progress(
self.edges(), title='Edge caches ', pbar=self.pbar):
self[n1][n2]['feature-cache'] = \
self.feature_manager.create_edge_cache(self, n1, n2)
[docs] def set_probabilities(self, probs=array([]), normalize=False):
"""Set the `probabilities` attributes of the RAG.
For various reasons, including removing the need for bounds
checking when looking for neighboring pixels, the volume of
pixel-level probabilities is padded on all faces. In addition,
this function adds an attribute `probabilities_r`, a raveled
view of the padded probabilities array for quick access to
individual voxels using linear indices.
Parameters
----------
probs : array
The input probabilities array.
normalize : bool, optional
If ``True``, the values in the array are scaled to be in
[0, 1].
Returns
-------
None
"""
if len(probs) == 0:
self.probabilities = zeros_like(self.watershed)
self.probabilities_r = self.probabilities.ravel()
probs = probs.astype(double)
if normalize and len(probs) > 1:
probs -= probs.min() # ensure probs.min() == 0
probs /= probs.max() # ensure probs.max() == 1
sp = probs.shape
sw = tuple(array(self.watershed.shape, dtype=int)-\
2*self.pad_thickness*ones(self.watershed.ndim, dtype=int))
p_ndim = probs.ndim
w_ndim = self.watershed.ndim
padding = [inf]+(self.pad_thickness-1)*[0]
if p_ndim == w_ndim:
self.probabilities = morpho.pad(probs, padding)
self.probabilities_r = self.probabilities.ravel()[:,newaxis]
elif p_ndim == w_ndim+1:
axes = list(range(p_ndim-1))
self.probabilities = morpho.pad(probs, padding, axes)
self.probabilities_r = self.probabilities.reshape(
(self.watershed.size, -1))
[docs] def set_orientations(self, orientation_map, channel_is_oriented):
"""Set the orientation map of the probability image.
Parameters
----------
orientation_map : array of float
A map of angles of the same shape as the superpixel map.
channel_is_oriented : 1D array-like of bool
A vector having length the number of channels in the
probability map.
Returns
-------
None
"""
if len(orientation_map) == 0:
self.orientation_map = zeros_like(self.watershed)
self.orientation_map_r = self.orientation_map.ravel()
padding = [0]+(self.pad_thickness-1)*[0]
self.orientation_map = morpho.pad(orientation_map, padding).astype(int)
self.orientation_map_r = self.orientation_map.ravel()
if channel_is_oriented is None:
nchannels = 1 if self.probabilities.ndim==self.watershed.ndim \
else self.probabilities.shape[-1]
self.channel_is_oriented = array([False]*nchannels)
self.max_probabilities_r = zeros_like(self.probabilities_r)
self.oriented_probabilities_r = zeros_like(self.probabilities_r)
self.non_oriented_probabilities_r = self.probabilities_r
else:
self.channel_is_oriented = channel_is_oriented
self.max_probabilities_r = \
self.probabilities_r[:, self.channel_is_oriented].max(axis=1)
self.oriented_probabilities_r = \
self.probabilities_r[:, self.channel_is_oriented]
self.oriented_probabilities_r = \
self.oriented_probabilities_r[
list(range(len(self.oriented_probabilities_r))),
self.orientation_map_r]
self.non_oriented_probabilities_r = \
self.probabilities_r[:, ~self.channel_is_oriented]
[docs] def set_watershed(self, ws=array([], int), connectivity=1):
"""Set the initial segmentation volume (watershed).
The initial segmentation is called `watershed` for historical
reasons only.
Parameters
----------
ws : array of int
The initial segmentation.
connectivity : int in {1, ..., `ws.ndim`}, optional
The pixel neighborhood.
Returns
-------
None
"""
if not np.issubdtype(ws.dtype, np.integer):
ws = ws.astype(morpho.smallest_int_dtype(np.max(ws),
signed=np.min(ws) < 0))
try:
self.boundary_body = ws.max()+1
except ValueError: # empty watershed given
self.boundary_body = -1
self.volume_size = ws.size
if ws.size > 0:
ws, _, inv = relabel_sequential(ws)
self.inverse_watershed_map = inv # translates to original labels
self.watershed = morpho.pad(ws, self.boundary_body)
self.watershed_r = self.watershed.ravel()
self.pad_thickness = 1
self.steps = morpho.raveled_steps_to_neighbors(self.watershed.shape,
connectivity)
[docs] def set_ground_truth(self, gt=None):
"""Set the ground truth volume.
This is useful for tracking segmentation accuracy over time.
Parameters
----------
gt : array of int
A ground truth segmentation of the same volume passed to
``set_watershed``.
Returns
-------
None
"""
if gt is not None:
gtm = gt.max()+1
gt_ignore = [0, gtm] if (gt==0).any() else [gtm]
seg_ignore = [0, self.boundary_body] if \
(self.watershed==0).any() else [self.boundary_body]
self.gt = morpho.pad(gt, gt_ignore)
self.rig = contingency_table(self.watershed, self.gt,
ignore_seg=seg_ignore,
ignore_gt=gt_ignore)
else:
self.gt = None
# null pattern to transparently allow merging of nodes.
# Bonus feature: counts how many sp's went into a single node.
try:
self.rig = ones(2 * self.watershed.max() + 1)
except ValueError:
self.rig = ones(2 * self.number_of_nodes() + 1)
[docs] def set_exclusions(self, excl):
"""Set an exclusion volume, forbidding certain merges.
Parameters
----------
excl : array of int
Exclusions work as follows: the volume `excl` is the same
shape as the initial segmentation (see ``set_watershed``),
and consists of mostly 0s. Any voxels with *the same*
non-zero label will not be allowed to merge during
agglomeration (provided they were not merged in the initial
segmentation).
This allows manual separation *a priori* of difficult-to-
-segment regions.
Returns
-------
None
"""
if excl.size != 0:
excl = morpho.pad(excl, [0]*self.pad_thickness)
for n in self.nodes():
if excl.size != 0:
eids = unique(excl.ravel()[self.extent(n)])
eids = eids[flatnonzero(eids)]
self.node[n]['exclusions'] = set(list(eids))
else:
self.node[n]['exclusions'] = set()
[docs] def build_merge_queue(self):
"""Build a queue of node pairs to be merged in a specific priority.
Parameters
----------
None
Returns
-------
mq : MergeQueue object
A MergeQueue is a Python ``deque`` with a specific element
structure: a list of length 4 containing:
- the merge priority (any ordered type)
- a 'valid' flag
- and the two nodes in arbitrary order
The valid flag allows one to "remove" elements from the
queue in O(1) time by setting the flag to ``False``. Then,
one checks the flag when popping elements and ignores those
marked as invalid.
One other specific feature is that there are back-links from
edges to their corresponding queue items so that when nodes
are merged, affected edges can be invalidated and reinserted
in the queue with a new priority.
"""
queue_items = []
for l1, l2 in self.real_edges_iter():
w = self.merge_priority_function(self,l1,l2)
qitem = [w, True, l1, l2]
queue_items.append(qitem)
self[l1][l2]['qlink'] = qitem
self[l1][l2]['weight'] = w
return MergeQueue(queue_items, with_progress=self.show_progress)
[docs] def rebuild_merge_queue(self):
"""Build a merge queue from scratch and assign to self.merge_queue.
See Also
--------
build_merge_queue
"""
self.merge_queue = self.build_merge_queue()
[docs] def agglomerate(self, threshold=0.5, save_history=False):
"""Merge nodes hierarchically until given edge confidence threshold.
This is the main workhorse of the ``agglo`` module!
Parameters
----------
threshold : float, optional
The edge priority at which to stop merging.
save_history : bool, optional
Whether to save and return a history of all the merges made.
Returns
-------
history : list of tuple of int, optional
The ordered history of node pairs merged.
scores : list of float, optional
The list of merge scores corresponding to the `history`.
evaluation : list of tuple, optional
The split VI after each merge. This is only meaningful if
a ground truth volume was provided at build time.
Notes
-----
This function returns ``None`` when `save_history` is
``False``.
"""
if self.merge_queue.is_empty():
self.merge_queue = self.build_merge_queue()
history, scores, evaluation = [], [], []
while len(self.merge_queue) > 0 and \
self.merge_queue.peek()[0] < threshold:
merge_priority, _, n1, n2 = self.merge_queue.pop()
self.update_frozen_sets(n1, n2)
self.merge_nodes(n1, n2, merge_priority)
if save_history:
history.append((n1,n2))
scores.append(merge_priority)
evaluation.append(
(self.number_of_nodes()-1, self.split_vi())
)
if save_history:
return history, scores, evaluation
[docs] def agglomerate_count(self, stepsize=100, save_history=False):
"""Agglomerate until 'stepsize' merges have been made.
This function is like ``agglomerate``, but rather than to a
certain threshold, a certain number of merges are made,
regardless of threshold.
Parameters
----------
stepsize : int, optional
The number of merges to make.
save_history : bool, optional
Whether to save and return a history of all the merges made.
Returns
-------
history : list of tuple of int, optional
The ordered history of node pairs merged.
scores : list of float, optional
The list of merge scores corresponding to the `history`.
evaluation : list of tuple, optional
The split VI after each merge. This is only meaningful if
a ground truth volume was provided at build time.
Notes
-----
This function returns ``None`` when `save_history` is
``False``.
See Also
--------
agglomerate
"""
if self.merge_queue.is_empty():
self.merge_queue = self.build_merge_queue()
history, evaluation = [], []
i = 0
for i in range(stepsize):
if len(self.merge_queue) == 0:
break
merge_priority, _, n1, n2 = self.merge_queue.pop()
i += 1
self.merge_nodes(n1, n2, merge_priority)
if save_history:
history.append((n1, n2))
evaluation.append(
(self.number_of_nodes()-1, self.split_vi())
)
if save_history:
return history, evaluation
[docs] def agglomerate_ladder(self, min_size=1000, strictness=2):
"""Merge sequentially all nodes smaller than `min_size`.
Parameters
----------
min_size : int, optional
The smallest allowable segment after ladder completion.
strictness : {1, 2, 3}, optional
`strictness == 1`: all nodes smaller than `min_size` are
merged according to the merge priority function.
`strictness == 2`: in addition to `1`, small nodes can only
be merged to big nodes.
`strictness == 3`: in addition to `2`, nodes sharing less
than one pixel of boundary are not agglomerated.
Returns
-------
None
Notes
-----
Nodes that are on the volume boundary are not agglomerated.
"""
original_merge_priority_function = self.merge_priority_function
self.merge_priority_function = make_ladder(
self.merge_priority_function, min_size, strictness
)
self.rebuild_merge_queue()
self.agglomerate(inf)
self.merge_priority_function = original_merge_priority_function
self.merge_queue.finish()
self.rebuild_merge_queue()
max_score = max([qitem[0] for qitem in self.merge_queue.q])
for n in self.tree.nodes():
self.tree.node[n]['w'] -= max_score
[docs] def learn_agglomerate(self, gts, feature_map,
min_num_samples=1,
learn_flat=True,
learning_mode='strict',
labeling_mode='assignment',
priority_mode='active',
memory=True,
unique=True,
random_state=None,
max_num_epochs=10,
min_num_epochs=2,
max_num_samples=np.inf,
classifier='random forest',
active_function=classifier_probability,
mpf=boundary_mean):
"""Agglomerate while comparing to ground truth & classifying merges.
Parameters
----------
gts : array of int or list thereof
The ground truth volume(s) corresponding to the current
probability map.
feature_map : function (Rag, node, node) -> array of float
The map from node pairs to a feature vector. This must
consist either of uncached features or of the cache used
when building the graph.
min_num_samples : int, optional
Continue training until this many training examples have
been collected.
learn_flat : bool, optional
Do a flat learning on the static graph with no
agglomeration.
learning_mode : {'strict', 'loose'}, optional
In 'strict' mode, if a "don't merge" edge is encountered,
it is added to the training set but the merge is not
executed. In 'loose' mode, the merge is allowed to proceed.
labeling_mode : {'assignment', 'vi-sign', 'rand-sign'}, optional
How to decide whether two nodes should be merged based on
the ground truth segmentations. ``'assignment'`` means the
nodes are assigned to the ground truth node with which they
share the highest overlap. ``'vi-sign'`` means the the VI
change of the switch is used (negative is better).
``'rand-sign'`` means the change in Rand index is used
(positive is better).
priority_mode : string, optional
One of:
``'active'``: Train a priority function with the data
from previous epochs to obtain the next.
``'random'``: Merge edges at random.
``'mixed'``: Alternate between epochs of ``'active'``
and ``'random'``.
``'mean'``: Use the mean boundary value. (In this case,
training is limited to 1 or 2 epochs.)
``'custom'``: Use the function provided by `mpf`.
memory : bool, optional
Keep the training data from all epochs (rather than just
the most recent one).
unique : bool, optional
Remove duplicate feature vectors.
random_state : int, optional
If provided, this parameter is passed to `get_classifier`
to set the random state and allow consistent results across
tests.
max_num_epochs : int, optional
Do not train for longer than this (this argument *may*
override the `min_num_samples` argument).
min_num_epochs : int, optional
Train for no fewer than this number of epochs.
max_num_samples : int, optional
Train for no more than this number of samples.
classifier : string, optional
Any valid classifier descriptor. See
``gala.classify.get_classifier()``
active_function : function (feat. map, classifier) -> function, optional
Use this to create the next priority function after an
epoch.
mpf : function (Rag, node, node) -> float
A merge priority function to use when ``priority_mode`` is
``'custom'``.
Returns
-------
data : list of array
Four arrays containing:
- the feature vectors, shape ``(n_samples, n_features)``.
- the labels, shape ``(n_samples, 3)``. A value of `-1`
means "should merge", while `1` means "should
not merge". The columns correspond to the three
labeling methods: assignment, VI sign, or RI sign.
- the VI and RI change of each merge, ``(n_edges, 2)``.
- the list of merged edges ``(n_edges, 2)``.
alldata : list of list of array
A list of lists like `data` above: one list for each epoch.
Notes
-----
The gala algorithm [1] uses the default parameters. For the
LASH algorithm [2], use:
- `learning_mode`: ``'loose'``
- `labeling_mode`: ``'rand-sign'``
- `memory`: ``False``
References
----------
.. [1] Nunez-Iglesias et al, Machine learning of hierarchical
clustering to segment 2D and 3D images, PLOS ONE, 2013.
.. [2] Jain et al, Learning to agglomerate superpixel
hierarchies, NIPS, 2011.
See Also
--------
Rag
"""
learning_mode = learning_mode.lower()
labeling_mode = labeling_mode.lower()
priority_mode = priority_mode.lower()
if priority_mode == 'mean' and unique:
max_num_epochs = 2 if learn_flat else 1
if priority_mode in ['random', 'mean'] and not memory:
max_num_epochs = 1
label_type_keys = {'assignment':0, 'vi-sign':1, 'rand-sign':2}
if type(gts) != list:
gts = [gts] # allow using single ground truth as input
master_ctables = \
[contingency_table(self.get_segmentation(), gt) for gt in gts]
alldata = []
data = [[],[],[],[]]
for num_epochs in range(max_num_epochs):
ctables = deepcopy(master_ctables)
if len(data[0]) > min_num_samples and num_epochs >= min_num_epochs:
break
if learn_flat and num_epochs == 0:
alldata.append(self.learn_flat(gts, feature_map))
data = unique_learning_data_elements(alldata) if memory \
else alldata[-1]
continue
g = self.copy()
if priority_mode == 'mean':
g.merge_priority_function = boundary_mean
elif num_epochs > 0 and priority_mode == 'active' or \
num_epochs % 2 == 1 and priority_mode == 'mixed':
cl = get_classifier(classifier, random_state=random_state)
feat, lab = classify.sample_training_data(
data[0], data[1][:, label_type_keys[labeling_mode]],
max_num_samples)
cl = cl.fit(feat, lab)
g.merge_priority_function = active_function(feature_map, cl)
elif priority_mode == 'random' or \
(priority_mode == 'active' and num_epochs == 0):
g.merge_priority_function = random_priority
elif priority_mode == 'custom':
g.merge_priority_function = mpf
g.show_progress = False # bug in MergeQueue usage causes
# progressbar crash.
g.rebuild_merge_queue()
alldata.append(g.learn_epoch(ctables, feature_map,
learning_mode=learning_mode,
labeling_mode=labeling_mode))
if memory:
if unique:
data = unique_learning_data_elements(alldata)
else:
data = concatenate_data_elements(alldata)
else:
data = alldata[-1]
logging.debug('data size %d at epoch %d'%(len(data[0]), num_epochs))
return data, alldata
[docs] def learn_flat(self, gts, feature_map):
"""Learn all edges on the graph, but don't agglomerate.
Parameters
----------
gts : array of int or list thereof
The ground truth volume(s) corresponding to the current
probability map.
feature_map : function (Rag, node, node) -> array of float
The map from node pairs to a feature vector. This must
consist either of uncached features or of the cache used
when building the graph.
Returns
-------
data : list of array
Four arrays containing:
- the feature vectors, shape ``(n_samples, n_features)``.
- the labels, shape ``(n_samples, 3)``. A value of `-1`
means "should merge", while `1` means "should
not merge". The columns correspond to the three
labeling methods: assignment, VI sign, or RI sign.
- the VI and RI change of each merge, ``(n_edges, 2)``.
- the list of merged edges ``(n_edges, 2)``.
See Also
--------
learn_agglomerate
"""
if type(gts) != list:
gts = [gts] # allow using single ground truth as input
ctables = [contingency_table(self.get_segmentation(), gt) for gt in gts]
assignments = [(ct == ct.max(axis=1)[:,newaxis]) for ct in ctables]
return list(map(array, zip(*[
self.learn_edge(e, ctables, assignments, feature_map)
for e in self.real_edges()])))
[docs] def learn_edge(self, edge, ctables, assignments, feature_map):
"""Determine whether an edge should be merged based on ground truth.
Parameters
----------
edge : (int, int) tuple
An edge in the graph.
ctables : list of array
A list of contingency tables determining overlap between the
current segmentation and the ground truth.
assignments : list of array
Similar to the contingency tables, but each row is thresholded
so each segment corresponds to exactly one ground truth segment.
feature_map : function (Rag, node, node) -> array of float
The map from node pairs to a feature vector.
Returns
-------
features : 1D array of float
The feature vector for that edge.
labels : 1D array of float, length 3
The labels determining whether the edge should be merged.
A value of `-1` means "should merge", while `1` means "should
not merge". The columns correspond to the three labeling
methods: assignment, VI sign, or RI sign.
weights : 1D array of float, length 2
The VI and RI change of the merge.
nodes : tuple of int
The given edge.
"""
n1, n2 = edge
features = feature_map(self, n1, n2).ravel()
# Calculate weights for weighting data points
s1, s2 = [self.node[n]['size'] for n in [n1, n2]]
weights = \
compute_local_vi_change(s1, s2, self.volume_size), \
compute_local_rand_change(s1, s2, self.volume_size)
# Get the fraction of times that n1 and n2 assigned to
# same segment in the ground truths
cont_labels = [
[(-1)**(a[n1,:]==a[n2,:]).all() for a in assignments],
[compute_true_delta_vi(ctable, n1, n2) for ctable in ctables],
[-compute_true_delta_rand(ctable, n1, n2, self.volume_size)
for ctable in ctables]
]
labels = [sign(mean(cont_label)) for cont_label in cont_labels]
if any(map(isnan, labels)) or any([label == 0 for l in labels]):
logging.debug('NaN or 0 labels found. ' +
' '.join(map(str, [labels, (n1, n2)])))
labels = [1 if i==0 or isnan(i) or n1 in self.frozen_nodes or
n2 in self.frozen_nodes or (n1, n2) in self.frozen_edges else
i for i in labels]
return features, labels, weights, (n1, n2)
[docs] def learn_epoch(self, ctables, feature_map,
learning_mode='permissive', labeling_mode='assignment'):
"""Learn the agglomeration process using various strategies.
Parameters
----------
ctables : array of float or list thereof
One or more contingency tables between own segments and gold
standard segmentations
feature_map : function (Rag, node, node) -> array of float
The map from node pairs to a feature vector. This must
consist either of uncached features or of the cache used
when building the graph.
learning_mode : {'strict', 'permissive'}, optional
If ``'strict'``, don't proceed with a merge when it goes against
the ground truth. For historical reasons, 'loose' is allowed as
a synonym for 'strict'.
labeling_mode : {'assignment', 'vi-sign', 'rand-sign'}, optional
Which label to use for `learning_mode`. Note that all labels
are saved in the end.
Returns
-------
data : list of array
Four arrays containing:
- the feature vectors, shape ``(n_samples, n_features)``.
- the labels, shape ``(n_samples, 3)``. A value of `-1`
means "should merge", while `1` means "should
not merge". The columns correspond to the three
labeling methods: assignment, VI sign, or RI sign.
- the VI and RI change of each merge, ``(n_edges, 2)``.
- the list of merged edges ``(n_edges, 2)``.
"""
label_type_keys = {'assignment':0, 'vi-sign':1, 'rand-sign':2}
assignments = [(ct == ct.max(axis=1)[:,newaxis]) for ct in ctables]
g = self
data = []
while len(g.merge_queue) > 0:
merge_priority, valid, n1, n2 = g.merge_queue.pop()
dat = g.learn_edge((n1,n2), ctables, assignments, feature_map)
data.append(dat)
label = dat[1][label_type_keys[labeling_mode]]
if learning_mode != 'strict' or label < 0:
node_id = g.merge_nodes(n1, n2, merge_priority)
for ctable, assignment in zip(ctables, assignments):
ctable[node_id] = ctable[n1] + ctable[n2]
ctable[n1] = 0
ctable[n2] = 0
assignment[node_id] = (ctable[node_id] ==
ctable[node_id].max())
assignment[n1] = 0
assignment[n2] = 0
return list(map(array, zip(*data)))
[docs] def replay_merge_history(self, merge_seq, labels=None, num_errors=1):
"""Agglomerate according to a merge sequence, optionally labeled.
Parameters
----------
merge_seq : iterable of pair of int
The sequence of node IDs to be merged.
labels : iterable of int in {-1, 0, 1}, optional
A sequence matching `merge_seq` specifying whether a merge
should take place or not. -1 or 0 mean "should merge", 1
otherwise.
Returns
-------
n : int
Number of elements consumed from `merge_seq`
e : (int, int)
Last merge pair observed.
Notes
-----
The merge sequence and labels *must* be generators if you don't want
to manually keep track of how much has been consumed. The merging
continues until `num_errors` false merges have been encountered, or
until the sequence is fully consumed.
"""
if labels is None:
labels1 = it.repeat(False)
labels2 = it.repeat(False)
else:
labels1 = (label > 0 for label in labels)
labels2 = (label > 0 for label in labels)
counter = it.count()
errors_remaining = conditional_countdown(labels2, num_errors)
nodes = None
for nodes, label, errs, count in \
zip(merge_seq, labels1, errors_remaining, counter):
n1, n2 = nodes
if not label:
self.merge_nodes(n1, n2)
elif errs == 0:
break
return count, nodes
[docs] def rename_node(self, old, new):
"""Rename node `old` to `new`, updating edges and weights.
Parameters
----------
old : int
The node being renamed.
new : int
The new node id.
"""
self.add_node(new, attr_dict=self.node[old])
self.add_edges_from(
[(new, v, self[old][v]) for v in self.neighbors(old)])
for v in self.neighbors(new):
qitem = self[new][v].get('qlink', None)
if qitem is not None:
if qitem[2] == old:
qitem[2] = new
else:
qitem[3] = new
self.remove_node(old)
[docs] def merge_nodes(self, n1, n2, merge_priority=0.0):
"""Merge two nodes, while updating the necessary edges.
Parameters
----------
n1, n2 : int
Nodes determining the edge for which to update the UCM.
merge_priority : float, optional
The merge priority of the merge.
Returns
-------
node_id : int
The id of the node resulting from the merge.
Notes
-----
Additionally, the RIG (region intersection graph), the
contingency matrix to the ground truth (if provided) is
updated.
"""
if len(self.node[n1]['exclusions'] & self.node[n2]['exclusions']) > 0:
return
else:
self.node[n1]['exclusions'].update(self.node[n2]['exclusions'])
w = self[n1][n2].get('weight', merge_priority)
self.node[n1]['size'] += self.node[n2]['size']
self.node[n1]['watershed_ids'] += self.node[n2]['watershed_ids']
self.feature_manager.update_node_cache(self, n1, n2,
self.node[n1]['feature-cache'], self.node[n2]['feature-cache'])
new_neighbors = [n for n in self.neighbors(n2)
if n not in [n1, self.boundary_body]]
for n in new_neighbors:
self.merge_edge_properties((n2, n), (n1, n))
try:
self.merge_queue.invalidate(self[n1][n2]['qlink'])
except KeyError:
pass
node_id = self.tree.merge(n1, n2, w)
self.remove_node(n2)
self.rename_node(n1, node_id)
self.rig[node_id] = self.rig[n1] + self.rig[n2]
self.rig[n1] = 0
self.rig[n2] = 0
return node_id
[docs] def merge_subgraph(self, subgraph=None, source=None):
"""Merge a (typically) connected set of nodes together.
Parameters
----------
subgraph : agglo.Rag, networkx.Graph, or list of int (node id)
A subgraph to merge.
source : int (node id), optional
Merge the subgraph to this node.
Returns
-------
None
"""
if type(subgraph) not in [Rag, Graph]: # input is node list
subgraph = self.subgraph(subgraph)
if len(subgraph) > 0:
node_dfs = list(dfs_preorder_nodes(subgraph, source))
# dfs_preorder_nodes returns iter, convert to list
source_node, other_nodes = node_dfs[0], node_dfs[1:]
for current_node in other_nodes:
self.merge_nodes(source_node, current_node)
[docs] def split_node(self, u, n=2, **kwargs):
"""Use normalized cuts [1] to split a node/segment.
Parameters
----------
u : int (node id)
Which node to split.
n : int, optional
How many segments to split it into.
Returns
-------
None
References
----------
.. [1] Shi, J., and Malik, J. (2000). Normalized cuts and image
segmentation. Pattern Analysis and Machine Intelligence.
"""
node_extent = self.extent(u)
labels = unique(self.watershed_r[node_extent])
self.remove_node(u)
self.build_graph_from_watershed(idxs=node_extent)
self.ncut(num_clusters=n, nodes=labels, **kwargs)
[docs] def merge_edge_properties(self, src, dst):
"""Merge the properties of edge src into edge dst.
Parameters
----------
src, dst : (int, int)
Edges being merged.
Returns
-------
None
"""
u, v = dst
w, x = src
if not self.has_edge(u,v):
self.add_edge(u, v, attr_dict=self[w][x])
else:
self[u][v]['boundary'].extend(self[w][x]['boundary'])
self.feature_manager.update_edge_cache(self, (u, v), (w, x),
self[u][v]['feature-cache'], self[w][x]['feature-cache'])
try:
self.merge_queue.invalidate(self[w][x]['qlink'])
except KeyError:
pass
self.update_merge_queue(u, v)
[docs] def update_merge_queue(self, u, v):
"""Update the merge queue item for edge (u, v). Add new by default.
Parameters
----------
u, v : int (node id)
Edge being updated.
Returns
-------
None
"""
if self.boundary_body in [u, v]:
return
if 'qlink' in self[u][v]:
self.merge_queue.invalidate(self[u][v]['qlink'])
if not self.merge_queue.is_null_queue:
w = self.merge_priority_function(self,u,v)
new_qitem = [w, True, u, v]
self[u][v]['qlink'] = new_qitem
self[u][v]['weight'] = w
self.merge_queue.push(new_qitem)
[docs] def get_segmentation(self, threshold=None):
"""Return the unpadded segmentation represented by the graph.
Remember that the segmentation volume is padded with an
"artificial" segment that envelops the volume. This function
simply removes the wrapping and returns a segmented volume.
Parameters
----------
threshold : float, optional
Get the segmentation at the given threshold. If no
threshold is given, return the segmentation at the current
level of agglomeration.
Returns
-------
seg : array of int
The segmentation of the volume presently represented by the
graph.
"""
if threshold is None:
# a threshold of np.inf is the same as no threshold on the
# tree when getting the map (see below). Thus, using a
# threshold of `None` (the default), we get the segmentation
# implied by the current merge tree.
threshold = np.inf
elif threshold > self.max_merge_score:
# If a higher threshold is required than has been merged, we
# continue the agglomeration until that threshold is hit.
self.agglomerate(threshold)
m = self.tree.get_map(threshold)
seg = m[self.watershed]
if self.pad_thickness > 1: # volume has zero-boundaries
seg = morpho.remove_merged_boundaries(seg, self.connectivity)
return morpho.juicy_center(seg, self.pad_thickness)
[docs] def build_volume(self, nbunch=None):
"""Return the segmentation induced by the graph.
Parameters
----------
nbunch : iterable of int (node id), optional
A list of nodes for which to build the volume. All nodes
are used if this is not provided.
Returns
-------
seg : array of int
The segmentation implied by the graph.
Notes
-----
This function is very similar to ``get_segmentation``, but it
builds the segmentation from the bottom up, rather than using
the currently-stored segmentation.
"""
v = zeros_like(self.watershed)
vr = v.ravel()
if nbunch is None:
nbunch = self.nodes()
for n in nbunch:
vr[self.extent(n)] = n
return morpho.juicy_center(v,self.pad_thickness)
[docs] def build_boundary_map(self, ebunch=None):
"""Return a map of the current merge priority.
Parameters
----------
ebunch : iterable of (int, int), optional
The list of edges for which to build a map. Use all edges
if not provided.
Returns
-------
bm : array of float
The image of the edge weights.
"""
if len(self.merge_queue) == 0:
self.rebuild_merge_queue()
m = zeros(self.watershed.shape, double)
mr = m.ravel()
if ebunch is None:
ebunch = self.real_edges_iter()
ebunch = sorted([(self[u][v]['weight'], u, v) for u, v in ebunch])
for w, u, v in ebunch:
b = self[u][v]['boundary']
mr[b] = w
if hasattr(self, 'ignored_boundary'):
m[self.ignored_boundary] = inf
return morpho.juicy_center(m, self.pad_thickness)
[docs] def remove_obvious_inclusions(self):
"""Merge any nodes with only one edge to their neighbors."""
for n in self.nodes():
if self.degree(n) == 1:
self.merge_nodes(self.neighbors(n)[0], n)
[docs] def remove_inclusions(self):
"""Merge any segments fully contained within other segments.
In 3D EM images, inclusions are not biologically plausible, so
this function can be used to remove them.
Parameters
----------
None
Returns
-------
None
"""
bcc = list(biconnected_components(self))
if len(bcc) > 1:
container = [i for i, s in enumerate(bcc) if
self.boundary_body in s][0]
del bcc[container] # remove the main graph
bcc = list(map(list, bcc))
for cc in bcc:
cc.sort(key=lambda x: self.node[x]['size'], reverse=True)
bcc.sort(key=lambda x: self.node[x[0]]['size'])
for cc in bcc:
self.merge_subgraph(cc, cc[0])
[docs] def orphans(self):
"""List all the nodes that do not touch the volume boundary.
Parameters
----------
None
Returns
-------
orphans : list of int (node id)
A list of node ids.
Notes
-----
"Orphans" are not biologically plausible in EM data, so we can
flag them with this function for further scrutiny.
"""
return [n for n in self.nodes() if not self.at_volume_boundary(n)]
[docs] def compute_orphans(self):
"""Find all the segments that do not touch the volume boundary.
Parameters
----------
None
Returns
-------
orphans : list of int (node id)
A list of node ids.
Notes
-----
This function differs from ``orphans`` in that it does not use
the graph, but rather computes orphans directly from the
segmentation.
"""
return morpho.orphans(self.get_segmentation())
[docs] def is_traversed_by_node(self, n):
"""Determine whether a body traverses the volume.
This is defined as touching the volume boundary at two distinct
locations.
Parameters
----------
n : int (node id)
The node being inspected.
Returns
-------
tr : bool
Whether the segment "traverses" the volume being segmented.
"""
if not self.at_volume_boundary(n) or n == self.boundary_body:
return False
v = zeros(self.watershed.shape, uint8)
v.ravel()[self[n][self.boundary_body]['boundary']] = 1
_, n = label(v, ones([3]*v.ndim))
return n > 1
[docs] def traversing_bodies(self):
"""List all bodies that traverse the volume."""
return [n for n in self.nodes() if self.is_traversed_by_node(n)]
[docs] def non_traversing_bodies(self):
"""List bodies that are not orphans and do not traverse the volume."""
return [n for n in self.nodes() if self.at_volume_boundary(n) and
not self.is_traversed_by_node(n) and n != self.boundary_body]
[docs] def raveler_body_annotations(self, traverse=False):
"""Return JSON-compatible dict formatted for Raveler annotations."""
orphans = self.compute_orphans()
non_traversing_bodies = self.compute_non_traversing_bodies() \
if traverse else []
data = \
[{'status':'not sure', 'comment':'orphan', 'body ID':int(o)}
for o in orphans] +\
[{'status':'not sure', 'comment':'does not traverse',
'body ID':int(n)} for n in non_traversing_bodies]
metadata = {'description':'body annotations', 'file version':2}
return {'data':data, 'metadata':metadata}
[docs] def at_volume_boundary(self, n):
"""Return True if node n touches the volume boundary."""
return self.has_edge(n, self.boundary_body) or n == self.boundary_body
def should_merge(self, n1, n2):
return self.rig[n1].argmax() == self.rig[n2].argmax()
def get_pixel_label(self, n1, n2):
boundary = array(self[n1][n2]['boundary'])
min_idx = boundary[self.probabilities_r[boundary,0].argmin()]
if self.should_merge(n1, n2):
return min_idx, 2
else:
return min_idx, 1
def pixel_labels_array(self, false_splits_only=False):
ar = zeros_like(self.watershed_r)
labels = [self.get_pixel_label(*e) for e in self.real_edges()]
if false_splits_only:
labels = [l for l in labels if l[1] == 2]
ids, ls = list(map(array,zip(*labels)))
ar[ids] = ls.astype(ar.dtype)
return ar.reshape(self.watershed.shape)
def split_vi(self, gt=None):
if self.gt is None and gt is None:
return array([0,0])
elif self.gt is not None:
return split_vi(None, None, self.rig)
else:
return split_vi(self.get_segmentation(), gt, None, [0], [0])
def boundary_indices(self, n1, n2):
return self[n1][n2]['boundary']
[docs] def get_edge_coordinates(self, n1, n2, arbitrary=False):
"""Find where in the segmentation the edge (n1, n2) is most visible."""
return get_edge_coordinates(self, n1, n2, arbitrary)
def write(self, fout, output_format='GraphML'):
if output_format == 'Plaza JSON':
self.write_plaza_json(fout)
else:
raise ValueError('Unsupported output format for agglo.Rag: %s'
% output_format)
[docs] def write_plaza_json(self, fout, synapsejson=None, offsetz=0):
"""Write graph to Steve Plaza's JSON spec."""
json_vals = {}
if synapsejson is not None:
synapse_file = open(synapsejson)
json_vals1 = json.load(synapse_file)
body_count = {}
for item in json_vals1["data"]:
bodyid = ((item["T-bar"])["body ID"])
if bodyid in body_count:
body_count[bodyid] += 1
else:
body_count[bodyid] = 1
for psd in item["partners"]:
bodyid = psd["body ID"]
if bodyid in body_count:
body_count[bodyid] += 1
else:
body_count[bodyid] = 1
json_vals["synapse_bodies"] = []
for body, count in body_count.items():
temp = [body, count]
json_vals["synapse_bodies"].append(temp)
edge_list = [
{'location': list(map(int, self.get_edge_coordinates(i, j)[-1::-1])),
'node1': int(i), 'node2': int(j),
'edge_size': len(self[i][j]['boundary']),
'size1': self.node[i]['size'],
'size2': self.node[j]['size'],
'weight': float(self[i][j]['weight'])}
for i, j in self.real_edges()
]
json_vals['edge_list'] = edge_list
with open(fout, 'w') as f:
json.dump(json_vals, f, indent=4)
[docs] def ncut(self, num_clusters=10, kmeans_iters=5, sigma=255.0*20, nodes=None,
**kwargs):
"""Run normalized cuts on the current set of superpixels.
Keyword arguments:
num_clusters -- number of clusters to compute
kmeans_iters -- # iterations to run kmeans when clustering
sigma -- sigma value when setting up weight matrix
Return value: None
"""
if nodes is None:
nodes = self.nodes()
# Compute weight matrix
W = self.compute_W(self.merge_priority_function, nodes=nodes)
# Run normalized cut
labels, eigvec, eigval = ncutW(W, num_clusters, kmeans_iters, **kwargs)
# Merge nodes that are in same cluster
self.cluster_by_labels(labels, nodes)
[docs] def cluster_by_labels(self, labels, nodes=None):
"""Merge all superpixels with the same label (1 label per 1 sp)"""
if nodes is None:
nodes = array(self.nodes())
if not (len(labels) == len(nodes)):
raise ValueError('Number of labels should be %d but is %d.',
self.number_of_nodes(), len(labels))
for l in unique(labels):
inds = nonzero(labels==l)[0]
nodes_to_merge = nodes[inds]
node1 = nodes_to_merge[0]
for node in nodes_to_merge[1:]:
self.merge_nodes(node1, node)
[docs] def compute_W(self, merge_priority_function, sigma=255.0*20, nodes=None):
""" Computes the weight matrix for clustering"""
if nodes is None:
nodes = array(self.nodes())
n = len(nodes)
nodes2ind = dict(zip(nodes, range(n)))
W = lil_matrix((n,n))
for u, v in self.real_edges(nodes):
try:
i, j = nodes2ind[u], nodes2ind[v]
except KeyError:
continue
w = merge_priority_function(self,u,v)
W[i,j] = W[j,i] = exp(-w**2/sigma)
return W
def update_frozen_sets(self, n1, n2):
self.frozen_nodes.discard(n1)
self.frozen_nodes.discard(n2)
for x, y in self.frozen_edges.copy():
if n2 in [x, y]:
self.frozen_edges.discard((x, y))
if x == n2:
self.frozen_edges.add((n1, y))
if y == n2:
self.frozen_edges.add((x, n1))
[docs]def get_edge_coordinates(g, n1, n2, arbitrary=False):
"""Find where in the segmentation the edge (n1, n2) is most visible."""
boundary = g[n1][n2]['boundary']
if arbitrary:
# quickly get an arbitrary point on the boundary
idx = boundary.pop(); boundary.append(idx)
coords = unravel_index(idx, g.watershed.shape)
else:
boundary_idxs = unravel_index(boundary, g.watershed.shape)
coords = [bincount(dimcoords).argmax() for dimcoords in boundary_idxs]
return array(coords) - g.pad_thickness
def is_mito_boundary(g, n1, n2, channel=2, threshold=0.5):
return max(np.mean(g.probabilities_r[g[n1][n2]['boundary'], c])
for c in channel) > threshold
def is_mito(g, n, channel=2, threshold=0.5):
return max(np.mean(g.probabilities_r[g.extent(n), c])
for c in channel) > threshold
[docs]def best_possible_segmentation(ws, gt):
"""Build the best possible segmentation given a superpixel map."""
cnt = contingency_table(ws, gt)
assignment = cnt == cnt.max(axis=1)[:,newaxis]
hard_assignment = where(assignment.sum(axis=1) > 1)[0]
# currently ignoring hard assignment nodes
assignment[hard_assignment,:] = 0
ws = Rag(ws)
for gt_node in range(1,cnt.shape[1]):
ws.merge_subgraph(where(assignment[:,gt_node])[0])
return ws.get_segmentation()