#!/usr/bin/env python -u
"""Command line driver for OpenMM
"""
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
from __future__ import print_function
import os
import sys
import shutil
import logging
import platform
from datetime import datetime
try:
    from collections import OrderedDict
except ImportError:
    # using ordereddict is just cosmetic, so on python2.6
    # we can just use regular dict
    OrderedDict = dict

# openmm
try:
    from simtk import unit
    from simtk.openmm import app
    from simtk import openmm as mm
except ImportError as err:
    print("Failed to import OpenMM packages:", str(err))
    print("Make sure OpenMM is installed and the library path is set correctly.")
    sys.exit(1)

from ipcfg.progressreporter import ProgressReporter
from ipcfg.restartreporter import RestartReporter, loadRestartFile
from ipcfg.velocityverlet import VelocityVerletIntegrator

# XML parsing
import xml.etree.ElementTree as etree

# command line configuration system
from ipcfg.extratraitlets import Quantity
from ipcfg.openmmapplication import OpenMMApplication, AppConfigurable
from ipcfg.IPython.traitlets import (CInt, CBool, CBytes, CaselessStrEnum, List,
                                     Instance, Enum, CFloat, TraitError)
from ipcfg.IPython.loader import AliasError
from ipcfg.IPython.text import wrap_paragraphs

#-----------------------------------------------------------------------------
# Classes
#-----------------------------------------------------------------------------


