This example demonstrates the use of randomized search to select the number of states via cross validation, using sklearn's RandomizedSearchCV
.
from __future__ import print_function
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats.distributions
from msmbuilder.example_datasets import load_quadwell
from msmbuilder.example_datasets import quadwell_eigs
from msmbuilder.cluster import NDGrid
from msmbuilder.msm import MarkovStateModel
from sklearn.pipeline import Pipeline
from sklearn.grid_search import RandomizedSearchCV
pipeline = Pipeline([
('grid', NDGrid(min=-1.2, max=1.2)),
('msm', MarkovStateModel(n_timescales=3, reversible_type='transpose', verbose=False)),
])
true_gmrq4 = quadwell_eigs(250)[0][:4].sum()
search = RandomizedSearchCV(pipeline, n_iter=10, cv=3, refit=False, param_distributions={
'grid__n_bins_per_feature': scipy.stats.distributions.randint(10, 500),
})
# take only the first 500 data points from each trajectory.
# this will create a smaller dataset that's easier to overfit,
# by using too many states
dataset = [t[0:500] for t in load_quadwell().trajectories]
search.fit(dataset)
scores = np.array([[np.mean(e.cv_validation_scores),
np.std(e.cv_validation_scores),
e.parameters['grid__n_bins_per_feature']]
for e in search.grid_scores_])
plt.scatter(scores[:,2], scores[:, 0])
plt.plot(plt.xlim(), [search.best_score_]*2, 'k-.', label='best')
plt.plot(plt.xlim(), [true_gmrq4]*2, 'k-', label='true')
print('Best params:', search.best_params_)
plt.legend(loc=4)
plt.show()
(quadwell-n-states.ipynb; quadwell-n-states_evaluated.ipynb; quadwell-n-states.py)