diff --git a/src/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py index 7181f4ce5..b410d583b 100644 --- a/src/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py +++ b/src/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py @@ -10,6 +10,7 @@ import itertools import logging import warnings +from collections import Counter from copy import deepcopy from typing import Optional, Union @@ -21,30 +22,68 @@ from openmm import NonbondedForce, System, app from openmm import unit as omm_unit -from openfe import SolventComponent - logger = logging.getLogger(__name__) -def _get_ion_and_water_parameters( +def _get_ff_parameters( + forcefield: app.ForceField, + smiles: str, +) -> dict[str, tuple]: + """ + Get NonbondedForce parameters for all atoms in a molecule by creating + a minimal dummy system from a SMILES string with the given forcefield. + + Parameters + ---------- + forcefield : app.ForceField + The force field to use for parameterisation. + smiles : str + The SMILES string of the molecule to parameterise + (e.g. '[Na+]', '[Cl-]', 'O'). + + Returns + ------- + dict[str, tuple] + A dictionary keyed by element symbol, with (charge, sigma, epsilon) + tuples as values. + """ + from openff.toolkit import Molecule + + offmol = Molecule.from_smiles(smiles) + dummy_top = offmol.to_topology().to_openmm() + dummy_system = forcefield.createSystem(dummy_top) + nbf = [f for f in dummy_system.getForces() if isinstance(f, NonbondedForce)][0] + + return { + atom.element.symbol: nbf.getParticleParameters(atom.index) + for atom in dummy_top.atoms() + } + +def _get_ion_parameters( topology: app.Topology, system: System, - ion_resname: str, - water_resname: str = 'HOH', -): + charge_difference: int, + forcefield: app.ForceField, +) -> tuple: """ - Get ion, and water (oxygen and hydrogen) atoms parameters. + Get NonbondedForce parameters for the most abundant monovalent ion of + the appropriate charge sign found in the topology. Falls back to + 'NA' or 'CL' parameters if no ion is found in the topology. Parameters ---------- topology : app.Topology - The topology to search for the ion and water + The topology to search for the most abundant ion type. system : app.System The system associated with the input topology object. - ion_resname : str - The residue name of the ion to get parameters for - water_resname : str - The residue name of the water to get parameters for. Default 'HOH'. + charge_difference : int + The charge difference between state A and state B, + calculated as stateA_formal_charge - stateB_formal_charge. + Positive values indicate a positive ion is needed, + negative values a negative ion. + forcefield : app.ForceField + The force field to use for parameterization if no ion is found + in the topology. Returns ------- @@ -54,44 +93,48 @@ def _get_ion_and_water_parameters( The NonbondedForce sigma parameter of the ion atom ion_epsilon : float The NonbondedForce epsilon parameter of the ion atom - o_charge : float - The partial charge of the water oxygen. - h_charge : float - The partial charge of the water hydrogen. - - Raises - ------ - ValueError - If there are no ``ion_resname`` or ``water_resname`` named residues in - the input ``topology``. - - Attribution - ----------- - Based on `perses.utils.charge_changing.get_ion_and_water_parameters`. """ - def _find_atom(topology, resname, elementname): - for atom in topology.atoms(): - if atom.residue.name == resname: - if (elementname is None or atom.element.symbol == elementname): - return atom.index - errmsg = ("Error encountered when attempting to explicitly handle " - "charge changes using an alchemical water. No residue " - f"named: {resname} found, with element {elementname}") - raise ValueError(errmsg) - - ion_index = _find_atom(topology, ion_resname, None) - oxygen_index = _find_atom(topology, water_resname, 'O') - hydrogen_index = _find_atom(topology, water_resname, 'H') - - nbf = [i for i in system.getForces() - if isinstance(i, NonbondedForce)][0] - - ion_charge, ion_sigma, ion_epsilon = nbf.getParticleParameters(ion_index) - o_charge, _, _ = nbf.getParticleParameters(oxygen_index) - h_charge, _, _ = nbf.getParticleParameters(hydrogen_index) + nbf = [i for i in system.getForces() if isinstance(i, NonbondedForce)][0] + + desired_sign = np.sign(charge_difference) + ion_counts: Counter = Counter() + ion_atom_indices: dict[str, int] = {} + + for residue in topology.residues(): + atoms = list(residue.atoms()) + if len(atoms) != 1: + continue + atom = atoms[0] + if atom.element is None: + continue + charge, _, _ = nbf.getParticleParameters(atom.index) + charge_val = charge.value_in_unit(omm_unit.elementary_charge) + if np.isclose(abs(charge_val), 1.0, atol=0.01) and np.sign(charge_val) == desired_sign: + ion_counts[residue.name] += 1 + if residue.name not in ion_atom_indices: + ion_atom_indices[residue.name] = atom.index + + if ion_counts: + # Use the most abundant matching ion found in the topology + best_resname = ion_counts.most_common(1)[0][0] + return nbf.getParticleParameters(ion_atom_indices[best_resname]) + + # No matching ion in topology: fall back to NA/CL from forcefield + if charge_difference > 0: + fallback_smiles = '[Na+]' + else: + fallback_smiles = '[Cl-]' - return ion_charge, ion_sigma, ion_epsilon, o_charge, h_charge + wmsg = ( + f"No monovalent ion with the appropriate charge sign found in " + f"the topology. Defaulting to '{fallback_smiles}' and obtaining " + f"parameters from the forcefield directly." + ) + warnings.warn(wmsg) + logger.warning(wmsg) + ion_params = _get_ff_parameters(forcefield, fallback_smiles) + return next(iter(ion_params.values())) def _fix_alchemical_water_atom_mapping( system_mapping: dict[str, Union[dict[int, int], list[int]]], @@ -123,10 +166,12 @@ def _fix_alchemical_water_atom_mapping( def handle_alchemical_waters( - water_resids: list[int], topology: app.Topology, - system: System, system_mapping: dict, + water_resids: list[int], + topology: app.Topology, + system: System, + system_mapping: dict, charge_difference: int, - solvent_component: SolventComponent, + forcefield: app.ForceField, ): """ Add alchemical waters from a pre-defined list. @@ -142,15 +187,12 @@ def handle_alchemical_waters( system_mapping : dictionary A dictionary of system mappings between the stateA and stateB systems charge_difference : int - The charge difference between state A and state B. - positive_ion_resname : str - The name of a positive ion to replace the water with if the absolute - charge difference is positive. - negative_ion_resname : str - The name of a negative ion to replace the water with if the absolute - charge difference is negative. - water_resname : str - The residue name of the water to get parameters for. Default 'HOH'. + The charge difference between state A and state B, + calculated as stateA_formal_charge - stateB_formal_charge. + Positive values indicate a positive ion is needed, + negative values a negative ion. + forcefield : app.ForceField + The forcefield to use for ion parameterization. Raises ------ @@ -171,19 +213,9 @@ def handle_alchemical_waters( f"difference: {abs(charge_difference)}") raise ValueError(errmsg) - if charge_difference > 0: - ion_resname = solvent_component.positive_ion.strip('-+').upper() - elif charge_difference < 0: - ion_resname = solvent_component.negative_ion.strip('-+').upper() - # if there's no charge difference then just skip altogether - else: + if charge_difference == 0: return None - ion_charge, ion_sigma, ion_epsilon, o_charge, h_charge = _get_ion_and_water_parameters( - topology, system, ion_resname, - 'HOH', # Modeller always adds HOH waters - ) - # get the nonbonded forces nbfrcs = [i for i in system.getForces() if isinstance(i, NonbondedForce)] @@ -193,8 +225,7 @@ def handle_alchemical_waters( # for convenience just grab the first & only entry nbf = nbfrcs[0] - # Loop through residues, check if they match the residue index - # mutate the atom as necessary + # Check for virtual site waters for res in topology.residues(): if res.index in water_resids: # if the number of atoms > 3, then we have virtual sites which are @@ -204,6 +235,18 @@ def handle_alchemical_waters( "are not currently supported as alchemical waters") raise ValueError(errmsg) + ion_charge, ion_sigma, ion_epsilon = _get_ion_parameters( + topology, system, charge_difference, forcefield, + ) + # Get water parameters + water_params = _get_ff_parameters(forcefield, '[H]O[H]') + o_charge = water_params['O'][0] + h_charge = water_params['H'][0] + + # Loop through residues, check if they match the residue index + # mutate the atom as necessary + for res in topology.residues(): + if res.index in water_resids: for at in res.atoms(): idx = at.index charge, sigma, epsilon = nbf.getParticleParameters(idx) @@ -433,7 +476,6 @@ def _remove_constraints(old_to_new_atom_map, old_system, old_topology, * Very slow, needs refactoring * Can we drop having topologies as inputs here? """ - from collections import Counter no_const_old_to_new_atom_map = deepcopy(old_to_new_atom_map) diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py index ed83e0459..35ef51375 100644 --- a/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py +++ b/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -28,6 +28,7 @@ ProteinComponent, ProteinMembraneComponent, SmallMoleculeComponent, + SolvatedPDBComponent, SolventComponent, settings, ) @@ -39,6 +40,7 @@ settings_validation, system_validation, ) +from . import _rfe_utils from .equil_rfe_settings import ( AlchemicalSettings, IntegratorSettings, @@ -372,7 +374,7 @@ def _validate_charge_difference( mapping: LigandAtomMapping, nonbonded_method: str, explicit_charge_correction: bool, - solvent_component: SolventComponent | None, + solvent_component: SolventComponent | SolvatedPDBComponent | None, ): """ Validates the net charge difference between the two states. @@ -435,12 +437,10 @@ def _validate_charge_difference( ) raise ValueError(errmsg) - ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference] - wmsg = ( f"A charge difference of {difference} is observed " "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" + "transforming a water into an ion of opposite charge." ) logger.info(wmsg) @@ -560,6 +560,8 @@ def _validate( # Note: validation depends on the mapping & solvent component checks if stateA.contains(SolventComponent): solv_comp = stateA.get_components_of_type(SolventComponent)[0] + elif stateA.contains(SolvatedPDBComponent): + solv_comp = stateA.get_components_of_type(SolvatedPDBComponent)[0] else: solv_comp = None diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_units.py b/src/openfe/protocols/openmm_rfe/hybridtop_units.py index cd634f391..5f4c1c9f3 100644 --- a/src/openfe/protocols/openmm_rfe/hybridtop_units.py +++ b/src/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -390,7 +390,7 @@ def _handle_net_charge( charge_difference: int, system_mappings: dict[str, dict[int, int]], distance_cutoff: Quantity, - solvent_component: SolventComponent | None, + forcefield: openmm.app.ForceField, ) -> None: """ Handle system net charge by adding an alchemical water. @@ -404,7 +404,7 @@ def _handle_net_charge( charge_difference : int system_mappings : dict[str, dict[int, int]] distance_cutoff : Quantity - solvent_component : SolventComponent | None + forcefield: openmm.app.ForceField """ # Base case, return if no net charge if charge_difference == 0: @@ -425,7 +425,7 @@ def _handle_net_charge( system=stateB_system, system_mapping=system_mappings, charge_difference=charge_difference, - solvent_component=solvent_component, + forcefield=forcefield, ) def _get_omm_objects( @@ -545,6 +545,7 @@ def _filter_small_mols(smols, state): # Net charge: add alchemical water if needed # Must be done here as we in-place modify the particles of state B. if settings["alchemical_settings"].explicit_charge_correction: + forcefield = states_inputs["A"]["generator"].forcefield self._handle_net_charge( stateA_topology=stateA_topology, stateA_positions=stateA_positions, @@ -553,7 +554,7 @@ def _filter_small_mols(smols, state): charge_difference=mapping.get_alchemical_charge_difference(), system_mappings=system_mappings, distance_cutoff=settings["alchemical_settings"].explicit_charge_correction_cutoff, - solvent_component=solvent_component, + forcefield=forcefield, ) # Finally get the state B positions diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index fb297e248..62f3072c8 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -1215,6 +1215,36 @@ def test_dry_run_membrane_complex( ) +def test_validate_charge_difference_membrane_system( + a2a_protein_membrane_component, + a2a_ligands, +): + ligA = next(c for c in a2a_ligands if c.name == "4g") + ligB = next(c for c in a2a_ligands if c.name == "4h") + + mapping = openfe.LigandAtomMapping( + componentA=ligA, + componentB=ligB, + componentA_to_componentB={i: i for i in range(36)}, + ) + + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + + # Mock get_alchemical_charge_difference to return non-zero + with mock.patch.object( + mapping, + "get_alchemical_charge_difference", + return_value=1, + ): + protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings) + protocol._validate_charge_difference( + mapping=mapping, + nonbonded_method=settings.forcefield_settings.nonbonded_method, + explicit_charge_correction=True, + solvent_component=a2a_protein_membrane_component, + ) + + def test_lambda_schedule_default(): lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol(functions="default") assert len(lambdas.lambda_schedule) == 10 @@ -1860,7 +1890,7 @@ def benzene_solvent_openmm_system(benzene_modifications): positions = to_openmm(from_openmm(modeller.getPositions())) system = system_generator.create_system(topology, molecules=[offmol]) - return system, topology, positions + return system, topology, positions, system_generator.forcefield @pytest.fixture(scope="session") @@ -1910,7 +1940,7 @@ def benzene_self_system_mapping(benzene_solvent_openmm_system): alchemical transformation (this technically doesn't work in practice because the RFE protocol expects an alchemical component). """ - system, topology, positions = benzene_solvent_openmm_system + system, topology, positions, forcefield = benzene_solvent_openmm_system res = [r for r in topology.residues()] benzene_res = [r for r in res if r.name == "UNK"][0] @@ -1932,28 +1962,53 @@ def benzene_self_system_mapping(benzene_solvent_openmm_system): return system_mapping -@pytest.mark.parametrize( - "ion, water", - [ - ["NA", "SOL"], - ["NX", "WAT"], - ], -) -def test_get_ion_water_parameters_unknownresname(ion, water, benzene_solvent_openmm_system): - system, topology, positions = benzene_solvent_openmm_system +def test_get_ion_parameters_fallback_to_forcefield(benzene_modifications): + """ + Check that when no matching ion is found in the topology, + parameters are obtained from the forcefield (NA fallback). + """ + smc = benzene_modifications["benzene"] + offmol = smc.to_openff() + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + + system_generator = system_creation.get_system_generator( + forcefield_settings=settings.forcefield_settings, + integrator_settings=settings.integrator_settings, + thermo_settings=settings.thermo_settings, + cache=None, + has_solvent=True, + ) + system_generator.create_system( + offmol.to_topology().to_openmm(), + molecules=[offmol], + ) - errmsg = "Error encountered when attempting to explicitly handle" + modeller, _ = system_creation.get_omm_modeller( + protein_comp=None, + solvent_comp=openfe.SolventComponent(ion_concentration=0 * unit.molar), + small_mols={smc: offmol}, + omm_forcefield=system_generator.forcefield, + solvent_settings=settings.solvation_settings, + ) - with pytest.raises(ValueError, match=errmsg): - topologyhelpers._get_ion_and_water_parameters( - topology, system, ion_resname=ion, water_resname=water + topology = modeller.getTopology() + system = system_generator.create_system(topology, molecules=[offmol]) + + with pytest.warns(UserWarning, match="No monovalent ion"): + ion_charge, ion_sigma, ion_epsilon = topologyhelpers._get_ion_parameters( + topology, + system, + charge_difference=1, + forcefield=system_generator.forcefield, ) + assert ion_charge.value_in_unit(omm_unit.elementary_charge) == pytest.approx(1.0) + def test_get_alchemical_waters_no_waters( benzene_solvent_openmm_system, ): - system, topology, positions = benzene_solvent_openmm_system + system, topology, positions, forcefield = benzene_solvent_openmm_system errmsg = "There are no waters" @@ -1969,7 +2024,7 @@ def test_handle_alchemwats_incorrect_count( """ Check that an error is thrown when charge_difference != len(water_resids) """ - system, topology, positions = benzene_solvent_openmm_system + system, topology, positions, forcefield = benzene_solvent_openmm_system errmsg = "There should be as many alchemical water residues:" @@ -1980,7 +2035,7 @@ def test_handle_alchemwats_incorrect_count( system=system, system_mapping={}, charge_difference=1, - solvent_component=openfe.SolventComponent(), + forcefield=forcefield, ) @@ -1990,7 +2045,7 @@ def test_handle_alchemwats_too_many_nbf( """ Check that an error is thrown when there are multiple NonbondedForces """ - system, topology, positions = benzene_solvent_openmm_system + system, topology, positions, forcefield = benzene_solvent_openmm_system new_system = copy.deepcopy(system) new_system.addForce(NonbondedForce()) @@ -2004,7 +2059,7 @@ def test_handle_alchemwats_too_many_nbf( system=new_system, system_mapping={}, charge_difference=1, - solvent_component=openfe.SolventComponent(), + forcefield=forcefield, ) @@ -2026,7 +2081,7 @@ def test_handle_alchemwats_vsite_water( system=system, system_mapping={}, charge_difference=1, - solvent_component=openfe.SolventComponent(), + forcefield=app.ForceField(), ) @@ -2037,7 +2092,7 @@ def test_handle_alchemwats_incorrect_atom( """ Check that an error is thrown when charge_difference != len(water_resids) """ - system, topology, positions = benzene_solvent_openmm_system + system, topology, positions, forcefield = benzene_solvent_openmm_system # modify the charge of hydrogen atom 25 new_system = copy.deepcopy(system) # protect the session scoped object @@ -2054,7 +2109,7 @@ def test_handle_alchemwats_incorrect_atom( system=new_system, system_mapping=benzene_self_system_mapping, charge_difference=1, - solvent_component=openfe.SolventComponent(), + forcefield=forcefield, ) @@ -2062,7 +2117,7 @@ def test_handle_alchemical_wats( benzene_solvent_openmm_system, benzene_self_system_mapping, ): - system, topology, positions = benzene_solvent_openmm_system + system, topology, positions, forcefield = benzene_solvent_openmm_system n_env = len(benzene_self_system_mapping["old_to_new_env_atom_map"]) n_core = len(benzene_self_system_mapping["old_to_new_core_atom_map"]) @@ -2073,7 +2128,7 @@ def test_handle_alchemical_wats( system=system, system_mapping=benzene_self_system_mapping, charge_difference=1, - solvent_component=openfe.SolventComponent(), + forcefield=forcefield, ) # check the mappings @@ -2089,8 +2144,8 @@ def test_handle_alchemical_wats( # system parameters checks nbf = [i for i in system.getForces() if isinstance(i, NonbondedForce)][0] # check the oxygen parameters - i_chg, i_sig, i_eps, o_chg, h_chg = topologyhelpers._get_ion_and_water_parameters( - topology, system, "NA", "HOH" + i_chg, i_sig, i_eps = topologyhelpers._get_ion_parameters( + topology, system, charge_difference=1, forcefield=forcefield ) charge, sigma, epsilon = nbf.getParticleParameters(24) diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py index 6d5db1249..f9d219e34 100644 --- a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_validation.py @@ -441,11 +441,10 @@ def test_get_charge_difference(mapping_name, result, request, caplog): mapping = request.getfixturevalue(mapping_name) caplog.set_level(logging.INFO) - ion = r"Na+" if result == -1 else r"Cl-" msg = ( f"A charge difference of {result} is observed " "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" + f"transforming a water into an ion of opposite charge" ) openmm_rfe.RelativeHybridTopologyProtocol._validate_charge_difference(