class General(AppConfigurable):
    """General options, including the force field, platform, and coordinates.
    """

    protein = CaselessStrEnum(['amber96', 'amber99sb', 'amber99sb-ildn',
        'amber99sb-nmr', 'amber03', 'amber10', 'amoeba2009'], allow_none=True, config=True,
        default_value='None', help=''' Force Field to use for the protein atoms.
        For details, consult the literature.''')
    water = CaselessStrEnum(['SPC/E', 'TIP3P', 'TIP4P-Ew', 'TIP5P', 'Implicit', 'None'],
        config=True, default_value='None', allow_none=True, help='''Water model 
        to use in the simulation.''')
    ffxml = List(config=True, help='''Supply one or more custom force field files,
        in the OpenMM XML format. This can be used to specify a force field for
        ligands, nonstandard amino acids, etc.''', nargs='+')
    sysxml = CBytes(config=True, help='''Supply one OpenMM system XML file, which
        comes from a serialized System object and provides a complete system.''')
    serialize = CBytes(config=True, help='''Write a system XML file, which this program
        can read using the "sysxml" option.''')
    prmtop = CBytes(config=True, help='''Supply one AMBER prmtop file, which provides
        a complete topology and system.''')
    gmxtop = CBytes(config=True, help='''Supply one GROMACS .top file, which provides
        a complete topology and system.''')
    platform = CaselessStrEnum(['Reference', 'OpenCL', 'CUDA', 'CPU', 'NotSpecified'],
        default_value='NotSpecified', allow_none=False, help='''OpenMM runs
        simulations on four platforms: Reference, CUDA, CPU, and OpenCL. If not
        specified, the fastest available platform will be selected
        automatically.''', config=True)
    precision = CaselessStrEnum(['Single', 'Mixed', 'Double'], config=True,
        allow_none=False, default_value='Mixed', help='''Level of numeric
        precision to use for calculations.''')
    device = CInt(config=True, help='''Supply the device index of the CUDA device
        (i.e. NVidia GPU) or OpenCL device that you want to run on. Defaults to
        the fastest device available.''')
    coords = CBytes(config=True, help='''OpenMM can take a pdb, which contains
        the coordinates and topology, or AMBER inpcrd, which contains coordinates.''')

    # nonconfigurable traits
    pdb_file = Instance(app.PDBFile)
    gro_file = Instance(app.GromacsGroFile)
    inpcrd_file = Instance(app.AmberInpcrdFile)
    prmtop_file = Instance(app.AmberPrmtopFile)
    gmxtop_file = Instance(app.GromacsTopFile)
    fastest_platform = CBytes(help='The name of the fastest platform on the system')
    xml_override = []
    def _fastest_platform_default(self):
        fastest, speed = -1, -1
        for i in range(mm.Platform.getNumPlatforms()):
            p = mm.Platform.getPlatform(i)
            s = p.getSpeed()
            if s > speed:
                fastest = p.getName()
                speed = s

        return fastest

    def active_config_traits(self):
        values = ['platform', 'coords']
        if self.platform != 'Reference' and self.fastest_platform in ['CUDA', 'OpenCL']:
            values.append('precision')
            if 'device' in self.specified_config_traits:
                values.append('device') 

        if len(self.ffxml) > 0:
            values.append('ffxml')

        xmltraits = ['protein', 'water']

        if 'sysxml' in self.specified_config_traits:
            values.append('sysxml')
            self.xml_override = xmltraits
        else:
            values += xmltraits

        return values

    def validate(self):
        self.log.debug('Running general options validations.')
        if 'precision' in self.specified_config_traits and self.platform not in ['OpenCL', 'CUDA']:
            raise TraitError('Manually setting the precision is only '
                             'appropriate on the OpenCL and CUDA platforms')
        if 'device' in self.specified_config_traits and self.platform not in ['OpenCL', 'CUDA']:
            raise TraitError('Manually setting the device is only '
                             'appropriate on the OpenCL and CUDA platforms')
        if 'sysxml' in self.specified_config_traits:
            if any([i in self.specified_config_traits for i in ['ffxml', 'prmtop', 'gmxtop', 'protein', 'water']]):
                raise TraitError('Since sysxml was specified, you should not specify ffxml / prmtop / gmxtop / protein / water.')
        if 'prmtop' in self.specified_config_traits:
            if any([i in self.specified_config_traits for i in ['ffxml', 'gmxtop', 'protein', 'water']]):
                raise TraitError('Since prmtop was specified, you should not specify any xml / gmxtop / protein / water.')
        if 'gmxtop' in self.specified_config_traits:
            if any([i in self.specified_config_traits for i in ['ffxml', 'protein', 'water']]):
                raise TraitError('Since gmxtop was specified, you should not specify any xml / prmtop / protein / water.')
        if self.coords.endswith('inpcrd'):
            if 'prmtop' not in self.specified_config_traits and 'gmxtop' not in self.specified_config_traits:
                raise TraitError('You specified AMBER coordinates (.inpcrd) which also requires a prmtop or gmxtop file.')
        if self.coords.endswith('gro'):
            if 'prmtop' not in self.specified_config_traits and 'gmxtop' not in self.specified_config_traits:
                raise TraitError('You specified GROMACS coordinates (.gro) which also requires a prmtop or gmxtop file.')
                
    def load_coords(self):
        "Load coordinate/topology files from disk"
        if self.coords.endswith('.pdb'):
            if self.pdb_file is None:
                self.application.script("pdb = app.PDBFile('%s')" % self.coords)
                self.pdb_file = app.PDBFile(self.coords)
        elif self.coords.endswith('.gro'):
            if self.gro_file is None:
                self.application.script("gro = app.GromacsGroFile('%s')" % self.coords)
                self.gro_file = app.GromacsGroFile(self.coords)
        elif self.coords.endswith('.inpcrd'):
            if self.inpcrd_file is None:
                self.application.script("inpcrd = app.AmberInpcrdFile('%s')" % self.coords)
                self.inpcrd_file = app.AmberInpcrdFile(self.coords)
        elif self.coords == '':
            self.application.error('You must provide a coordinate file, either in the '
                                   'configuration file or on the command line using, '
                                   'for example, --coords protein.pdb or --coords protein.inpcrd')
        else:
            raise NotImplementedError('Currently,only reading from .pdb, .gro, and .inpcrd files '
                                      'is implemented')
        if 'prmtop' in self.specified_config_traits:
            if self.prmtop_file is None:
                self.application.script("prmtop = app.AmberPrmtopFile('%s')" % self.prmtop)
                self.prmtop_file = app.AmberPrmtopFile(self.prmtop)
        elif 'gmxtop' in self.specified_config_traits:
            if self.gmxtop_file is None:
                if self.pdb_file is not None:
                    self.application.script("gmxtop = app.GromacsTopFile('%s', unitCellDimensions=pdb.topology.getUnitCellDimensions())" % self.gmxtop)
                    self.gmxtop_file = app.GromacsTopFile(self.gmxtop, unitCellDimensions = self.pdb_file.topology.getUnitCellDimensions())
                elif self.gro_file is not None:
                    self.application.script("gmxtop = app.GromacsTopFile('%s', unitCellDimensions=gro.topology.getUnitCellDimensions())" % self.gmxtop)
                    self.gmxtop_file = app.GromacsTopFile(self.gmxtop, unitCellDimensions = self.gro_file.getUnitCellDimensions())
                else:
                    self.application.script("gmxtop = app.GromacsTopFile('%s')" % self.gmxtop)
                    self.gmxtop_file = app.GromacsTopFile(self.gmxtop)
            
    def get_forcefield(self):
        "Create the force field object"
        files = [e for e in self.ffxml]  # copy
        if self.protein not in ['None', None]:
            files.append(self.protein.replace('-', '').lower() + '.xml')
        else:
            self.log.warning('No builtin protein force field being used')


        if self.water in ['SPC/E', 'TIP3P', 'TIP4P-Ew', 'TIP5P']:
            files.append(
                self.water.replace('/', '').replace('-', '').lower() + '.xml')
        elif self.water == 'Implicit':
            if self.protein == 'amber96':
                files.append('amber96_obc.xml')
            elif self.protein.startswith('amber99'):
                files.append('amber99_obc.xml')
            elif self.protein == 'amber03':
                files.append('amber03_obc.xml')
            elif self.protein == 'amber10':
                files.append('amber10_obc.xml')
            elif self.protein == 'amoeba2009':
                files.append('amoeba2009_gk.xml')
            else:
                self.error("Without specifying the protein force field, I can't"
                           "unambiguously pick the implicit solvent parameters.")
        else:
            if self.protein != 'amoeba2009': # AMOEBA automatically includes a water model.
                self.log.warning("No builtin water model is being used. If you'd "
                             "like to specify a water model, use --water.")

        if len(files) == 0:
            self.application.error('No force fields were loaded. You can use OpenMM supplied '
                                   'force fields with the "protein" and "water" configurables '
                                   '(--protein / --water flags on the command line), or supply '
                                   'custom OpenMM XML force field files with the "ffxml" option')

        self.application.script('forcefield = app.ForceField(%s)' %
               ', '.join(["'%s'" % f for f in files]))

        is_amoeba_ff = False
        for fn in files:
            try:
                tree = etree.parse(fn)
            except IOError:
                tree = etree.parse(os.path.join(os.path.dirname(app.__file__), 'data', fn))
            xmlroot = tree.getroot()
            if xmlroot.tag.lower() != 'forcefield':
                self.application.error('Tried to load %s as a force field XML file, but the first tag is %s. '
                                       'You may load system XML files using the "sysxml" option'  % (fn, xmlroot.tag))
            for child in xmlroot:
                if 'amoeba' in child.tag.lower() and is_amoeba_ff == False:
                    self.log.info('Detected AMOEBA force field!')
                    is_amoeba_ff = True

        return app.ForceField(*files), is_amoeba_ff

    def get_system_from_sysxml(self):
        "Create the System object from a system XML file"

        is_amoeba_ff = False

        fn = self.sysxml

        try:
            tree = etree.parse(fn)
        except:
            raise RuntimeError('Failed to load XML file: %s' % fn)

        xmlroot = tree.getroot()
        if not ('type' in xmlroot.attrib and xmlroot.attrib['type'].lower() == 'system'):
            self.application.error('Tried to load %s as a system XML file, but the root must have attribute type="System"'
                                   'You may load force field XML files using the "ffxml" option'  % (fn))

        for child in xmlroot:
            if 'amoeba' in child.tag.lower() and is_amoeba_ff == False:
                self.log.info('Detected AMOEBA system!')
                is_amoeba_ff = True

        self.application.script('system = XmlSerializer.deserializeSystem(open(%s).read())' % fn)
        system = mm.XmlSerializer.deserializeSystem(open(fn).read())

        return system

    def get_positions(self):
        "Get the positions for every atom"

        self.load_coords()
        if self.pdb_file is not None:
            self.application.script('positions = pdb.positions')
            return self.pdb_file.positions
        elif self.gro_file is not None:
            self.application.script('positions = gro.positions')
            return self.gro_file.positions
        elif self.inpcrd_file is not None:
            self.application.script('positions = inpcrd.positions')
            return self.inpcrd_file.positions
            
        raise RuntimeError('Only .pdb, .inpcrd, and .gro files are currently supported for reading coordinates.')

    def get_topology(self):
        "Get the system topology"

        self.load_coords()
        if self.pdb_file is not None:
            self.application.script('topology = pdb.topology')
            return self.pdb_file.topology
        elif self.prmtop_file is not None:
            self.application.script('topology = prmtop.topology')
            return self.prmtop_file.topology
        elif self.gmxtop_file is not None:
            self.application.script('topology = gmxtop.topology')
            return self.gmxtop_file.topology
        raise RuntimeError('Only .pdb, .prmtop, or Gromacs .top files are currently supported for reading the topology.')

    def get_unitcell(self):
        "Get the unit cell dimensions"

        self.load_coords()
        if self.pdb_file is not None:
            self.application.script('unitcell = pdb.getUnitCellDimensions()')
            return self.pdb_file.topology.getUnitCellDimensions()
        elif self.gro_file is not None:
            self.application.script('unitcell = gro.getUnitCellDimensions()')
            return self.gro_file.getUnitCellDimensions()
        raise RuntimeError('Only .pdb and .gro files are currently supported for reading the unit cell.')

    def get_platform(self):
        "Get the platform"

        if self.platform is 'NotSpecified':
            platform = self.fastest_platform
            self.platform = self.fastest_platform
        else:
            platform = self.platform

        self.application.script("platform = mm.Platform.getPlatformByName('%s')" %
                                platform)
        return mm.Platform.getPlatformByName(platform)


    def get_platform_properties(self):
        "Get any specified platorm properties"

        if self.platform is 'NotSpecified':
            platform = self.fastest_platform
        else:
            platform = self.platform

        if platform == 'Reference':
            pp = None
        elif platform == 'CPU':
            pp = None
        elif platform == 'CUDA':
            pp = {'CudaPrecision': self.precision.lower()}
        elif platform == 'OpenCL':
            pp = {'OpenCLPrecision': self.precision.lower()}
        else:
            raise RuntimeError('unknown platform')
        
        if 'device' in self.specified_config_traits:
            if platform == 'CUDA':
                pp['CudaDeviceIndex'] = str(self.device)
            elif platform == 'OpenCL':
                pp['OpenCLDeviceIndex'] = str(self.device)

        self.application.script('platformProperties = %s' % pp)
        return pp


