From ac677ee6e44d683c047beef5c8a9c6afaec88cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Tue, 26 Apr 2022 20:37:07 -0400 Subject: [PATCH 1/9] Handling hybrid topology factor (htf) from input file --- perses/app/setup_relative_calculation.py | 78 ++++++++++++++---------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/perses/app/setup_relative_calculation.py b/perses/app/setup_relative_calculation.py index dc819c2cd..b8655dc49 100644 --- a/perses/app/setup_relative_calculation.py +++ b/perses/app/setup_relative_calculation.py @@ -7,7 +7,7 @@ import logging from pathlib import Path -from perses.annihilation.relative import HybridTopologyFactory +from perses.annihilation.relative import HybridTopologyFactory, RESTCapableHybridTopologyFactory from perses.app.relative_setup import RelativeFEPSetup from perses.annihilation.lambda_protocol import LambdaProtocol @@ -306,6 +306,12 @@ def getSetupOptions(filename): if not 'rmsd_restraint' in setup_options: setup_options['rmsd_restraint'] = False + # Handling htf input parameter + if 'hybrid_topology_factory' not in setup_options: + default_htf_class_name = "HybridTopologyFactory" + setup_options['hybrid_topology_factory'] = default_htf_class_name + _logger.info(f"\t 'hybrid_topology_factory' not specified: default to {default_htf_class_name}") + os.makedirs(trajectory_directory, exist_ok=True) @@ -562,33 +568,24 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True): else: _internal_parallelism = None - ne_fep = dict() for phase in phases: _logger.info(f"\t\tphase: {phase}") - hybrid_factory = HybridTopologyFactory(top_prop['%s_topology_proposal' % phase], - top_prop['%s_old_positions' % phase], - top_prop['%s_new_positions' % phase], - neglected_new_angle_terms = top_prop[f"{phase}_forward_neglected_angles"], - neglected_old_angle_terms = top_prop[f"{phase}_reverse_neglected_angles"], - softcore_LJ_v2 = setup_options['softcore_v2'], - interpolate_old_and_new_14s = setup_options['anneal_1,4s'], - rmsd_restraint=setup_options['rmsd_restraint'], - ) + hybrid_factory = _generate_htf(phase, top_prop, setup_options) if build_samplers: - ne_fep[phase] = SequentialMonteCarlo(factory = hybrid_factory, - lambda_protocol = setup_options['lambda_protocol'], - temperature = temperature, - trajectory_directory = trajectory_directory, - trajectory_prefix = f"{trajectory_prefix}_{phase}", - atom_selection = atom_selection, - timestep = timestep, - eq_splitting_string = eq_splitting, - neq_splitting_string = neq_splitting, - collision_rate = setup_options['ncmc_collision_rate_ps'], - ncmc_save_interval = ncmc_save_interval, - internal_parallelism = _internal_parallelism) + ne_fep[phase] = SequentialMonteCarlo(factory=hybrid_factory, + lambda_protocol=setup_options['lambda_protocol'], + temperature=temperature, + trajectory_directory=trajectory_directory, + trajectory_prefix=f"{trajectory_prefix}_{phase}", + atom_selection=atom_selection, + timestep=timestep, + eq_splitting_string=eq_splitting, + neq_splitting_string=neq_splitting, + collision_rate=setup_options['ncmc_collision_rate_ps'], + ncmc_save_interval=ncmc_save_interval, + internal_parallelism=_internal_parallelism) print("Nonequilibrium switching driver class constructed") @@ -604,15 +601,7 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True): _logger.info(f"\t\tphase: {phase}:") #TODO write a SAMSFEP class that mirrors NonequilibriumSwitchingFEP _logger.info(f"\t\twriting HybridTopologyFactory for phase {phase}...") - htf[phase] = HybridTopologyFactory(top_prop['%s_topology_proposal' % phase], - top_prop['%s_old_positions' % phase], - top_prop['%s_new_positions' % phase], - neglected_new_angle_terms = top_prop[f"{phase}_forward_neglected_angles"], - neglected_old_angle_terms = top_prop[f"{phase}_reverse_neglected_angles"], - softcore_LJ_v2 = setup_options['softcore_v2'], - interpolate_old_and_new_14s = setup_options['anneal_1,4s'], - rmsd_restraint=setup_options['rmsd_restraint'] - ) + htf[phase] = _generate_htf(phase, top_prop, setup_options) for phase in phases: # Define necessary vars to check energy bookkeeping @@ -1024,5 +1013,30 @@ def _resume_run(setup_options): raise("Can't resume") +def _generate_htf(phase: str, topology_proposal_dictionary: dict, setup_options: dict): + """ + Generates topology proposal for phase. + """ + factory_name = setup_options['hybrid_topology_factory'] + if factory_name == HybridTopologyFactory.__name__: + factory = HybridTopologyFactory + elif factory_name == RESTCapableHybridTopologyFactory.__name__: + factory = RESTCapableHybridTopologyFactory + try: + htf = factory(topology_proposal_dictionary[f'{phase}_topology_proposal'], + topology_proposal_dictionary[f'{phase}_old_positions'], + topology_proposal_dictionary[f'{phase}_new_positions'], + neglected_new_angle_terms=topology_proposal_dictionary[f"{phase}_forward_neglected_angles"], + neglected_old_angle_terms=topology_proposal_dictionary[f"{phase}_reverse_neglected_angles"], + softcore_LJ_v2=setup_options['softcore_v2'], + interpolate_old_and_new_14s=setup_options['anneal_1,4s'], + rmsd_restraint=setup_options['rmsd_restraint'] + ) + except NameError as error: + _logger.error(f"{error}. Check 'hybrid_topology_factory' name in input file.") + raise + return htf + + if __name__ == "__main__": run() From 86fe7b6a120646f5601af1d94557132b652fe78a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Thu, 5 May 2022 16:06:09 -0400 Subject: [PATCH 2/9] Support for multisampler platform spec from YAML input file. --- perses/app/setup_relative_calculation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/perses/app/setup_relative_calculation.py b/perses/app/setup_relative_calculation.py index b8655dc49..21a7d2ff7 100644 --- a/perses/app/setup_relative_calculation.py +++ b/perses/app/setup_relative_calculation.py @@ -312,6 +312,10 @@ def getSetupOptions(filename): setup_options['hybrid_topology_factory'] = default_htf_class_name _logger.info(f"\t 'hybrid_topology_factory' not specified: default to {default_htf_class_name}") + # Handling absence platform name input (backwards compatibility) + if 'platform' not in setup_options: + setup_options['platform'] = None # defaults to choosing best platform + os.makedirs(trajectory_directory, exist_ok=True) @@ -659,7 +663,7 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True): return {'topology_proposals': top_prop, 'hybrid_topology_factories': htf} # get platform - platform = get_openmm_platform(platform_name=None) + platform = get_openmm_platform(platform_name=setup_options['platform']) # Setup context caches for multistate samplers energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform) sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform) @@ -958,7 +962,7 @@ def _resume_run(setup_options): from openmmtools.cache import ContextCache from perses.samplers.multistate import HybridSAMSSampler, HybridRepexSampler # get platform - platform = get_openmm_platform(platform_name=None) + platform = get_openmm_platform(platform_name=setup_options['platform']) # Setup context caches for multistate samplers energy_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform) sampler_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform) From 8614b8f636027daa34a428a55abaca9e200a685b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Thu, 5 May 2022 19:55:19 -0400 Subject: [PATCH 3/9] Validate endstate energies for different HTFs. Supports new REST htf. --- perses/app/setup_relative_calculation.py | 47 +++++++++++++++++++----- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/perses/app/setup_relative_calculation.py b/perses/app/setup_relative_calculation.py index 21a7d2ff7..1f047978f 100644 --- a/perses/app/setup_relative_calculation.py +++ b/perses/app/setup_relative_calculation.py @@ -14,7 +14,7 @@ from openmmtools import mcmc, cache from openmmtools.multistate import MultiStateReporter from perses.utils.smallmolecules import render_atom_mapping -from perses.tests.utils import validate_endstate_energies +from perses.tests.utils import validate_endstate_energies, validate_endstate_energies_point from perses.dispersed.smc import SequentialMonteCarlo import datetime @@ -593,6 +593,7 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True): print("Nonequilibrium switching driver class constructed") + # TODO: Should this function return a single thing instead of two different objects for neq vs others? return {'topology_proposals': top_prop, 'ne_fep': ne_fep} else: @@ -608,16 +609,10 @@ def run_setup(setup_options, serialize_systems=True, build_samplers=True): htf[phase] = _generate_htf(phase, top_prop, setup_options) for phase in phases: - # Define necessary vars to check energy bookkeeping - _top_prop = top_prop['%s_topology_proposal' % phase] - _htf = htf[phase] - _forward_added_valence_energy = top_prop['%s_added_valence_energy' % phase] - _reverse_subtracted_valence_energy = top_prop['%s_subtracted_valence_energy' % phase] - if not use_given_geometries: - zero_state_error, one_state_error = validate_endstate_energies(_top_prop, _htf, _forward_added_valence_energy, _reverse_subtracted_valence_energy, beta = 1.0/(kB*temperature), ENERGY_THRESHOLD = ENERGY_THRESHOLD)#, trajectory_directory=f'{xml_directory}{phase}') - _logger.info(f"\t\terror in zero state: {zero_state_error}") - _logger.info(f"\t\terror in one state: {one_state_error}") + _validate_endstate_energies_for_htf(htf, top_prop, phase, + beta=1.0 / (kB * temperature), + ENERGY_THRESHOLD=ENERGY_THRESHOLD) else: _logger.info(f"'use_given_geometries' was passed to setup; skipping endstate validation") @@ -832,6 +827,7 @@ def run(yaml_filename=None): _forward_added_valence_energy = setup_dict['topology_proposals'][f"{phase}_added_valence_energy"] _reverse_subtracted_valence_energy = setup_dict['topology_proposals'][f"{phase}_subtracted_valence_energy"] + # TODO: Validation here should be done with the same _validate_endstate_energies_for_htf function. zero_state_error, one_state_error = validate_endstate_energies(hybrid_factory._topology_proposal, hybrid_factory, _forward_added_valence_energy, _reverse_subtracted_valence_energy, beta = 1.0/(kB*temperature), ENERGY_THRESHOLD = ENERGY_THRESHOLD, trajectory_directory=f'{setup_options["trajectory_directory"]}/xml/{phase}') _logger.info(f"\t\terror in zero state: {zero_state_error}") _logger.info(f"\t\terror in one state: {one_state_error}") @@ -1042,5 +1038,36 @@ def _generate_htf(phase: str, topology_proposal_dictionary: dict, setup_options: return htf +def _validate_endstate_energies_for_htf(hybrid_topology_factory_dict: dict, topology_proposal_dict: dict, phase: str, + **kwargs): + """ + Validates endstate energies according to different hybrid topology factories and phases. + + Parameters + ---------- + hybrid_topology_factory: dict + Dictionary with different hybrid topology factories for different phases. Phase as key, HTF as value. + topology_proposal_dict: dict + Dictionary with different topology proposals for different phases. Phase as key, top_pro as value. + phase: str + Name of the phase. + """ + htf = hybrid_topology_factory_dict[phase] + if htf.__name__ == "Hybrid TopologyFactory": + topology_proposal = topology_proposal_dict[f"{phase}_topology_proposal"] + forward_added_valence_energy = topology_proposal_dict[f"{phase}_added_valence_energy"] + reverse_substracted_valence_energy = topology_proposal_dict[f"{phase}_substracted_valence_energy"] + zero_state_error, one_state_error = validate_endstate_energies(topology_proposal, + htf, + forward_added_valence_energy, + reverse_substracted_valence_energy, + **kwargs) + _logger.info(f"\t\terror in zero state: {zero_state_error}") + _logger.info(f"\t\terror in one state: {one_state_error}") + elif htf.__name__ == "RESTCapableHybridTopologyFactory": + for endstate in [0, 1]: + validate_endstate_energies_point(htf, endstate=endstate, minimize=True) + + if __name__ == "__main__": run() From 8711b65d29a69ee162ed51e7db7edf51bdd56052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Fri, 6 May 2022 17:26:38 -0400 Subject: [PATCH 4/9] Handling specific REST HTF parameters. --- perses/app/setup_relative_calculation.py | 43 ++++++++++++++++-------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/perses/app/setup_relative_calculation.py b/perses/app/setup_relative_calculation.py index 1f047978f..7dc8c7bf1 100644 --- a/perses/app/setup_relative_calculation.py +++ b/perses/app/setup_relative_calculation.py @@ -1015,26 +1015,41 @@ def _resume_run(setup_options): def _generate_htf(phase: str, topology_proposal_dictionary: dict, setup_options: dict): """ - Generates topology proposal for phase. + Generates topology proposal for phase. Supports both HybridTopologyFactory and new RESTCapableHybridTopologyFactory """ factory_name = setup_options['hybrid_topology_factory'] + htf_setup_dict = { + "neglected_new_angle_terms": topology_proposal_dictionary[f"{phase}_forward_neglected_angles"], + "neglected_old_angle_terms": topology_proposal_dictionary[f"{phase}_reverse_neglected_angles"], + "softcore_LJ_v2": setup_options['softcore_v2'], + "interpolate_old_and_new_14s": setup_options['anneal_1,4s'], + "rmsd_restraint": setup_options['rmsd_restraint'] + } + if factory_name == HybridTopologyFactory.__name__: factory = HybridTopologyFactory elif factory_name == RESTCapableHybridTopologyFactory.__name__: factory = RESTCapableHybridTopologyFactory - try: - htf = factory(topology_proposal_dictionary[f'{phase}_topology_proposal'], - topology_proposal_dictionary[f'{phase}_old_positions'], - topology_proposal_dictionary[f'{phase}_new_positions'], - neglected_new_angle_terms=topology_proposal_dictionary[f"{phase}_forward_neglected_angles"], - neglected_old_angle_terms=topology_proposal_dictionary[f"{phase}_reverse_neglected_angles"], - softcore_LJ_v2=setup_options['softcore_v2'], - interpolate_old_and_new_14s=setup_options['anneal_1,4s'], - rmsd_restraint=setup_options['rmsd_restraint'] - ) - except NameError as error: - _logger.error(f"{error}. Check 'hybrid_topology_factory' name in input file.") - raise + # Add/use specified REST HTF parameters if present + rest_specific_options = dict() + try: + rest_specific_options.update({'rest_radius': setup_options['rest_radius']}) + except KeyError: + _logger.info("'rest_radius' not specified. Using default value.") + try: + rest_specific_options.update({'w_scale': setup_options['w_scale']}) + except KeyError: + _logger.info("'w_scale' not specified. Using default value.") + + # update htf_setup_dictionary with new parameters + htf_setup_dict.update(rest_specific_options) + else: + raise ValueError(f"Unsupported Hybrid Topology Factory. Check 'hybrid_topology_factory' name in input file.") + htf = factory(topology_proposal_dictionary[f'{phase}_topology_proposal'], + topology_proposal_dictionary[f'{phase}_old_positions'], + topology_proposal_dictionary[f'{phase}_new_positions'], + **htf_setup_dict + ) return htf From d70674ba0a1889617b9950e25358cd83ba4fc05d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Mon, 9 May 2022 12:20:03 -0400 Subject: [PATCH 5/9] Using instances types instead of class name. --- perses/app/setup_relative_calculation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/perses/app/setup_relative_calculation.py b/perses/app/setup_relative_calculation.py index 7dc8c7bf1..969399d84 100644 --- a/perses/app/setup_relative_calculation.py +++ b/perses/app/setup_relative_calculation.py @@ -1067,21 +1067,21 @@ def _validate_endstate_energies_for_htf(hybrid_topology_factory_dict: dict, topo phase: str Name of the phase. """ - htf = hybrid_topology_factory_dict[phase] - if htf.__name__ == "Hybrid TopologyFactory": + current_htf = hybrid_topology_factory_dict[phase] + if isinstance(current_htf, HybridTopologyFactory): topology_proposal = topology_proposal_dict[f"{phase}_topology_proposal"] forward_added_valence_energy = topology_proposal_dict[f"{phase}_added_valence_energy"] - reverse_substracted_valence_energy = topology_proposal_dict[f"{phase}_substracted_valence_energy"] + reverse_substracted_valence_energy = topology_proposal_dict[f"{phase}_subtracted_valence_energy"] zero_state_error, one_state_error = validate_endstate_energies(topology_proposal, - htf, + current_htf, forward_added_valence_energy, reverse_substracted_valence_energy, **kwargs) _logger.info(f"\t\terror in zero state: {zero_state_error}") _logger.info(f"\t\terror in one state: {one_state_error}") - elif htf.__name__ == "RESTCapableHybridTopologyFactory": + elif isinstance(current_htf, RESTCapableHybridTopologyFactory): for endstate in [0, 1]: - validate_endstate_energies_point(htf, endstate=endstate, minimize=True) + validate_endstate_energies_point(current_htf, endstate=endstate, minimize=True) if __name__ == "__main__": From 2fcf01088e2bed6f3d40a629a234e5f94c528380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Wed, 11 May 2022 14:45:30 -0400 Subject: [PATCH 6/9] Making fah generator add new htf key. Should change in the future. --- perses/app/fah_generator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/perses/app/fah_generator.py b/perses/app/fah_generator.py index 1cef0b09e..971768f2f 100644 --- a/perses/app/fah_generator.py +++ b/perses/app/fah_generator.py @@ -457,6 +457,12 @@ def run_neq_fah_setup(ligand_file, setup_options['bond_expr'] = generate_expression(setup_options['bond_expression']) # Generate topology proposals and hybrid topology factories via perses run_setup + # Manually add the new 'hybrid_topology_factory' key + # TODO: This should be changed once we have a new API (i.e. OpenFE Settings classes) + if 'hybrid_topology_factory' not in setup_options: + default_htf_class_name = "HybridTopologyFactory" + setup_options['hybrid_topology_factory'] = default_htf_class_name + _logger.info(f"\t 'hybrid_topology_factory' not specified: default to {default_htf_class_name}") _logger.info(f"spectators: {setup_options['spectators']}") if setup == 'small_molecule': _logger.info(f"Setting up a small molecule transformation") From 4bb686957cf1d15102736b8090c4773b2db1187a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Thu, 12 May 2022 11:42:28 -0400 Subject: [PATCH 7/9] More informative error message --- perses/app/setup_relative_calculation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/perses/app/setup_relative_calculation.py b/perses/app/setup_relative_calculation.py index 1d5c4a30e..bac0ae8cd 100644 --- a/perses/app/setup_relative_calculation.py +++ b/perses/app/setup_relative_calculation.py @@ -1091,7 +1091,8 @@ def _generate_htf(phase: str, topology_proposal_dictionary: dict, setup_options: # update htf_setup_dictionary with new parameters htf_setup_dict.update(rest_specific_options) else: - raise ValueError(f"Unsupported Hybrid Topology Factory. Check 'hybrid_topology_factory' name in input file.") + raise ValueError(f"You specified an unsupported factory type: {factory_name}. Currently, the supported " + f"factories are: HybridTopologyFactory and RESTCapableHybridTopologyFactory.") htf = factory(topology_proposal_dictionary[f'{phase}_topology_proposal'], topology_proposal_dictionary[f'{phase}_old_positions'], topology_proposal_dictionary[f'{phase}_new_positions'], From 9335d035f293c72dcbb67b64c386f2c42649679d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Thu, 12 May 2022 12:00:28 -0400 Subject: [PATCH 8/9] Copying htf inside validate energies function. --- perses/app/relative_point_mutation_setup.py | 2 +- perses/tests/test_topology_proposal.py | 3 +-- perses/tests/utils.py | 14 ++++++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/perses/app/relative_point_mutation_setup.py b/perses/app/relative_point_mutation_setup.py index cfca69c15..167a59622 100644 --- a/perses/app/relative_point_mutation_setup.py +++ b/perses/app/relative_point_mutation_setup.py @@ -436,7 +436,7 @@ def __init__(self, if generate_rest_capable_hybrid_topology_factory: from perses.tests.utils import validate_endstate_energies_point for endstate in [0, 1]: - htf = copy.deepcopy(self.get_complex_rest_htf()) if is_complex else copy.deepcopy(self.get_apo_rest_htf()) + htf = self.get_complex_rest_htf() if is_complex else self.get_apo_rest_htf() validate_endstate_energies_point(htf, endstate=endstate, minimize=True) else: diff --git a/perses/tests/test_topology_proposal.py b/perses/tests/test_topology_proposal.py index 8808abb96..355495b0d 100644 --- a/perses/tests/test_topology_proposal.py +++ b/perses/tests/test_topology_proposal.py @@ -304,8 +304,7 @@ def generate_dipeptide_top_pos_sys(topology, if generate_rest_capable_hybrid_topology_factory: from perses.tests.utils import validate_endstate_energies_point for endstate in [0, 1]: - htf = copy.deepcopy(forward_htf) - validate_endstate_energies_point(htf, endstate=endstate, minimize=True) + validate_endstate_energies_point(forward_htf, endstate=endstate, minimize=True) else: from perses.tests.utils import validate_endstate_energies diff --git a/perses/tests/utils.py b/perses/tests/utils.py index 28976b747..f368561d1 100644 --- a/perses/tests/utils.py +++ b/perses/tests/utils.py @@ -846,7 +846,7 @@ def validate_endstate_energies(topology_proposal, return zero_error, one_error -def validate_endstate_energies_point(htf, endstate=0, minimize=False): +def validate_endstate_energies_point(input_htf, endstate=0, minimize=False): """ ** Used for validating endstate energies for RESTCapableHybridTopologyFactory ** @@ -856,7 +856,7 @@ def validate_endstate_energies_point(htf, endstate=0, minimize=False): Parameters ---------- - htf : RESTCapableHybridTopologyFactory + input_htf : RESTCapableHybridTopologyFactory the RESTCapableHybridTopologyFactory to test endstate : int, default=0 the endstate to test (0 or 1) @@ -869,6 +869,9 @@ def validate_endstate_energies_point(htf, endstate=0, minimize=False): # Check that endstate is 0 or 1 assert endstate in [0, 1], "Endstate must be 0 or 1" + # Make deep copy to ensure original object remains unaltered + htf = copy.deepcopy(input_htf) + # Get original system system = htf._topology_proposal.old_system if endstate == 0 else htf._topology_proposal.new_system @@ -971,7 +974,7 @@ def validate_endstate_energies_point(htf, endstate=0, minimize=False): print(f"Success! Energies are equal at lambda {endstate}!") -def validate_endstate_energies_md(htf, T_max=300 * unit.kelvin, endstate=0, n_steps=125000): +def validate_endstate_energies_md(input_htf, T_max=300 * unit.kelvin, endstate=0, n_steps=125000): """ Check that the hybrid system's energy (without unique old/new valence energy) matches the original system's energy for snapshots extracted (every 1 ps) from a MD simulation. @@ -980,7 +983,7 @@ def validate_endstate_energies_md(htf, T_max=300 * unit.kelvin, endstate=0, n_st Parameters ---------- - htf : RESTCapableHybridTopologyFactory + input_htf : RESTCapableHybridTopologyFactory the RESTCapableHybridTopologyFactory to test T_max : unit.kelvin default=300 * unit.kelvin T_max at which to test the factory. This should not actually affect the energy differences, since T_max should equal T_min at the endstates @@ -996,6 +999,9 @@ def validate_endstate_energies_md(htf, T_max=300 * unit.kelvin, endstate=0, n_st # Check that endstate is 0 or 1 assert endstate in [0, 1], "Endstate must be 0 or 1" + # Make deep copy to ensure original object remains unaltered + htf = copy.deepcopy(input_htf) + # Set temperature T_min = temperature From f1d844b75dea1555d3368168380a076e97d20666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Thu, 12 May 2022 13:49:58 -0400 Subject: [PATCH 9/9] No need to pickle htfs now. --- perses/tests/test_relative.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/perses/tests/test_relative.py b/perses/tests/test_relative.py index 4684064ea..7a13d5718 100644 --- a/perses/tests/test_relative.py +++ b/perses/tests/test_relative.py @@ -956,9 +956,6 @@ def run_RESTCapableHybridTopologyFactory_energies(test_name, phase, use_point_en """ - import tempfile - import pickle - from perses.tests.test_topology_proposal import generate_atp, generate_dipeptide_top_pos_sys from perses.app.relative_point_mutation_setup import PointMutationExecutor from perses.tests.utils import validate_endstate_energies_point, validate_endstate_energies_md @@ -1006,21 +1003,13 @@ def run_RESTCapableHybridTopologyFactory_energies(test_name, phase, use_point_en ) htf = solvent_delivery.get_apo_rest_htf() - # Save htf as temporary pickled file - with tempfile.TemporaryDirectory() as temp_dir: - with open(os.path.join(temp_dir, "htf.pickle"), "wb") as f: - pickle.dump(htf, f) - - if use_point_energies: - for endstate in [0, 1]: - with open(os.path.join(temp_dir, "htf.pickle"), "rb") as f: - htf = pickle.load(f) - validate_endstate_energies_point(htf, endstate=endstate, minimize=True) - else: - for endstate in [0, 1]: - with open(os.path.join(temp_dir, "htf.pickle"), "rb") as f: - htf = pickle.load(f) - validate_endstate_energies_md(htf, endstate=endstate, n_steps=10) + # validating endstate energies + if use_point_energies: + for endstate in [0, 1]: + validate_endstate_energies_point(htf, endstate=endstate, minimize=True) + else: + for endstate in [0, 1]: + validate_endstate_energies_md(htf, endstate=endstate, n_steps=10) def test_RESTCapableHybridTopologyFactory_energies(): """