Source code for msmbuilder.dataset

# Author: Robert McGibbon <rmcgibbo@gmail.com>
# Contributors:
# Copyright (c) 2014, Stanford University
# All rights reserved.

from __future__ import absolute_import, print_function, division

import sys
import os
import re
import glob
from os.path import join, exists, expanduser
import socket
import getpass
import itertools
from datetime import datetime
from collections import Sequence
import warnings

import tables
import mdtraj as md
from mdtraj.core.trajectory import _parse_topology
import numpy as np
from . import version

_PYTABLES_DISABLE_COMPRESSION = tables.Filters(complevel=0)


__all__ = ['dataset']


[docs]def dataset(path, mode='r', fmt=None, verbose=False, **kwargs): """Open a dataset object MSMBuilder supports several dataset 'formats' for storing lists of sequences on disk. This function can also be used as a context manager. Parameters ---------- path : str The path to the dataset on the filesystem mode : {'r', 'w', 'a'} Open a dataset for reading, writing, or appending. Note that some formats only support a subset of these modes. fmt : {'dir-npy', 'hdf5', 'mdtraj'} The format of the data on disk ``dir-npy`` A directory of binary numpy files, one file per sequence ``hdf5`` A single hdf5 file with each sequence as an array node ``mdtraj`` A read-only set of trajectory files that can be loaded with mdtraj ``dir-npy-union`` or ``hdf5-union`` Several datasets of the respective type which will have their features union-ed together. verbose : bool Whether to print information about the dataset """ if mode == 'r' and fmt is None: fmt = _guess_format(path) elif mode in 'wa' and fmt is None: raise ValueError('mode="%s", but no fmt. fmt=%s' % (mode, fmt)) if fmt == 'dir-npy': return NumpyDirDataset(path, mode=mode, verbose=verbose) elif fmt == 'mdtraj': return MDTrajDataset(path, mode=mode, verbose=verbose, **kwargs) elif fmt == 'hdf5': return HDF5Dataset(path, mode=mode, verbose=verbose) elif fmt.endswith("-union"): sub_fmt = fmt[:-len('-union')] return UnionDataset(path, fmt=sub_fmt, mode=mode, verbose=verbose) else: raise NotImplementedError("unknown fmt: %s" % fmt)
def _guess_format(path): """Guess the format of a dataset based on its filename / filenames. """ if isinstance(path, (list, tuple)): # Is concatenating features "horizontally" the most obvious # behavior here? I don't think so. Passing a list to dataset() # should probably include the additional paths as additional # trajectories. E.g. # `ds = dataset(['traj1.dcd, traj2.dcd'], top='struct.pdb')` # `ds = dataset(['tica1/', 'tica2/']) # # In the second case, it is not as straightforward what the # expected behavior is, but for the first we should def. be # concatenating "vertically" rather than "horizontally" - mph warnings.warn("Prior to MSMB 3.2, passing a list of paths would" + " result in features being 'union-ed'." + " This behavior is deprecated as of v3.2 and will" + " be changed for v3.3." + " To retain the current functionality, specify" + " fmt='dir-npy-union' or fmt='hdf5-union' explicitly.") fmt = _guess_format(path[0]) err = "Only the union of 'dir-npy' and 'hdf5' formats is supported" assert fmt in ['dir-npy', 'hdf5'], err err = "All datasets must be the same format" for p in path[1:]: assert _guess_format(p) == fmt, err return "{}-union".format(fmt) if os.path.isdir(path): return 'dir-npy' if path.endswith('.h5') or path.endswith('.hdf5'): # TODO: Check for mdtraj .h5 file return 'hdf5' # TODO: What about a list of trajectories, e.g. from command line nargs='+' return 'mdtraj' class _BaseDataset(Sequence): _PROVENANCE_TEMPLATE = '''MSMBuilder Dataset: MSMBuilder:\t{version} Command:\t{cmdline} Path:\t\t{path} Username:\t{user} Hostname:\t{hostname} Date:\t\t{date} Comments:\t\t{comments} ''' _PREV_TEMPLATE = ''' == Derived from == {previous} ''' def __init__(self, path, mode='r', verbose=False): self.path = path self.mode = mode self.verbose = verbose if mode not in ('r', 'w', 'a'): raise ValueError('mode must be one of "r", "w", "a"') if mode in 'wa': if mode == 'w' and exists(path): raise ValueError('File exists: %s' % path) #os.makedirs(path, exist_ok=True) # (py3 only) try: os.makedirs(path) except OSError: pass self._write_provenance() def create_derived(self, out_path, comments='', fmt=None): if fmt is None: out_dataset = self.__class__(out_path, mode='w', verbose=self.verbose) else: out_dataset = dataset(out_path, mode='w', verbose=self.verbose, fmt=fmt) out_dataset._write_provenance(previous=self.provenance, comments=comments) return out_dataset def apply(self, fn): for key in self.keys(): yield fn(self.get(key)) def _build_provenance(self, previous=None, comments=''): val = self._PROVENANCE_TEMPLATE.format( version=version.full_version, cmdline=' '.join(sys.argv), user=getpass.getuser(), hostname=socket.gethostname(), path=self.path, comments=comments, date=datetime.now().strftime("%B %d, %Y %I:%M %p")) if previous: val += self._PREV_TEMPLATE.format(previous=previous) return val @property def provenance(self): raise NotImplementedError('implemented in subclass') def _write_provenance(self, previous=None, comments=''): raise NotImplementedError('implemented in subclass') def __len__(self): return sum(1 for xx in self.keys()) def __getitem__(self, i): return self.get(i) def __setitem__(self, i, x): return self.set(i, x) def __iter__(self): for key in self.keys(): yield self.get(key) def keys(self): # keys()[i], get(i) and set(i, x) should all follow # the same ordering convention for the indices / items. raise NotImplementedError('implemeneted in subclass') def items(self): for key in self.keys(): yield (key, self.get(key)) def get(self, i): raise NotImplementedError('implemeneted in subclass') def set(self, i, x): raise NotImplementedError('implemeneted in subclass') def close(self): pass def flush(self): pass def __enter__(self): return self def __exit__(self, *exc_info): self.close() class NumpyDirDataset(_BaseDataset): """Mixtape dataset container Parameters ---------- path : str mode : {'r', 'w', 'a'} Read, write, or append. If mode is set to 'a' or 'w', duplicate keys will be overwritten. Examples -------- for X in Dataset('path/to/dataset'): print X """ _ITEM_FORMAT = '%08d.npy' _ITEM_RE = re.compile('(\d{8}).npy') _PROVENANCE_FILE = 'PROVENANCE.txt' def get(self, i, mmap=False): if isinstance(i, slice): items = [] start, stop, step = i.indices(len(self)) for ii in itertools.islice(itertools.count(), start, stop, step): items.append(self.get(ii)) return items mmap_mode = 'r' if mmap else None filename = join(self.path, self._ITEM_FORMAT % i) if self.verbose: print('[NumpydirDataset] loading %s' % filename) try: return np.load(filename, mmap_mode) except IOError as e: raise IndexError(e) def set(self, i, x): if self.mode not in 'wa': raise IOError('Dataset not opened for writing') filename = join(self.path, self._ITEM_FORMAT % i) if self.verbose: print('[NumpydirDataset] saving %s' % filename) return np.save(filename, x) def keys(self): for fn in sorted(os.listdir(os.path.expanduser(self.path)), key=_keynat): match = self._ITEM_RE.match(fn) if match: yield int(match.group(1)) @property def provenance(self): try: with open(join(self.path, self._PROVENANCE_FILE), 'r') as f: return f.read() except IOError: return 'No available provenance' def _write_provenance(self, previous=None, comments=''): with open(join(self.path, self._PROVENANCE_FILE), 'w') as f: p = self._build_provenance(previous=previous, comments=comments) f.write(p) class HDF5Dataset(_BaseDataset): _ITEM_FORMAT = 'arr_%d' _ITEM_RE = re.compile('arr_(\d+)') def __init__(self, path, mode='r', verbose=False): if mode not in ('r', 'w'): raise ValueError('mode must be one of "r", "w"') if mode == 'w': if exists(path): raise ValueError('File exists: %s' % path) self._handle = tables.open_file(path, mode=mode, filters=_PYTABLES_DISABLE_COMPRESSION) self.path = path self.mode = mode self.verbose = verbose if mode == 'w': self._write_provenance() def __getstate__(self): # pickle does not like to pickle the pytables handle, so... # self.flush() return {'path': self.path, 'mode': self.mode, 'verbose': self.verbose} def __setstate__(self, state): self.path = state['path'] self.mode = state['mode'] self.verbose = state['verbose'] self._handle = tables.open_file(self.path, mode=self.mode, filters=_PYTABLES_DISABLE_COMPRESSION) def get(self, i, mmap=False): if isinstance(i, slice): items = [] start, stop, step = i.indices(len(self)) for ii in itertools.islice(itertools.count(), start, stop, step): items.append(self.get(ii)) return items return self._handle.get_node('/', self._ITEM_FORMAT % i)[:] def keys(self): nodes = self._handle.list_nodes('/') for node in sorted(nodes, key=lambda x: _keynat(x.name)): match = self._ITEM_RE.match(node.name) if match: yield int(match.group(1)) def set(self, i, x): if 'w' not in self.mode: raise IOError('Dataset not opened for writing') try: self._handle.create_carray('/', self._ITEM_FORMAT % i, obj=x) except tables.exceptions.NodeError: self._handle.remove_node('/', self._ITEM_FORMAT % i) self.set(i, x) @property def provenance(self): try: return self._handle.root._v_attrs['provenance'] except KeyError: return 'No available provenance' def _write_provenance(self, previous=None, comments=''): p = self._build_provenance(previous=previous, comments=comments) self._handle.root._v_attrs['provenance'] = p def close(self): if hasattr(self, '_handle'): self._handle.close() def flush(self): self._handle.flush() def __del__(self): self.close() class MDTrajDataset(_BaseDataset): _PROVENANCE_TEMPLATE = '''MDTraj dataset: path:\t\t{path} topology:\t{topology} stride:\t{stride} atom_indices\t{atom_indices} ''' def __init__(self, path, mode='r', topology=None, stride=1, atom_indices=None, verbose=False): if mode != 'r': raise ValueError('mode must be "r"') self.path = path self.topology = topology self.stride = stride self.atom_indices = atom_indices self.verbose = verbose if isinstance(path, list): self.glob_matches = [expanduser(fn) for fn in path] else: self.glob_matches = sorted(glob.glob(expanduser(path)), key=_keynat) if topology is None: self._topology = None else: self._topology = _parse_topology(os.path.expanduser(topology)) def get(self, i): if self.verbose: print('[MDTraj dataset] loading %s' % self.filename(i)) if self._topology is None: t = md.load(self.filename(i), stride=self.stride, atom_indices=self.atom_indices) else: t = md.load(self.filename(i), stride=self.stride, atom_indices=self.atom_indices, top=self._topology) return t def filename(self, i): return self.glob_matches[i] def iterload(self, i, chunk): if self.verbose: print('[MDTraj dataset] iterloading %s' % self.filename(i)) if self._topology is None: return md.iterload( self.filename(i), chunk=chunk, stride=self.stride, atom_indices=self.atom_indices) else: return md.iterload( self.filename(i), chunk=chunk, stride=self.stride, atom_indices=self.atom_indices, top=self._topology) def keys(self): return iter(range(len(self.glob_matches))) @property def provenance(self): return self._PROVENANCE_TEMPLATE.format( path=self.path, topology=self.topology, atom_indices=self.atom_indices, stride=self.stride) def _dim_match(arr): if arr.ndim == 1: return arr[:, np.newaxis] return arr class UnionDataset(_BaseDataset): def __init__(self, paths, mode, fmt='dir-npy', verbose=False): # Check mode if mode != 'r': raise ValueError("Union datasets are read only") # Check format supported_subformats = ['dir-npy', 'hdf5'] if fmt not in supported_subformats: err = "Format must be one of {}. You gave {}" err = err.format(supported_subformats, fmt) raise ValueError(err) # Save parameters self.verbose = verbose self.datasets = [dataset(path, mode, fmt, verbose) for path in paths] # Sanity check self._check_same_length() def _check_same_length(self): """Check that the datasets are the same length""" lens = [] for ds in self.datasets: lens.append( sum(1 for _ in ds.keys()) ) if len(set(lens)) > 1: err = "Each dataset must be the same length. You gave: {}" err = err.format(lens) raise ValueError(err) def keys(self): return self.datasets[0].keys() def get(self, i): return np.concatenate([_dim_match(ds.get(i)) for ds in self.datasets], axis=1) def close(self): for ds in self.datasets: ds.close() def flush(self): for ds in self.datasets: ds.close() @property def provenance(self): return "\n\n".join(ds.provenance for ds in self.datasets) def _keynat(string): """A natural sort helper function for sort() and sorted() without using regular expression. """ r = [] for c in string: if c.isdigit(): if r and isinstance(r[-1], int): r[-1] = r[-1] * 10 + int(c) else: r.append(int(c)) else: r.append(9 + ord(c)) return r
Versions