class System(AppConfigurable):
    """Parameters for the system, including the method for calculating nonbonded
    forces, constraints, and initialization of velocities."""

    nb_method = CaselessStrEnum(['NoCutoff', 'CutoffNonPeriodic',
        'CutoffPeriodic', 'Ewald', 'PME'], config=True, default_value='PME',
        allow_none=False, help='''Method for calculating long range
        non-bondend interactions. Refer to the user guide for a detailed
        discussion.''')
    ewald_tol = CFloat(0.0005, config=True, allow_none=False, help='''The error
        tolerance is roughly equal to the fractional error in the forces due
        to truncating the Ewald summation.''')
    constraints = CaselessStrEnum(['None', 'HBonds', 'AllBonds', 'HAngles'],
        default_value='HBonds', allow_none=True, config=True, help='''Applying
        constraints to some of the atoms can enable you to take longer
        timesteps.''')
    rigid_water = CBool(True, config=True, help='''Keep water rigid. Be aware
        that flexible water may require you to further reduce the integration
        step size, typically to about 0.5 fs.''')
    cutoff = Quantity(1.0 * unit.nanometers, config=True,
        help='''Cutoff for long-range non-bonded interactions. This option is
        used for all non-bonded methods except for "NoCutoff".''')
    vdw_cutoff = Quantity(1.0 * unit.nanometers, config=True,
        help='''Specific cutoff for van der Waals interactions used in the 
        AMOEBA force field.''')
    pme_grid = List([24, 24, 24], config=True, help='''Specify a 3-vector for the PME grid
        dimensions in AMOEBA (also requires aewald parameter.)''', nargs=3)
    aewald = CFloat(5.4459052, config=True, allow_none=False, help='''The error
        tolerance is roughly equal to the fractional error in the forces due
        to truncating the Ewald summation.''')
    polarization = CaselessStrEnum(['direct', 'mutual'],
        default_value='mutual', allow_none=False, config=True, help='''Choose
        direct or mutual polarization for the AMOEBA polarizable force field.''')
    polar_eps = CFloat(1e-5, allow_none=False, config=True, help='''Choose
        SCF tolerance for polarizable force field (e.g. AMOEBA with polarization mutual).''')
    disp_corr = CBool(True, config=True, help='''Apply an isotropic long-range
    correction for the vdW interactions.''')
    rand_vels = CBool(True, config=True, help='''Initialize the system
        with random initial velocities, drawn from the Maxwell Boltzmann
        distribution.''')
    gen_temp = Quantity(300 * unit.kelvin, config=True, help='''Temperature
        used for generating initial velocities. This option is only used if
        rand_vels == True.''')
    is_amoeba_ff = CBool(False, config=False, help='''This flag is set after 
        the force field is read in, and signifies whether we have an AMOEBA
        force field.''')
    is_prmtop = CBool(False, config=False, help='''This flag is set after 
        the Amber .prmtop file is read in.''')
    is_gmxtop = CBool(False, config=False, help='''This flag is set after 
        the Gromacs .top file is read in.''')
    aewald_pmegrid = CBool(False, config=False, help='''This flag is set after 
        the force field is read in, and signifies whether we are specifying PME
        using "aewald + pme_grid" rather than "ewald_tol" for the AMOEBA force field.''')
    from_sysxml = CBool(False, config=False, help='''This flag is set if the system
        is obtained from a system XML file, which will override most user-provided options.''')

    # nonconfigurable traits
    xml_override = []

    def active_config_traits(self):
        """Construct a list of all of the configurable traits that are currently
        'active', in the sense that their value will have some effect on the
        simulation.
        """
        active_traits = []

        # So named because the system XML has the ability to override these.
        xmltraits = ['nb_method', 'constraints', 'rigid_water', 'rand_vels']

        if self.nb_method in ['CutoffPeriodic', 'PME', 'Ewald'] and not self.is_prmtop and not self.is_gmxtop:
            xmltraits.append('disp_corr')

        if self.nb_method in ['PME', 'Ewald']:
            xmltraits.append('ewald_tol')

        if self.nb_method != 'NoCutoff':
            xmltraits.append('cutoff')

        if self.rand_vels:
            xmltraits.append('gen_temp')

        if self.is_amoeba_ff:
            if self.aewald_pmegrid:
                # LPW: I introduced this because the user should be able
                # to specify aewald and pme_grid to override ewald_tol.
                xmltraits.remove('ewald_tol')
                xmltraits.append('aewald')
                xmltraits.append('pme_grid')
            xmltraits.append('vdw_cutoff')
            xmltraits.append('polarization')
            if self.polarization == 'mutual':
                xmltraits.append('polar_eps')

        if self.from_sysxml:
            self.xml_override = xmltraits
        else:
            active_traits = xmltraits
            
        return active_traits

    def validate(self):
        """Run some validation checks.
        """
        self.log.debug('Running system options validations.')
        # note that many of these checks are sort of redundant with the computation
        # of the active traits, but they provide a nicer english explanation of
        # what's wrong with the configuration, which is important for the user.
        if self.nb_method not in ['PME', 'Ewald'] and 'ewald_tol' in self.specified_config_traits:
            raise TraitError("The Ewald summation tolerance option, 'ewald_tol', "
                             "is only appropriate to set when 'nb_method' is "
                             "PME or Ewald.")
        if not self.rand_vels and 'gen_temp' in self.specified_config_traits:
            raise TraitError("The generation temperature option, 'gen_temp' "
                             "is only appropriate when 'rand_vels' is True")

