Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow torch model to be used as Integrator #56

Open
jchodera opened this issue Jan 15, 2022 · 5 comments
Open

Allow torch model to be used as Integrator #56

jchodera opened this issue Jan 15, 2022 · 5 comments
Labels
enhancement New feature or request

Comments

@jchodera
Copy link
Member

jchodera commented Jan 15, 2022

I wonder if we could also support torch models used as Integrators in OpenMM.

Perhaps something like this could work:

import torch

class IntegratorModule(torch.nn.Module):
    """A BAOAB Langevin integrator"""
    def forward(self, positions, velocities, timestep, temperature):
        """The forward method returns the updated positions and velocities given the current timestep and temperature.

        Parameters
        ----------
        positions : torch.Tensor with shape (nparticles,3)
           positions[i,k] is the position (in nanometers) of spatial dimension k of particle i
        velocities : torch.Tensor with shape (nparticles,3)
           velocities[i,k] is the velocity (in nanometers/picosecond) of spatial dimension k of particle i
        masses : torch.Tensor with shape (nparticles,3)
           masses[i,k] is the position (in amu) of particle i for all k
        timestep : torch.Tensor with shape (,)
            the integration timestep (in femtoseconds)
        temperature : torch.Tensor with shape (,)
            the temperature in kelvin)

        Returns
        -------
        positions : torch.Tensor with shape (nparticles,3)
           positions[i,k] is the position (in nanometers) of spatial dimension k of particle i
        velocities : torch.Tensor with shape (nparticles,3)
           velocities[i,k] is the velocity (in nanometers/picosecond) of spatial dimension k of particle i
        """
        forces = openmm_compute_forces(positions)
        velocities = velocities + (timestep/2) * (forces/masses)
        positions = positions + timestep * velocities
        forces = openmm_compute_forces(positions)
        velocities = velocities + (timestep/2) * (forces/masses)

        return positions, velocities

# Render the compute graph to a TorchScript module
module = torch.jit.script(IntegratorModule())

# Serialize the compute graph to a file
module.save('integrator.pt')

Here, we would have to extend TorchScript with custom Ops openmm_compute_forces, openmm_compute_potential, and openmm_compute_potential_and_forces, which would wrap the normal OpenMM energy/force computation. Optimally, these C++ functions would know when the force or potential did not need to be recomputed (because no particles moved) if called at the end of one step and at the beginning of the next step.

To use the integrator in a simulation, the user would create a TorchIntegrator object that would behave much like a normal Integrator:

from openmm import unit
timestep = 4.0*unit.femtoseconds
temperature = 300*unit.kelvin

# Create the TorchIntegrator from the serialized compute graph
from openmmtorch import TorchIntegrator
torch_integrator = TorchIntegrator('integrator.pt', temperature, timestep)

# Create a Context with the integrator
context = openmm.Context(system, torch_integrator)

# Run some dynamics
torch_integrator.step(100)

# Change the temperature and timestep
torch_integrator.setTemperature(100*unit.kelvin)
torch_integrator.setTimestep(2.0*unit.femtoseconds)

Edit: It would also be important to enable the integrator to modify global parameters, as well as define its own that can be accessed through the OpenMM API. I'm not quite sure how that would work, however.

@peastman
Copy link
Member

That assumes a different computational model from how OpenMM works. For example, it has no concept of computing forces for an arbitrary set of positions. It can only compute them for the current positions that are set in the context.

Instead, it would be best to define functions that closely map to the operations in CustomIntegrator: get and set positions, get and set velocities, compute forces and/or energy for the current positions, apply constraints, etc.

@raimis raimis added the enhancement New feature or request label Jan 18, 2022
@raimis
Copy link
Contributor

raimis commented Jan 18, 2022

It would be possible to implement a support for such models:

class IntegratorModule(torch.nn.Module):
    def forward(self, positions, velocities, forces, arbitrary_number_of_scalars):

        # Do some computation

        return new_positions, new_velocities

@peastman
Copy link
Member

def forward(self, positions, velocities, forces, arbitrary_number_of_scalars):

That assumes you only need the forces at the start of the step. Integrators often need the forces at multiple points throughout the step. Here's an example of a velocity verlet integrator implemented with CustomIntegrator:

CustomIntegrator integrator(0.001);
integrator.addPerDofVariable("x1", 0);
integrator.addUpdateContextState();
integrator.addComputePerDof("v", "v+0.5*dt*f/m");
integrator.addComputePerDof("x", "x+dt*v");
integrator.addComputePerDof("x1", "x");
integrator.addConstrainPositions();
integrator.addComputePerDof("v", "v+0.5*dt*f/m+(x-x1)/dt");
integrator.addConstrainVelocities();

I imagine that the PyTorch implementation might look something like this:

class VelocityVerletIntegrator(openmmtorch.Integrator):
  def forward(self):
    dt = self.dt
    m = self.m
    self.updateContextState()
    self.v = self.v+0.5*dt*self.f/m
    x1 = self.x+dt*self.v
    self.x = x1
    self.constrainPositions()
    self.v = self.v+0.5*dt*self.f/m+(self.x-x1)/dt
    self.constrainVelocities()

I assumed that openmmtorch.Integrator is a subclass of torch.nn.Module, and that it defines lots of properties like x and f, which are implemented as functions that get or set information in the Context.

@raimis
Copy link
Contributor

raimis commented Jan 19, 2022

Under the hood, PyTorch constructs a computational graph, which represents the operations and associated input-output dependencies.

Probably it is possible to wrap openmm::Context into torch::CustomClass and let PyTorch to act on it:

class Integrator(openmmtorch.Integrator):

   def forward(self, state):
     dt = self.dt
     m = self.m

     state = updateState(state)
     v = getVelocities(state)
     f = getForces(state)
     x = getPositions(state)
     v = v+0.5*dt*f/m
     x1 = x+dt*v
     state = setPositions(state, x1)
     state = constrainPositions(state)
     x = getPositions(state)
     f = getForces(state)
     v = v+0.5*dt*f/m+(x-x1)/dt
     state = setVelocities(state, v)
     state = constrainVelocities(state)

     return state

@raimis
Copy link
Contributor

raimis commented Jan 19, 2022

@jchodera do you have some examples which go beyond the capability of CustomIntegrator. How much more flexibility we want?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants