Source code for gala.classify

#!/usr/bin/env python

# system modules
import os
import logging
import random
import pickle as pck

# libraries
import h5py
import numpy as np
np.seterr(divide='ignore')

from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib

try:
    from vigra.learning import RandomForest as BaseVigraRandomForest
    from vigra.__version__ import version as vigra_version
    vigra_version = tuple(map(int, vigra_version.split('.')))
except ImportError:
    vigra_available = False
else:
    vigra_available = True


[docs]def default_classifier_extension(cl, use_joblib=True): """ Return the default classifier file extension for the given classifier. Parameters ---------- cl : sklearn estimator or VigraRandomForest object A classifier to be saved. use_joblib : bool, optional Whether or not joblib will be used to save the classifier. Returns ------- ext : string File extension Examples -------- >>> cl = RandomForestClassifier() >>> default_classifier_extension(cl) '.classifier.joblib' >>> default_classifier_extension(cl, False) '.classifier' """ if isinstance(cl, VigraRandomForest): return ".classifier.h5" elif use_joblib: return ".classifier.joblib" else: return ".classifier"
[docs]def load_classifier(fn): """Load a classifier previously saved to disk, given a filename. Supported classifier types are: - scikit-learn classifiers saved using either pickle or joblib persistence - vigra random forest classifiers saved in HDF5 format Parameters ---------- fn : string Filename in which the classifier is stored. Returns ------- cl : classifier object cl is one of the supported classifier types; these support at least the standard scikit-learn interface of `fit()` and `predict_proba()` """ if not os.path.exists(fn): raise IOError("No such file or directory: '%s'" % fn) try: with open(fn, 'r') as f: cl = pck.load(f) return cl except (pck.UnpicklingError, UnicodeDecodeError): pass try: cl = joblib.load(fn) return cl except KeyError: pass if vigra_available: cl = VigraRandomForest() try: cl.load_from_disk(fn) return cl except IOError as e: logging.error(e) except RuntimeError as e: logging.error(e) raise IOError("File '%s' does not appear to be a valid classifier file" % fn)
[docs]def save_classifier(cl, fn, use_joblib=True, **kwargs): """Save a classifier to disk. Parameters ---------- cl : classifier object Pickleable object or a classify.VigraRandomForest object. fn : string Writeable path/filename. use_joblib : bool, optional Whether to prefer joblib persistence to pickle. kwargs : keyword arguments Keyword arguments to be passed on to either `pck.dump` or `joblib.dump`. Returns ------- None Notes ----- For joblib persistence, `compress=3` is the default. """ if isinstance(cl, VigraRandomForest): cl.save_to_disk(fn) elif use_joblib: if 'compress' not in kwargs: kwargs['compress'] = 3 joblib.dump(cl, fn, **kwargs) else: with open(fn, 'wb') as f: pck.dump(cl, f, protocol=kwargs.get('protocol', 2))
[docs]def get_classifier(name='random forest', *args, **kwargs): """Return a classifier given a name. Parameters ---------- name : string The name of the classifier, e.g. 'random forest' or 'naive bayes'. *args, **kwargs : Additional arguments to pass to the constructor of the classifier. Returns ------- cl : classifier A classifier object implementing the scikit-learn interface. Raises ------ NotImplementedError If the classifier name is not recognized. Examples -------- >>> cl = get_classifier('random forest', n_estimators=47) >>> isinstance(cl, RandomForestClassifier) True >>> cl.n_estimators 47 >>> from numpy.testing import assert_raises >>> assert_raises(NotImplementedError, get_classifier, 'perfect class') """ name = name.lower() is_random_forest = name.find('random') > -1 and name.find('forest') > -1 is_naive_bayes = name.find('naive') > -1 is_logistic = name.startswith('logis') if vigra_available and is_random_forest: if 'random_state' in kwargs: del kwargs['random_state'] return VigraRandomForest(*args, **kwargs) elif is_random_forest: return DefaultRandomForest(*args, **kwargs) elif is_naive_bayes: from sklearn.naive_bayes import GaussianNB if 'random_state' in kwargs: del kwargs['random_state'] return GaussianNB(*args, **kwargs) elif is_logistic: from sklearn.linear_model import LogisticRegression return LogisticRegression(*args, **kwargs) else: raise NotImplementedError('Classifier "%s" is either not installed ' 'or not implemented in Gala.')
class DefaultRandomForest(RandomForestClassifier): def __init__(self, n_estimators=100, criterion='entropy', max_depth=20, bootstrap=False, random_state=None, n_jobs=-1): super(DefaultRandomForest, self).__init__( n_estimators=n_estimators, criterion=criterion, max_depth=max_depth, bootstrap=bootstrap, random_state=random_state, n_jobs=n_jobs) class VigraRandomForest(object): def __init__(self, ntrees=255, use_feature_importance=False, sample_classes_individually=False): self.rf = BaseVigraRandomForest(treeCount=ntrees, sample_classes_individually=sample_classes_individually) self.use_feature_importance = use_feature_importance self.sample_classes_individually = sample_classes_individually def fit(self, features, labels): features = self.check_features_vector(features) labels = self.check_labels_vector(labels) if self.use_feature_importance: self.oob, self.feature_importance = \ self.rf.learnRFWithFeatureSelection(features, labels) else: self.oob = self.rf.learnRF(features, labels) return self def predict_proba(self, features): features = self.check_features_vector(features) return self.rf.predictProbabilities(features) def predict(self, features): features = self.check_features_vector(features) return self.rf.predictLabels(features) def check_features_vector(self, features): if features.dtype != np.float32: features = features.astype(np.float32) if features.ndim == 1: features = features[np.newaxis, :] return features def check_labels_vector(self, labels): if labels.dtype != np.uint32: if len(np.unique(labels[labels < 0])) == 1 \ and not (labels==0).any(): labels[labels < 0] = 0 else: labels = labels + labels.min() labels = labels.astype(np.uint32) labels = labels.reshape((labels.size, 1)) return labels def save_to_disk(self, fn, rfgroupname='rf'): self.rf.writeHDF5(fn, rfgroupname) attr_list = ['oob', 'feature_importance', 'use_feature_importance', 'feature_description'] f = h5py.File(fn) for attr in attr_list: if hasattr(self, attr): f[rfgroupname].attrs[attr] = getattr(self, attr) def load_from_disk(self, fn, rfgroupname='rf'): self.rf = BaseVigraRandomForest(str(fn), rfgroupname) f = h5py.File(fn, 'r') for attr in f[rfgroupname].attrs: print("f[%s] = %s" % (attr, f[rfgroupname].attrs[attr])) setattr(self, attr, f[rfgroupname].attrs[attr]) def read_rf_info(fn): f = h5py.File(fn) return list(map(np.array, [f['oob'], f['feature_importance']]))
[docs]def concatenate_data_elements(alldata): """Return one big learning set from a list of learning sets. A learning set is a list/tuple of length 4 containing features, labels, weights, and node merge history. """ return list(map(np.concatenate, zip(*alldata)))
def unique_learning_data_elements(alldata): if type(alldata[0]) not in (list, tuple): alldata = [alldata] f, l, w, h = concatenate_data_elements(alldata) af = f.view('|S%d'%(f.itemsize*(len(f[0])))) _, uids, iids = np.unique(af, return_index=True, return_inverse=True) bcs = np.bincount(iids) logging.debug( 'repeat feature vec min %d, mean %.2f, median %.2f, max %d.' % (bcs.min(), np.mean(bcs), np.median(bcs), bcs.max()) ) def get_uniques(ar): return ar[uids] return list(map(get_uniques, [f, l, w, h]))
[docs]def sample_training_data(features, labels, num_samples=None): """Get a random sample from a classification training dataset. Parameters ---------- features: np.ndarray [M x N] The M (number of samples) by N (number of features) feature matrix. labels: np.ndarray [M] or [M x 1] The training label for each feature vector. num_samples: int, optional The size of the training sample to draw. Return full dataset if `None` or if num_samples >= M. Returns ------- feat: np.ndarray [num_samples x N] The sampled feature vectors. lab: np.ndarray [num_samples] or [num_samples x 1] The sampled training labels """ m = len(features) if num_samples is None or num_samples >= m: return features, labels idxs = random.sample(list(range(m)), num_samples) return features[idxs], labels[idxs]
def save_training_data_to_disk(data, fn, names=None, info='N/A'): if names is None: names = ['features', 'labels', 'weights', 'history'] fout = h5py.File(fn, 'w') for data_elem, name in zip(data, names): fout[name] = data_elem fout.attrs['info'] = info fout.close() def load_training_data_from_disk(fn, names=None, info='N/A'): if names is None: names = ['features', 'labels', 'weights', 'history'] fin = h5py.File(fn, 'r') data = [] for name in names: data.append(np.array(fin[name])) return data