class Dynamics(AppConfigurable):
    "Parameters for the integrator, thermostats and barostats."

    integrator = CaselessStrEnum(['Langevin', 'Verlet', 'Brownian',
        'VariableLangevin', 'VariableVerlet', 'VelocityVerlet'], config=True, allow_none=False,
        default_value='Langevin', help='''OpenMM offers a choice of several
        different integration methods. Refer to the user guide for
        details.''')
    tolerance = CFloat(0.0001, config=True, help='''Tolerance for variable
        timestep integrators ('VariableLangevin', 'VariableVerlet'). Smaller
        values will produce a smaller average step size.''')
    collision_rate = Quantity(1.0 / unit.picoseconds, config=True,
        help='''Friction coefficient, for use with stochastic integrators or
        the Anderson thermostat.''')
    temp = Quantity(300 * unit.kelvin, config=True, help='''Temperature
        of the heat bath, used either by a stochastic integrator or the
        Andersen thermostat to maintain a constant temperature ensemble.''')
    barostat = CaselessStrEnum(['MonteCarlo', 'MonteCarloAnisotropic', 'None'], allow_none=True,
        config=True, default_value='None', help='''Activate a barostat for
        pressure coupling. The MC barostat requires temperature control
        (stochastic integrator or Andersen thermostat) to be in effect
        as well.''')
    pressure = Quantity(1 * unit.atmosphere, config=True, help='''Pressure
        target, used by a barostat.''')
    pressure3 = List([1, 1, 1], config=True, help='''Pressure
        target, used by the Monte Carlo anisotropic barostat.''', nargs=3)
    scalex = CBool(True, config=True, help='''Switch for scaling the x-axis,
        used by the Monte Carlo anisotropic barostat.''')
    scaley = CBool(True, config=True, help='''Switch for scaling the y-axis,
        used by the Monte Carlo anisotropic barostat.''')
    scalez = CBool(True, config=True, help='''Switch for scaling the z-axis,
        used by the Monte Carlo anisotropic barostat.''')
    barostat_interval = CInt(25, config=True, help='''The frequency (in time
        steps) at which Monte Carlo pressure changes should be attempted.
        This option is only invoked when barostat in [MonteCarlo, MonteCarloAnisotropic].''')
    thermostat = CaselessStrEnum(['Andersen', 'None'], allow_none=True,
        config=True, default_value=None, help='''Activate a thermostat to
        maintain a constant temperature simulation.''')
    dt = Quantity(2 * unit.femtoseconds, config=True, help='''Timestep
        for fixed-timestep integrators.''')
    from_sysxml = CBool(False, config=False, help='''This flag is set if the system
        is obtained from a system XML file, which will override most user-provided options.''')

    # nonconfigurable traits
    xml_override = []

    def active_config_traits(self):
        """Construct a list of all of the configurable traits that are currently
        'active', in the sense that their value will have some effect on the
        simulation.
        """
        active_traits = ['integrator']

        # So named because the system XML has the ability to override these.
        xmltraits = []

        if self.integrator in ['Langevin', 'Verlet', 'VelocityVerlet', 'Brownian']:
            active_traits.append('dt')
        else:
            active_traits.append('tolerance')

        xmltraits.append('barostat')
        xmltraits.append('thermostat')

        if self.barostat == 'MonteCarlo':
            xmltraits.append('pressure')
            xmltraits.append('barostat_interval')

        if self.barostat == 'MonteCarloAnisotropic':
            xmltraits.append('pressure3')
            xmltraits.append('barostat_interval')
            xmltraits.append('scalex')
            xmltraits.append('scaley')
            xmltraits.append('scalez')

        if self.integrator in ['Langevin', 'VariableLangevin', 'Brownian']:
            active_traits.append('temp')
            active_traits.append('collision_rate')
        elif self.thermostat == 'Andersen':
            xmltraits.append('temp')
            xmltraits.append('collision_rate')

        if self.from_sysxml:
            self.xml_override = xmltraits
        else:
            active_traits += xmltraits

        return active_traits

    def validate(self):
        """Run some validation checks.
        """
        self.log.debug('Running dynamics options validations.')
        # note that many of these checks are sort of redundant with the computation
        # of the active traits, but they provide a nicer english explanation of
        # what's wrong with the configuration, which is important for the user.

        thermostatted = (self.integrator in ['Langevin', 'Brownian', 'VariableLangevin'] or
                         self.thermostat == 'Andersen')

        if 'tolerance' in self.specified_config_traits and self.integrator not in ['VariableLangevin', 'VariableVerlet']:
            raise TraitError("The variable integrator error threshold option, 'tolerance',"
                             "is only appropriate when using the VariableLangevin or "
                             "VariableVerlet integrators.")
        if 'dt' in self.specified_config_traits and self.integrator not in ['Langevin', 'Verlet', 'VelocityVerlet', 'Brownian']:
            raise TraitError("The timestep option, 'dt', is only appropriate when using "
                             "a fixed timestep integrator.")

        if 'collision_rate' in self.specified_config_traits and not thermostatted:
            raise TraitError("The friction coefficient option, 'collision_rate', is only "
                             "appropriate when using a stochastic integrator (e.g. Langevin, "
                             "Brownian, VariableLangevin) or an Andersen thermostat.")
        if 'temp' in self.specified_config_traits and not thermostatted:
            raise TraitError("The temperature target option, 'temp', is only "
                             "appropriate when using a thermostat or stochastic integrator.")

        if 'pressure' in self.specified_config_traits and not self.barostat == 'MonteCarlo':
            raise TraitError("The pressure target option, 'pressure', is only "
                             "appropriate when using the MonteCarlo barostat.")

        if 'barostat_interval' in self.specified_config_traits and self.barostat not in ['MonteCarlo', 'MonteCarloAnisotropic']:
            raise TraitError("The barostat interval option, 'barostat_interval', is only "
                             "appropriate when using the MonteCarlo or MonteCarloAnisotropic barostat.")

        if (self.barostat in ['MonteCarlo', 'MonteCarloAnisotropic']) and not thermostatted:
            raise TraitError("You should only use the MonteCarlo barostat on a system that is "
                             "under temperature control.")

    def get_integrator(self):
        "Fetch the integrator"

        if self.integrator == 'Langevin':
            self.application.script('integrator = mm.LangevinIntegrator(%s, %s, %s)'
                   % (self.temp, self.collision_rate, self.dt))
            return mm.LangevinIntegrator(self.temp, self.collision_rate, self.dt)
        elif self.integrator == 'Brownian':
            self.application.script('integrator = mm.BrownianIntegrator(%s, %s, %s)'
                   % (self.temp, self.collision_rate, self.dt))
            return mm.BrownianIntegrator(self.temp, self.collision_rate, self.dt)
        elif self.integrator == 'Verlet':
            self.application.script('integrator = mm.VerletIntegrator(%s)' % self.dt)
            return mm.VerletIntegrator(self.dt)
        elif self.integrator == 'VelocityVerlet':
            self.application.script('integrator = VelocityVerletIntegrator(%s)' % self.dt)
            return VelocityVerletIntegrator(self.dt)
        elif self.integrator == 'VariableVerlet':
            self.application.script('integrator = mm.VariableVerletIntegrator(%s)' %
                   self.tolerance)
            return VariableVerletIntegrator(self.tolerance)
        elif self.integrator == 'VariableLangevin':
            self.application.script('integrator = mm.VariableLangevinIntegrator(%s)' %
                   self.tolerance)
            return VariableLangevinIntegrator(self.tolerance)
        else:
            raise RuntimeError('unknown integrator')

    def get_forces(self):
        "Get additional OpenMM force objects to be added to the system"

        forces = []
        if self.barostat == 'MonteCarlo':
            self.application.script('system.addForce(mm.MonteCarloBarostat(%s, %s, %s)' %
                  (self.pressure, self.temp, self.barostat_interval))
            forces.append(mm.MonteCarloBarostat(self.pressure, self.temp,
                                                self.barostat_interval))
        if self.barostat == 'MonteCarloAnisotropic':
            self.application.script('system.addForce(mm.MonteCarloAnisotropicBarostat(%s, %s, %s, %s, %s, %s)' %
                  (self.pressure3, self.temp, self.barostat_interval, self.scalex, self.scaley, self.scalez))
            forces.append(mm.MonteCarloAnisotropicBarostat(self.pressure3, self.temp,
                                                           self.barostat_interval, self.scalex, self.scaley, self.scalez))
        if self.thermostat == 'Andersen':
            self.application.script('system.addForce(mm.AndersenThermostat(%s, %s)' %
                  (self.temp, self.collision_rate))
            forces.append(
                mm.AndersenThermostat(self.temp, self.collision_rate))

        return forces


