Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to create a CompoundThermodynamicState from a torchForce object with global parameter #147

Open
xiaowei-xie2 opened this issue Jun 24, 2024 · 3 comments

Comments

@xiaowei-xie2
Copy link

Hi,

I would like to do a Hamiltonian REMD with custom defined states, with each state specified by a torchForce object with a different global parameter. But I am having trouble creating a CompoundThermodynamicState to be used with ReplicaExchangeSampler. I can create a GlobalParameterState, but when I use that to create CompoundThermodynamicState, it complains there is no global parameter in the system. I have no trouble doing the same thing with a MM force field. Any idea what might be going wrong?

Thank you!

Here is the structure of the code I was using:

force = TorchForce('model.pt')
force.addGlobalParameter('a', 0.5)
force.addGlobalParameter('b', 0.3)
force.setUsesPeriodicBoundaryConditions(True)

# define system
system = ...

# Remove MM constraints
while system.getNumConstraints() > 0:
  system.removeConstraint(0)

# Remove MM forces
while system.getNumForces() > 0:
  system.removeForce(0)

assert system.getNumConstraints() == 0
assert system.getNumForces() == 0

system.addForce(force)

barostat = MonteCarloBarostat(1*bar, 298.15*kelvin)
system.addForce(barostat)

class LambdaState(GlobalParameterState):
    a = GlobalParameterState.GlobalParameter('a', standard_value=1.0)
    b = GlobalParameterState.GlobalParameter('b', standard_value=1.0)

    def set_rest_parameters(self, value_a, value_b):
        """Set all defined lambda parameters to the given value.

        The undefined parameters (i.e. those being set to None) remain undefined.

        Parameters
        ----------
        new_value : float
            The new value for all defined parameters.
        """
        lambda_functions = {'a': lambda a, b : value_a,
                 'b' : lambda a, b : value_b,
                 }

        for parameter_name in self._parameters:
            if self._parameters[parameter_name] is not None:
                new_value = lambda_functions[parameter_name](a, b)
                setattr(self, parameter_name, new_value)


lambda_state = LambdaState(a=0.5, b=0.3)
print('lambda_state.a:', lambda_state.a)
print('lambda_state.b:', lambda_state.b)

thermostate = ThermodynamicState(system, temperature=298.15 * unit.kelvin)
compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[lambda_state])

And I am getting the following error:

lambda_state.a: 0.5
lambda_state.b: 0.3
Traceback (most recent call last):
  File "/scr/xie1/training_xtb_test/openmm_FEP_lambdastate_REMD.py", line 129, in <module>
    compound_thermodynamic_state = CompoundThermodynamicState(thermostate, composable_states=[lambda_state])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xie1/miniconda3/lib/python3.12/site-packages/openmmtools/states.py", line 2790, in __init__
    self.set_system(self._standard_system, fix_state=True)
  File "/home/xie1/miniconda3/lib/python3.12/site-packages/openmmtools/states.py", line 2843, in set_system
    s.apply_to_system(system)
  File "/home/xie1/miniconda3/lib/python3.12/site-packages/openmmtools/states.py", line 3521, in apply_to_system
    raise self._GLOBAL_PARAMETER_ERROR(err_msg.format(parameter_name))
openmmtools.states.GlobalParameterError: Could not find global parameter a in the system.
@peastman
Copy link
Member

Is this specific to TorchForce? What if you change

force = TorchForce('model.pt')

to

force = CustomBondForce('r')

leaving everything else the same, including the calls to addGlobalParameter()? Do you get the same error? If so, the problem isn't related to TorchForce, and you should probably ask at https://github.com/choderalab/openmmtools. On the other hand, if that works and the problem really is related to TorchForce, can you post a complete example with all the files needed to reproduce it?

@xiaowei-xie2
Copy link
Author

Thank you for the reply! I went ahead and tried your suggestion and it seems the problem is specific to TorchForce. I created a simple example to reproduce this behavior below. Also I think I might have found a workaround (inspired by the openmmml package) by adding the following lines.

cv = openmm.CustomCVForce("")
cv.addGlobalParameter("param_a", 1)
cv.addGlobalParameter("param_b", 1)
tempSystem = openmm.System()
tempSystem.addForce(force)
interactingVarNames = []
for idx, force in enumerate(tempSystem.getForces()):
    name = f"allForce{idx+1}"
    cv.addCollectiveVariable(name, copy.deepcopy(force))
    interactingVarNames.append(name)

assert len(interactingVarNames) > 0 

interactingSum = "+".join(interactingVarNames)

cv.setEnergyFunction(
    f"({interactingSum})"
)

system.addForce(cv)

In this openmm_files.tar.gz I have 3 files mmforce.py, torchforce.py and torchforce_workaround.py and their corresponding outputs. You can see that only the force object is changed between the files.

Please let me know if my workaround is correct?

openmm_files.tar.gz

@peastman
Copy link
Member

@mikemhenry @ijpulidos can you take a look at this? This error is happening because of an interaction between SWIG and openmmtools.

_get_system_controlled_parameters() tries to find the list of global parameters by looping over all forces and looking for methods called getNumGlobalParameters() and getGlobalParameterName().

for force_index in range(system.getNumForces()):
    force = system.getForce(force_index)
    try:
        n_global_parameters = force.getNumGlobalParameters()
    except AttributeError:
        continue
    for parameter_id in range(n_global_parameters):
        parameter_name = force.getGlobalParameterName(parameter_id)
        if parameter_name in searched_parameters:
            yield force, parameter_name, parameter_id

The problem is that SWIG can only return the correct Python Force subclass from getForce() for built in classes. If the force was defined by a plugin, it just returns an instance of the abstract Force class. That's just referring to the Python wrapper, of course. The C++ object it wraps has the correct class. The TorchForce Python wrapper provides static isinstance() and cast() methods for checking whether something is a wrapped TorchForce and casting it to the correct Python class.

The robust way of getting a list of all global parameters is to call getParameters() on a Context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants