Source code for msmbuilder.cluster

# Author: Robert McGibbon <rmcgibbo@gmail.com>
# Contributors:
# Copyright (c) 2014, Stanford University
# All rights reserved.

#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------

from __future__ import absolute_import, print_function, division
import numpy as np
from sklearn import cluster
from sklearn import mixture

import mdtraj as md
from ..base import BaseEstimator
from ..utils import check_iter_of_sequences

from .base import MultiSequenceClusterMixin
from .kcenters import KCenters
from .ndgrid import NDGrid
from .agglomerative import LandmarkAgglomerative
from .regularspatial import RegularSpatial
from .kmedoids import KMedoids
from .minibatchkmedoids import MiniBatchKMedoids

__all__ = ['KMeans', 'MiniBatchKMeans', 'AffinityPropagation', 'MeanShift',
           'GMM', 'SpectralClustering', 'Ward', 'KCenters', 'NDGrid',
           'LandmarkAgglomerative', 'RegularSpatial', 'KMedoids',
           'MiniBatchKMedoids', 'MultiSequenceClusterMixin']


def _replace_labels(doc):
    """Really hacky find-and-replace method that modifies one of the sklearn
    docstrings to change the semantics of labels_ for the subclasses"""
    lines = doc.splitlines()
    labelstart, labelend = None, None
    foundattributes = False
    for i, line in enumerate(lines):
        if 'Attributes' in line:
            foundattributes = True
        if 'labels' in line and not labelstart and foundattributes:
            labelstart = len('\n'.join(lines[:i]))
        if labelstart and line.strip() == '' and not labelend:
            labelend = len('\n'.join(lines[:i + 1]))

    replace = '''\n    `labels_` : list of arrays, each of shape [sequence_length, ]
        The label of each point is an integer in [0, n_clusters).
    '''

    return doc[:labelstart] + replace + doc[labelend:]

#-----------------------------------------------------------------------------
# New "multisequence" versions of all of the clustering algorithims in sklearn
#-----------------------------------------------------------------------------


[docs]class KMeans(MultiSequenceClusterMixin, cluster.KMeans, BaseEstimator): __doc__ = _replace_labels(cluster.KMeans.__doc__)
[docs]class MiniBatchKMeans(MultiSequenceClusterMixin, cluster.MiniBatchKMeans, BaseEstimator): __doc__ = _replace_labels(cluster.MiniBatchKMeans.__doc__)
[docs]class AffinityPropagation(MultiSequenceClusterMixin, cluster.AffinityPropagation, BaseEstimator): __doc__ = _replace_labels(cluster.AffinityPropagation.__doc__)
[docs]class MeanShift(MultiSequenceClusterMixin, cluster.MeanShift, BaseEstimator): __doc__ = _replace_labels(cluster.MeanShift.__doc__)
[docs]class SpectralClustering(MultiSequenceClusterMixin, cluster.SpectralClustering, BaseEstimator): __doc__ = _replace_labels(cluster.SpectralClustering.__doc__)
[docs]class Ward(MultiSequenceClusterMixin, cluster.Ward, BaseEstimator): __doc__ = _replace_labels(cluster.Ward.__doc__)
[docs]class GMM(MultiSequenceClusterMixin, mixture.GMM, BaseEstimator): __doc__ = _replace_labels(mixture.GMM.__doc__)
Versions