class Simulation(AppConfigurable):
    """Parameters for the simulation, including the mode and frequency
    with which files are saved to disk, the number of steps, etc."""

    n_steps = CInt(10000, config=True, help='''Number of steps of simulation
        to run.''')
    minimize = CBool(True, config=True, help='''First perform local energy
        minimization, to find a local potential energy minimum near the
        starting structure.''')
    traj_file = CBytes('output.dcd', config=True, help='''Filename to save the
        resulting trajectory to, in DCD format.''')
    traj_freq = CInt(1000, config=True, help='''Frequency, in steps, to
        save the state to disk in the DCD format.''')
    progress_freq = CInt(1000, config=True, help='''Frequency, in steps,
        to print summary statistics on the state of the simulation.''')
    restart_file = CBytes('restart.json.bz2', config=True, help='''Filename for
        reading/writing the restart file.''')
    restart_freq = CInt(5000, config=True, help='''Frequency, in steps, to
        save the restart file.''')
    read_restart = CBool(False, config=True, help='''Switch for whether to
        read restart information from file.''')
    write_restart = CBool(True, config=True, help='''Switch for whether to
        write restart information to file.''')

    def validate(self):
        self.log.debug('Running simulation options validations.')
        if self.read_restart and not os.path.isfile(self.restart_file):
            raise TraitError("The simulation cannot be restarted, because the restart file does not exist.")

