Model selection with Randomized CVΒΆ

quadwell-n-states_evaluated

This example demonstrates the use of randomized search to select the number of states via cross validation, using sklearn's RandomizedSearchCV.

In [1]:
from __future__ import print_function
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats.distributions
from msmbuilder.example_datasets import load_quadwell
from msmbuilder.example_datasets import quadwell_eigs
from msmbuilder.cluster import NDGrid
from msmbuilder.msm import MarkovStateModel
from sklearn.pipeline import Pipeline
from sklearn.grid_search import RandomizedSearchCV
loading from /home/rmcgibbo/miniconda/lib/plugins
In [2]:
pipeline = Pipeline([
    ('grid', NDGrid(min=-1.2, max=1.2)),
    ('msm', MarkovStateModel(n_timescales=3, reversible_type='transpose', verbose=False)),
])

true_gmrq4 = quadwell_eigs(250)[0][:4].sum()
In [3]:
search = RandomizedSearchCV(pipeline, n_iter=10, cv=3, refit=False, param_distributions={
    'grid__n_bins_per_feature': scipy.stats.distributions.randint(10, 500),
})

# take only the first 500 data points from each trajectory.
# this will create a smaller dataset that's easier to overfit,
# by using too many states
dataset = [t[0:500] for t in load_quadwell().trajectories]
search.fit(dataset)
Out[3]:
RandomizedSearchCV(cv=3,
          estimator=Pipeline(steps=[('grid', NDGrid(max=1.2, min=-1.2, n_bins_per_feature=2)), ('msm', MarkovStateModel(ergodic_cutoff=1.0, lag_time=1, n_timescales=3,
         prior_counts=0, reversible_type='transpose', sliding_window=True,
         verbose=False))]),
          fit_params={}, iid=True, n_iter=10, n_jobs=1,
          param_distributions={'grid__n_bins_per_feature': <scipy.stats._distn_infrastructure.rv_frozen object at 0x2b969060f9d0>},
          pre_dispatch='2*n_jobs', random_state=None, refit=False,
          scoring=None, verbose=0)
In [4]:
scores = np.array([[np.mean(e.cv_validation_scores),
                    np.std(e.cv_validation_scores),
                    e.parameters['grid__n_bins_per_feature']]
                   for e in search.grid_scores_])

plt.scatter(scores[:,2], scores[:, 0])
plt.plot(plt.xlim(), [search.best_score_]*2, 'k-.', label='best')

plt.plot(plt.xlim(), [true_gmrq4]*2, 'k-', label='true')
print('Best params:', search.best_params_)
plt.legend(loc=4)
plt.show()
Best params: {'grid__n_bins_per_feature': 219}

(quadwell-n-states.ipynb; quadwell-n-states_evaluated.ipynb; quadwell-n-states.py)

Versions