HMM and MSM Timescales for Ala2ΒΆ

hmm-and-msm_evaluated

This example builds HMM and MSMs on the alanine_dipeptide dataset using varing lag times and numbers of states, and compares the relaxation timescales

In [1]:
from __future__ import print_function
import os
from matplotlib.pyplot import *
from msmbuilder.featurizer import SuperposeFeaturizer
from msmbuilder.example_datasets import AlanineDipeptide
from msmbuilder.hmm import GaussianFusionHMM
from msmbuilder.cluster import KCenters
from msmbuilder.msm import MarkovStateModel

First: load and "featurize"

Featurization refers to the process of converting the conformational snapshots from your MD trajectories into vectors in some space $\mathbb{R}^N$ that can be manipulated and modeled by subsequent analyses. The Gaussian HMM, for instance, uses Gaussian emission distributions, so it models the trajectory as a time-dependent mixture of multivariate Gaussians.

In general, the featurization is somewhat of an art. For this example, we're using Mixtape's SuperposeFeaturizer, which superposes each snapshot onto a reference frame (trajectories[0][0] in this example), and then measure the distance from each atom to its position in the reference conformation as the 'feature'

In [2]:
print(AlanineDipeptide.description())

dataset = AlanineDipeptide().get()
trajectories = dataset.trajectories
topology = trajectories[0].topology

indices = [atom.index for atom in topology.atoms if atom.element.symbol in ['C', 'O', 'N']]
featurizer = SuperposeFeaturizer(indices, trajectories[0][0])
sequences = featurizer.transform(trajectories)
The dataset consists of ten 10ns trajectories of of alanine dipeptide,
simulated using OpenMM 6.0.1 (CUDA platform, NVIDIA GTX660) with the
AMBER99SB-ILDN force field at 300K (langevin dynamics, friction coefficient
of 91/ps, timestep of 2fs) with GBSA implicit solvent. The coordinates are
saved every 1ps. Each trajectory contains 9,999 snapshots.

The dataset, including the script used to generate the dataset
is available on figshare at

http://dx.doi.org/10.6084/m9.figshare.1026131


Now sequences is our featurized data.

In [3]:
lag_times = [1, 10, 20, 30, 40]
hmm_ts0 = {}
hmm_ts1 = {}
n_states = [3, 5]

for n in n_states:
    hmm_ts0[n] = []
    hmm_ts1[n] = []
    for lag_time in lag_times:
        strided_data = [s[i::lag_time] for s in sequences for i in range(lag_time)]
        hmm = GaussianFusionHMM(n_states=n, n_features=sequences[0].shape[1], n_init=1).fit(strided_data)
        timescales = hmm.timescales_ * lag_time
        hmm_ts0[n].append(timescales[0])
        hmm_ts1[n].append(timescales[1])
        print('n_states=%d\tlag_time=%d\ttimescales=%s' % (n, lag_time, timescales))
    print()
n_states=3	lag_time=1	timescales=[ 125.90518188    3.76165414]
n_states=3	lag_time=10	timescales=[ 208.22389221    5.93706512]
n_states=3	lag_time=20	timescales=[ 221.70808411    6.2912302 ]
n_states=3	lag_time=30	timescales=[ 227.06163025    7.76600266]
n_states=3	lag_time=40	timescales=[ 230.3684845    8.1902256]

n_states=5	lag_time=1	timescales=[ 126.70064545    3.75312209    2.00503063    0.78047842]
n_states=5	lag_time=10	timescales=[ 210.25354004    5.94292164    2.72830915    1.99519026]
n_states=5	lag_time=20	timescales=[ 221.25015259    6.4437809     4.15999699    3.39460754]
n_states=5	lag_time=30	timescales=[ 227.30914307    7.20348692    6.25866508]
n_states=5	lag_time=40	timescales=[ 230.55610657    9.14086342    8.62874317    6.26740265]


In [4]:
figure(figsize=(14,3))

for i, n in enumerate(n_states):
    subplot(1,len(n_states),1+i)
    plot(lag_times, hmm_ts0[n])
    plot(lag_times, hmm_ts1[n])
    if i == 0:
        ylabel('Relaxation Timescale')
    xlabel('Lag Time')
    title('%d states' % n)

show()
In [5]:
msmts0, msmts1 = {}, {}
lag_times = [1, 10, 20, 30, 40]
n_states = [4, 8, 16, 32, 64]

for n in n_states:
    msmts0[n] = []
    msmts1[n] = []
    for lag_time in lag_times:
        assignments = KCenters(n_clusters=n).fit_predict(sequences)
        msm = MarkovStateModel(lag_time=lag_time, verbose=False).fit(assignments)
        timescales = msm.timescales_
        msmts0[n].append(timescales[0])
        msmts1[n].append(timescales[1])
        print('n_states=%d\tlag_time=%d\ttimescales=%s' % (n, lag_time, timescales[0:2]))
    print()
n_states=4	lag_time=1	timescales=[ 77.91153755   2.66159204]
n_states=4	lag_time=10	timescales=[ 174.01732276    4.36825276]
n_states=4	lag_time=20	timescales=[ 161.62416609    3.7265922 ]
n_states=4	lag_time=30	timescales=[ 213.33495369    6.83763059]
n_states=4	lag_time=40	timescales=[ 225.795004      8.38044144]

n_states=8	lag_time=1	timescales=[ 50.33475167   1.45723125]
n_states=8	lag_time=10	timescales=[ 147.01073233    5.19755161]
n_states=8	lag_time=20	timescales=[ 206.10701759    5.99616219]
n_states=8	lag_time=30	timescales=[ 187.9113262     7.69534766]
n_states=8	lag_time=40	timescales=[ 200.38302424    9.11645206]

n_states=16	lag_time=1	timescales=[ 105.00923068    3.40019785]
n_states=16	lag_time=10	timescales=[ 213.60491966    6.0556039 ]
n_states=16	lag_time=20	timescales=[ 215.05088944    6.26049243]
n_states=16	lag_time=30	timescales=[ 225.56068259    7.73438312]
n_states=16	lag_time=40	timescales=[ 228.6546226    10.51247598]

n_states=32	lag_time=1	timescales=[ 127.01704704    4.50036787]
n_states=32	lag_time=10	timescales=[ 222.01936062    6.30417916]
n_states=32	lag_time=20	timescales=[ 230.1909104     6.55949809]
n_states=32	lag_time=30	timescales=[ 232.04222208    8.39279602]
n_states=32	lag_time=40	timescales=[ 232.53979699   10.41725168]

n_states=64	lag_time=1	timescales=[ 158.69866183    4.87452951]
n_states=64	lag_time=10	timescales=[ 225.90616501    6.50118758]
n_states=64	lag_time=20	timescales=[ 232.97342926    6.6469737 ]
n_states=64	lag_time=30	timescales=[ 235.78370131    8.99127601]
n_states=64	lag_time=40	timescales=[ 237.28066114   12.19009133]


In [6]:
figure(figsize=(14,3))

for i, n in enumerate(n_states):
    subplot(1,len(n_states),1+i)
    plot(lag_times, msmts0[n])
    plot(lag_times, msmts1[n])
    if i == 0:
        ylabel('Relaxation Timescale')
    xlabel('Lag Time')
    title('%d states' % n)

show()

(hmm-and-msm.ipynb; hmm-and-msm_evaluated.ipynb; hmm-and-msm.py)

Versions