class OpenMM(OpenMMApplication):
    short_description = 'OpenMM: GPU Accelerated Molecular Dynamics'
    long_description = '''Run a molecular simulaton using the OpenMM toolkit.

    All options can be either specified in the configuration file or on the
    command line. Command line options override those specified in a config
    file.

    Note: If you have issues specifying units on the command like, like
    `openmm --dt 2*fs` causing a "no matches found" error beacuase your shell
    is trying to interpret the '*' as a wildcard, you can put the expression
    in single quotes (`openmm --dt '2*fs'`) or change the shell's nomatch behavior.
    This can be done with `setopt nonomatch` (zsh), `set nonomatch` (tcsh), or
    `shopt -u nullglob` (bash, but this behavior is already the default in bash).
    '''

    # Configured Classes. During initialization, these guys are
    # instantiated based on the config file / command line when
    # initialize_configured_classes() is executed
    classes = [General, System, Dynamics, Simulation]
    general = Instance(General)
    system = Instance(System)
    dynamics = Instance(Dynamics)
    simulation = Instance(Simulation)

    config_file_path = CBytes('config.in.ini', config=True, help="""Path to a
        configuration file to load from. The configuration files contains settings
        for all of the MD options. Every option can be either set in the config
        file and/or the command line. (see `--help-all`).""")
    config_file_out = CBytes('config.out.ini', config=True, help="""Write a
        config file containing all of the active options used by this
        simulation.""")
    show_script = CBool(True, config=True, help="""Print a script that can be
        used to run this simulation using the OpenMM python API. This is useful
        for learning the API, and as a starting point for producing more complex
        OpenMM applications for making further customizations.""")
    def _show_script_default(self):
        if self.log_level in [30, 40, 'ERROR', 'CRITICAL']:
            return False
        return True
    def _log_level_default(self):
        return logging.INFO

    _script_initialized = CBool(False)

    # The alias table gives all of the options that are shown on the -h.
    # The other options from the general, system, dyanmics and simulation
    # configurables are only shown to the user on --help-all.
    aliases = {'log_level': 'OpenMM.log_level',
               'out': 'OpenMM.config_file_out',
               'config': 'OpenMM.config_file_path',
               'script': 'OpenMM.show_script'}

    def initialize(self, argv=None):
        """Initialize this class. Parses the configuration file, loads up
        all of the configured classes, runs validation methods, etc.
        """
        try:
            super(OpenMM, self).initialize(argv)
            self.initialize_configured_classes()
            self.validate()
        except (TraitError, AliasError) as e:
            self.error(e)

    def validate(self):
        """Run validation on the whole configuration tree. This method runs
        validations that cross between different AppConfigurables, and then
        delegates to each AppConfigurable individually (this happens in super)
        for it to run its within-class validations.
        """
        super(OpenMM, self).validate()
        self.log.debug('Running global options validations.')
        if self.dynamics.integrator in ['Langevin', 'Verlet', 'VelocityVerlet']:
            if self.system.constraints is None and self.dynamics.dt > 1*unit.femtoseconds:
                raise TraitError('You are likely using too large a timestep. With the '
                                 'Langevin or Verlet integrators, without constraints a '
                                 'timestep over 1 femtosecond is not recommended.')
            if self.system.constraints in ['HBonds', 'AllBonds'] and self.dynamics.dt > 2*unit.femtoseconds:
                raise TraitError('You are likely using too large a timestep. With the '
                                 'Langevin or Verlet integrators and bond constraints, a '
                                 'timestep over 2 femtoseconds is not recommended.')
            if self.system.constraints == 'HAngles' and self.dynamics.dt > 4*unit.femtoseconds:
                raise TraitError('You are likely using too large a timestep. With the '
                                 'Langevin or Verlet integrators and HAngle constraints, a '
                                 'timestep over 4 femtoseconds is not recommended.')

        if ((self.general.platform != 'Reference') and
            (self.general.precision in ['Single', 'Mixed'])  and
            (self.system.nb_method == 'PME') and (self.system.ewald_tol < 5e-5)):
            raise TraitError('Your ewald error tolerance is so low that is numerical '
                             'error is likely to cause the forces to become less accurate, '
                             'not more. Very small error tolerances only work in double '
                             'precision. (This only applies to PME. Ewald has no problem '
                             'with them.')
        if (self.general.water == 'Implicit') and  (self.system.nb_method in ['CutoffPeriodic', 'Ewald', 'PME']):
            raise TraitError('Using periodic boundary conditions with implict solvent? '
                             'That\'s a very strange choice.  You don\'t really want '
                             'periodic boundary conditions with implicit solvent, do you?')
        if (self.dynamics.barostat is not None) and  (self.system.nb_method in  ['NoCutoff', 'CutoffNonPeriodic']):
            raise TraitError("It doesn't make sense to use a barostat with no cutoffs, "
                             "since %s implies you're using a nonperiodic system. But adjusting "
                             "the box volume (the way that the barostat controls the pressure) "
                             "will have no effect." % self.system.nb_method)

    def validate_system(self):
        """Run validation on the force field and build the option dictionary that gets passed to createSystem()."""
        if self.system.is_amoeba_ff and (self.general.platform not in ['Reference', 'CUDA'] or
            (self.general.platform == 'NotSpecified' and self.general.fastest_platform not in ['Reference', 'CUDA'])):
                self.error("The AMOEBA force field is only implemented on the Reference or CUDA platforms.")

        # The option dictionary that gets passed to createSystem() is initialized.
        option_dict = OrderedDict([('nonbondedMethod',getattr(app, self.system.nb_method)),
                                   ('constraints',getattr(app, self.system.constraints)),
                                   ('rigidWater',self.system.rigid_water)])

        # Set the nonbonded cutoff if we're using nonbonded cutoffs.
        if self.system.nb_method in ['CutoffNonPeriodic', 'CutoffPeriodic', 'Ewald', 'PME']:
            option_dict['nonbondedCutoff'] = self.system.cutoff

        # Set the dispersion correction if we have periodic boundary conditions.
        if self.system.nb_method in ['CutoffPeriodic', 'Ewald', 'PME'] and not self.system.is_prmtop and not self.system.is_gmxtop:
            option_dict['useDispersionCorrection'] = self.system.disp_corr

        if self.system.is_amoeba_ff:
            # Set AMOEBA specific options.
            if 'vdw_cutoff' in self.system.specified_config_traits:
                option_dict['vdwCutoff'] = self.system.vdw_cutoff
            option_dict['polarization'] = self.system.polarization
            # Set polar_eps if using mutual polarization
            if self.system.polarization == 'mutual':
                option_dict['polar_eps'] = self.system.polar_eps
            elif 'polar_eps' in self.system.specified_config_traits:
                self.error("The polar_eps option is used in mutual polarization only.")
            # The logic here allows the user to override ewald_tol if they specify both aewald and pme_grid.
            # However, if all three are specified, it errors out.
            if any([i in self.system.specified_config_traits for i in ['aewald', 'pme_grid']]):
                if not all([i in self.system.specified_config_traits for i in ['aewald', 'pme_grid']]):
                    self.error("To use aewald/pme_grid with the AMOEBA force field, both options must be specified.")
                else:
                    self.system.aewald_pmegrid = True
                    if 'ewald_tol' in self.system.specified_config_traits:
                        self.error("You may specify either aewald+pme_grid or ewald_tol, but you cannot specify all three.")
        else:
            # Crash if the user specifies AMOEBA options without the AMOEBA force field.
            for i in ['aewald', 'pme_grid', 'vdw_cutoff', 'polarization', 'polar_eps']:
                if i in self.system.specified_config_traits:
                    self.error("The %s option is only valid when using the AMOEBA polarizable force field.")

        if self.system.nb_method in ['Ewald', 'PME']:
            if self.system.is_amoeba_ff:
                if not self.system.aewald_pmegrid:
                    option_dict['ewaldErrorTolerance'] = self.system.ewald_tol
                else:
                    option_dict['pmeGridDimensions'] = [int(i) for i in self.system.pme_grid]
                    option_dict['aEwald'] = self.system.aewald
            else:
                option_dict['ewaldErrorTolerance'] = self.system.ewald_tol

        return option_dict
        
    def start(self):

        topology = self.general.get_topology()
        positions = self.general.get_positions()

        # Set up the system from system XML file.
        if 'sysxml' in self.general.specified_config_traits:
            self.system.from_sysxml = True
            self.dynamics.from_sysxml = True 
            system = self.general.get_system_from_sysxml()
            
        else:
            # Set up the system from AMBER prmtop file.
            if 'prmtop' in self.general.specified_config_traits:
                prmtop = self.general.prmtop_file
                self.system.is_prmtop = True

            elif 'gmxtop' in self.general.specified_config_traits:
                gmxtop = self.general.gmxtop_file
                self.system.is_gmxtop = True

            # Set up the system from force field XML files.
            elif len(self.general.ffxml) > 0 or any(i in self.general.specified_config_traits for i in ['protein', 'water']):
                forcefield, self.system.is_amoeba_ff = self.general.get_forcefield()
                modeller = app.Modeller(topology, positions)
                modeller.addExtraParticles(forcefield)
                topology = modeller.topology
                positions = modeller.positions

            # Perform a second validation step and generate options for setting up the system.
            system_options = self.validate_system()

            # Need to create a second dictionary for printing because of "enums".
            print_options = OrderedDict([(i, j) for i, j in system_options.items()])
            print_options['nonbondedMethod'] = self.system.nb_method
            print_options['constraints'] = self.system.constraints

            # Create the system object.
            if 'prmtop' in self.general.specified_config_traits:
                self.script('system = prmtop.createSystem('
                            + ','.join(["%s=%s" % (key,val) for key, val in print_options.items()])+')')
                system = prmtop.createSystem(**system_options)

            elif 'gmxtop' in self.general.specified_config_traits:
                self.script('system = gmxtop.createSystem('
                            + ','.join(["%s=%s" % (key,val) for key, val in print_options.items()])+')')
                system = gmxtop.createSystem(**system_options)

            elif len(self.general.ffxml) > 0 or any(i in self.general.specified_config_traits for i in ['protein', 'water']) :
                self.script('system = forcefield.createSystem(topology,' 
                            + ','.join(["%s=%s" % (key,val) for key, val in print_options.items()])+')')
                system = forcefield.createSystem(topology, **system_options)

            else:
                self.error("You did not provide enough information to create "
                           "the System object!  Valid options are:\n"
                           "(1) Specify a protein force field and/or water model using "
                           "--protein and --forcefield arguments\n"
                           "(2) Specify a force field XML file using --ffxml argument\n"
                           "(3) Specify a GROMACS or AMBER prmtop file using --gmxtop or --prmtop argument ")

            # Add thermostat and barostat forces.
            for force in self.dynamics.get_forces():
                system.addForce(force)

        # Set up the system from system XML file.
        if 'serialize' in self.general.specified_config_traits:
            self.script("serial = mm.XmlSerializer.serializeSystem(system)")
            serial = mm.XmlSerializer.serializeSystem(system)
            backup_file(self.general.serialize, self.log)
            self.script("with open(%s, 'w') as f: f.write(serial)" % self.general.serialize)
            with open(self.general.serialize, 'w') as f: f.write(serial)

        integrator = self.dynamics.get_integrator()
        platform = self.general.get_platform()
        properties = self.general.get_platform_properties()

        self.print_config()
        self.generate_config_file()

        self.script('simulation = app.Simulation(topology, system, integrator, platform, properties)')
        simulation = app.Simulation(topology, system, integrator, platform, properties)

        if self.simulation.read_restart:
            self.log.info("Restarting simulation by reading from %s." % self.simulation.restart_file)
            loadRestartFile(simulation, self.simulation.restart_file)
        else:
            self.script('simulation.context.setPositions(positions)')
            simulation.context.setPositions(positions)

            if self.simulation.minimize:
                self.script('simulation.minimizeEnergy()')
                simulation.minimizeEnergy()

            if self.system.rand_vels:
                self.script('simulation.context.setVelocitiesToTemperature()')
                simulation.context.setVelocitiesToTemperature(self.system.gen_temp)

        if self.simulation.progress_freq > 0:
            self.script('simulation.reporters.append(ProgressReporter(sys.stdout, %s, %s))'
                        % (self.simulation.progress_freq, self.simulation.n_steps))
            simulation.reporters.append(ProgressReporter(sys.stdout,
                self.simulation.progress_freq, self.simulation.n_steps))

        if self.simulation.traj_freq > 0:
            backup_file(self.simulation.traj_file, self.log)
            self.script('simulation.reporters.append(DCDReporter(%s, %s))'
                        % (self.simulation.traj_file, self.simulation.traj_freq))
            simulation.reporters.append(app.DCDReporter(self.simulation.traj_file,
                self.simulation.traj_freq))

        if self.simulation.write_restart and self.simulation.restart_freq > 0:
            backup_file(self.simulation.restart_file, self.log)
            self.log.info("Will write restart information every %i steps to %s."
                          % (self.simulation.restart_freq, self.simulation.restart_file))
            self.script('simulation.reporters.append(RestartReporter(%s, %s))'
                        % (self.simulation.restart_file, self.simulation.restart_freq))
            simulation.reporters.append(RestartReporter(self.simulation.restart_file, self.simulation.restart_freq))

        self.script('simulation.step(%s)' % self.simulation.n_steps)
        if self.show_script:
            print

        self.log.info('Number of available platforms: %d' % mm.Platform.getNumPlatforms())
        self.log.info('Selected Platform: %s', platform.getName())
        for key in platform.getPropertyNames():
            self.log.info('%s = %s', key, platform.getPropertyValue(simulation.context, key))
        print('')

        force_reporters(simulation)
        simulation.step(self.simulation.n_steps)

        # before exiting, write a restart file
        force_reporters(simulation, RestartReporter)
        print("#=================================================#")
        print("#| Congratulations, your simulation has finished |#")
        print("#|      And if you don't know, now you know!     |#")
        print("#=================================================#")

    def script(self, msg):
        if not self.show_script:
            return

        def c(line):
            lines = wrap_paragraphs(line, 80)[0].splitlines()
            if len(lines) == 0:
                lines = ['']
            lines[0] = '>>> ' + lines[0]
            for i in range(1, len(lines)):
                lines[i] = '... ' + lines[i]
            return os.linesep.join(lines)

        if not self._script_initialized:
            lines = [
                '###################################################################',
                '# To use this script, paste all of the lines starting with the',
                '# three greater than signs (>>>) into a python interpreter.',
                '###################################################################',
                ' ',
                'from simtk.unit import *',
                'from simtk import openmm as mm',
                'from simtk.openmm import app',
                ' ']
            print
            print(os.linesep.join(map(c, lines)))
            self._script_initialized = True

        print(c(msg))

    def generate_config_file(self):
        """Write a configuration file containing all of the active options
        to self.config_file_out
        """
        if self.config_file_out == '':
            return

        backup_file(self.config_file_out, self.log)

        # create the config file
        lines = ['# Configuration file for openmm',
                 '# Generated on %s, %s' % (platform.node(), datetime.now()),
                 '# OpenMM version %s' % mm.Platform.getOpenMMVersion(),
                 '# Invocation command line: %s' % ' '.join(sys.argv),
                 '']
        lines.append(self.general.config_section())
        lines.append(self.system.config_section())
        lines.append(self.dynamics.config_section())
        lines.append(self.simulation.config_section())

        self.log.info('Writing config file to %s' % self.config_file_out)
        with open(self.config_file_out, 'w') as f:
            print('\n'.join(lines), file=f)


    def print_config(self):
        "Print a description of the config file, without comments"
        if not self.log.isEnabledFor(logging.INFO):
            # we need to be printing INFO or better to print the config
            return

        col1, col2 = [], []
        for kls in [getattr(self, c.__name__.lower(), None) for c in self.classes]:
            if kls is None:
                continue

            col1.append('\n[%s]' % kls.__class__.__name__)
            col2.append('')
            for traitname in kls.active_config_traits():
                trait = getattr(kls, traitname)
                if isinstance(trait, list):
                    trait = ' '.join([str(i) for i in trait])

                col1.append('%s = %s' % (traitname, trait))
                if traitname in kls.specified_config_traits:
                    col2.append('# your selection')
                else:
                    col2.append('# default value')

        width = max(len(e) for e in col1)

        header = '# Option Summary. A more detailed config file is saved to %s' % self.config_file_out
        breaker = '#' * len(header)
        print(os.linesep.join([breaker, header, breaker]))
        print(os.linesep.join(opt.ljust(width) + '  ' + comment for opt, comment in zip(col1, col2)))
        print('')


#----------------------------------------------------------------------------
# Utilities
#----------------------------------------------------------------------------

def backup_file(fnm, logger):
    # Backs up a file that's about to be overwritten (basename.ext)
    # by moving it to basename_#.ext where # is the first number that
    # doesn't already have an existing file.
    oldfnm = fnm
    if os.path.exists(oldfnm):
        base, ext = os.path.splitext(fnm)
        i = 1
        while os.path.exists(fnm):
            fnm = "%s_%i%s" % (base, i, ext)
            i += 1
        logger.info("Backing up %s -> %s" % (oldfnm, fnm))
        shutil.move(oldfnm, fnm)


def force_reporters(simulation, reporter_class=None):
    """Force one all of the reporters on the simulation to run.

    Parameters
    ----------
    simulation : Simulation
        The simulation to which the reporters are attached
    reporter_class : Reporter or None
        If supplied, only reporters that are instances of reporter_class will
        be triggered.
    """
    gets = [False, False, False, False]

    if reporter_class is None:
        reporters = simulation.reporters
    else:
        reporters = [r for r in simulation.reporters if isinstance(r, reporter_class)]

    if len(reporters) == 0:
        return

    for reporter in reporters:
        gets = [(a or b) for a, b in zip(gets, reporter.describeNextReport(simulation)[1:])]

    args = gets + [True, (simulation.topology.getUnitCellDimensions() is not None)]
    state = simulation.context.getState(*args)

    for reporter in reporters:
        reporter.report(simulation, state)


if __name__ == '__main__':
    openmm = OpenMM.instance()
    openmm.initialize()
    openmm.start()
