From ea6e58a6ae2941637b98dcac8765d2a5c20149dd Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Thu, 28 May 2026 12:04:36 +0200 Subject: [PATCH 01/17] First pass at adding structural analysis to the SepTop protocol --- .../protocols/openmm_septop/base_units.py | 311 +++++++++++++++++- .../protocols/openmm_septop/septop_units.py | 8 + 2 files changed, 317 insertions(+), 2 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 821c1931a..dcafe5407 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -20,6 +20,8 @@ from typing import Any, Literal, Optional import gufe +import matplotlib.pyplot as plt +import numpy as np import numpy.typing as npt import openmm import openmmtools @@ -43,6 +45,7 @@ ThermodynamicState, create_thermodynamic_state_protocol, ) +from rdkit import Chem import openfe from openfe.protocols.openmm_afe.equil_afe_settings import ( @@ -1503,11 +1506,282 @@ def _analyze_multistate_energies( reporter.close() return analyzer.unit_results_dict + @classmethod + def _run_complex_analysis( + cls, + ds, + pdb_file: pathlib.Path, + skip: int, + ligand_A_indices: list[int], + ligand_B_indices: list[int], + rdmol_A: Chem.Mol, + rdmol_B: Chem.Mol, + ) -> dict[str, Any]: + """ + Run structural analysis for the complex phase. + + Parameters + ---------- + ds : netCDF4.Dataset + Open NetCDF dataset for the multistate trajectory. + pdb_file : pathlib.Path + Path to the subsampled PDB file. + skip : int + Frame stride for analysis. + ligand_A_indices : list[int] + Atom indices of ligand A in the subsampled system. + ligand_B_indices : list[int] + Atom indices of ligand B in the subsampled system. + rdmol_A : Chem.Mol + RDKit molecule for ligand A, used for symmetry-corrected RMSD. + rdmol_B : Chem.Mol + RDKit molecule for ligand B, used for symmetry-corrected RMSD. + + Returns + ------- + dict[str, Any] + Per-state lists for ligand RMSD, COM drift, protein 2D RMSD, + and a single time_ps array. + """ + from openfe_analysis.rmsd import ( + LigandCOMDrift, + Protein2DRMSD, + SymmetryCorrectedLigandRMSD, + ) + from openfe_analysis.utils.apply_transformations import ( + apply_alignment_transformations, + ) + from openfe_analysis.utils.universe_utils import \ + create_universe_single_state + + n_lambda = ds.dimensions["state"].size + data: dict[str, Any] = { + "ligand_A_RMSD": [], + "ligand_B_RMSD": [], + "ligand_A_COM_drift": [], + "ligand_B_COM_drift": [], + "protein_2D_RMSD": [], + "time_ps": None, + } + + for state_idx in range(n_lambda): + u = create_universe_single_state(pdb_file, ds, state=state_idx) + prot = u.select_atoms("protein and name CA") + lig_A = u.atoms[ligand_A_indices] + lig_B = u.atoms[ligand_B_indices] + apply_alignment_transformations(u, protein=prot, + ligand=lig_A + lig_B) + + if prot: + prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) + data["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d) + + for label, lig, rdmol in [ + ("ligand_A", lig_A, rdmol_A), + ("ligand_B", lig_B, rdmol_B), + ]: + lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run( + step=skip) + data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) + + lig_drift = LigandCOMDrift(lig).run(step=skip) + data[f"{label}_COM_drift"].append(lig_drift.results.com_drift) + + if data["time_ps"] is None: + data["time_ps"] = ( + np.arange(len(u.trajectory))[::skip] * u.trajectory.dt + ) + + return data + + @classmethod + def _run_solvent_analysis( + cls, + ds, + pdb_file: pathlib.Path, + skip: int, + ligand_A_indices: list[int], + ligand_B_indices: list[int], + rdmol_A: Chem.Mol, + rdmol_B: Chem.Mol, + ) -> dict[str, Any]: + """ + Run structural analysis for the solvent phase. + + Parameters + ---------- + ds : netCDF4.Dataset + Open NetCDF dataset for the multistate trajectory. + pdb_file : pathlib.Path + Path to the subsampled PDB file. + skip : int + Frame stride for analysis. + ligand_A_indices : list[int] + Atom indices of ligand A in the subsampled system. + ligand_B_indices : list[int] + Atom indices of ligand B in the subsampled system. + rdmol_A : Chem.Mol + RDKit molecule for ligand A, used for symmetry-corrected RMSD. + rdmol_B : Chem.Mol + RDKit molecule for ligand B, used for symmetry-corrected RMSD. + + Returns + ------- + dict[str, Any] + Per-state lists for ligand RMSD and a single + time_ps array. + """ + from openfe_analysis.rmsd import SymmetryCorrectedLigandRMSD + from openfe_analysis.utils.apply_transformations import ( + apply_alignment_transformations, + ) + from openfe_analysis.utils.universe_utils import \ + create_universe_single_state + + n_lambda = ds.dimensions["state"].size + data: dict[str, Any] = { + "ligand_A_RMSD": [], + "ligand_B_RMSD": [], + "time_ps": None, + } + + for state_idx in range(n_lambda): + for label, indices, rdmol in [ + ("ligand_A", ligand_A_indices, rdmol_A), + ("ligand_B", ligand_B_indices, rdmol_B), + ]: + u = create_universe_single_state(pdb_file, ds, state=state_idx) + lig = u.atoms[indices] + apply_alignment_transformations(u, ligand=lig) + + lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run( + step=skip) + data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) + + if data["time_ps"] is None: + data["time_ps"] = ( + np.arange(len(u.trajectory))[ + ::skip] * u.trajectory.dt + ) + + return data + + @classmethod + def _structural_analysis( + cls, + pdb_file: pathlib.Path, + trj_file: pathlib.Path, + output_directory: pathlib.Path, + dry: bool, + simtype: str, + ligand_A_indices: list[int], + ligand_B_indices: list[int], + rdmol_A: Chem.Mol, + rdmol_B: Chem.Mol, + ) -> dict[str, str | pathlib.Path]: + """ + Run structural analysis using ``openfe-analysis``. + + Parameters + ---------- + pdb_file : pathlib.Path + Path to the subsampled PDB file. + trj_file : pathlib.Path + Path to the trajectory NetCDF file. + output_directory : pathlib.Path + Directory where plots and the NPZ file will be written. + dry : bool + Whether or not we are running a dry run. + simtype : str + Either ``"complex"`` or ``"solvent"``. Controls whether protein + analyses are run and how alignment is applied. + ligand_A_indices : list[int] + Atom indices of ligand A in the subsampled system. + ligand_B_indices : list[int] + Atom indices of ligand B in the subsampled system. + rdmol_A : Chem.Mol + RDKit molecule for ligand A, used for symmetry-corrected RMSD. + rdmol_B : Chem.Mol + RDKit molecule for ligand B, used for symmetry-corrected RMSD. + + Returns + ------- + dict[str, str | pathlib.Path] + Dictionary containing either the path to the NPZ file with the + structural data, or the analysis error. + """ + import netCDF4 as nc + from openfe_analysis import plotting + + try: + with nc.Dataset(trj_file) as ds: + # Frame stride: aim for ~500 frames per state + if hasattr(ds, "PositionInterval"): + n_frames = len( + range(0, ds.dimensions["iteration"].size, + ds.PositionInterval) + ) + else: + n_frames = ds.dimensions["iteration"].size + skip = max(n_frames // 500, 1) + + if simtype == "complex": + data = cls._run_complex_analysis( + ds, pdb_file, skip, + ligand_A_indices, ligand_B_indices, + rdmol_A, rdmol_B, + ) + else: + data = cls._run_solvent_analysis( + ds, pdb_file, skip, + ligand_A_indices, ligand_B_indices, + rdmol_A, rdmol_B, + ) + + except Exception as e: + return {"structural_analysis_error": str(e)} + + if not dry: + time_ps = data.get("time_ps", []) + + if data.get("protein_2D_RMSD"): + fig = plotting.plot_2D_rmsd(data["protein_2D_RMSD"]) + fig.savefig(output_directory / "protein_2D_RMSD.png") + plt.close(fig) + + for label in ["ligand_A", "ligand_B"]: + if data.get(f"{label}_RMSD"): + fig = plotting.plot_ligand_RMSD(time_ps, + data[f"{label}_RMSD"]) + fig.savefig(output_directory / f"{label}_RMSD.png") + plt.close(fig) + + if data.get(f"{label}_COM_drift"): + fig = plotting.plot_ligand_COM_drift( + time_ps, data[f"{label}_COM_drift"] + ) + fig.savefig(output_directory / f"{label}_COM_drift.png") + plt.close(fig) + + npz_file = output_directory / "structural_analysis.npz" + npz_data = { + k: np.asarray(v, dtype=np.float32) + for k, v in data.items() + if v is not None + } + np.savez_compressed(npz_file, **npz_data) + return {"structural_analysis": npz_file} + def run( self, *, trajectory: pathlib.Path, checkpoint: pathlib.Path, + pdb_file: pathlib.Path, + ligand_A_indices: list[int], + ligand_B_indices: list[int], + rdmol_A: Chem.Mol, + rdmol_B: Chem.Mol, dry: bool = False, verbose: bool = True, scratch_basepath: pathlib.Path | None = None, @@ -1521,6 +1795,16 @@ def run( Path to the MultiStateReporter generated NetCDF file. checkpoint : pathlib.Path Path to the checkpoint file generated by MultiStateReporter. + pdb_file : pathlib.Path + Path to the subsampled PDB file. + ligand_A_indices : list[int] + Atom indices of ligand A in the subsampled system. + ligand_B_indices : list[int] + Atom indices of ligand B in the subsampled system. + rdmol_A : Chem.Mol + RDKit molecule for ligand A. + rdmol_B : Chem.Mol + RDKit molecule for ligand B. dry : bool Do a dry run of the calculation, creating all necessary hybrid system components (topology, system, sampler, etc...) but without @@ -1560,7 +1844,22 @@ def run( dry=dry, ) - return energy_analysis + if verbose: + self.logger.info("Analyzing structural outputs") + + structural_analysis = self._structural_analysis( + pdb_file=pdb_file, + trj_file=trajectory, + output_directory=self.shared_basepath, + dry=dry, + simtype=self.simtype, + ligand_A_indices=ligand_A_indices, + ligand_B_indices=ligand_B_indices, + rdmol_A=rdmol_A, + rdmol_B=rdmol_B, + ) + + return energy_analysis | structural_analysis def _execute( self, @@ -1579,14 +1878,22 @@ def _execute( trajectory = simulation.outputs["trajectory"] checkpoint = simulation.outputs["checkpoint"] + alchem_comps = self._inputs["alchemical_components"] + rdmol_A = alchem_comps["stateA"][0].to_rdkit() + rdmol_B = alchem_comps["stateB"][0].to_rdkit() + outputs = self.run( trajectory=trajectory, checkpoint=checkpoint, + pdb_file=setup.outputs["subsampled_pdb_structure"], + ligand_A_indices=setup.outputs["ligand_A_indices"], + ligand_B_indices=setup.outputs["ligand_B_indices"], + rdmol_A=rdmol_A, + rdmol_B=rdmol_B, scratch_basepath=ctx.scratch, shared_basepath=ctx.shared, ) - # We re-include things here to make life easier when gathering results if self.simtype == "complex": previous_outputs = { "standard_state_correction_A": setup.outputs["standard_state_correction_A"], diff --git a/src/openfe/protocols/openmm_septop/septop_units.py b/src/openfe/protocols/openmm_septop/septop_units.py index 6ad261837..12fde00e5 100644 --- a/src/openfe/protocols/openmm_septop/septop_units.py +++ b/src/openfe/protocols/openmm_septop/septop_units.py @@ -855,6 +855,8 @@ def run( "restraint_geometry_B": restraint_geom_B.model_dump(), "selection_indices": selection_indices, "subsampled_pdb_structure": sub_pdb_structure, + "ligand_A_indices": atom_indices_AB_A, + "ligand_B_indices": atom_indices_AB_B, } else: return { @@ -870,6 +872,8 @@ def run( "positions": equil_positions_AB, "selection_indices": selection_indices, "subsampled_pdb_structure": sub_pdb_structure, + "ligand_A_indices": atom_indices_AB_A, + "ligand_B_indices": atom_indices_AB_B, } @@ -1133,6 +1137,8 @@ def run( "standard_state_correction": corr.to("kilocalorie_per_mole"), "selection_indices": selection_indices, "subsampled_pdb_structure": sub_pdb_structure, + "ligand_A_indices": atom_indices_AB_A, + "ligand_B_indices": atom_indices_AB_B, } else: return { @@ -1146,6 +1152,8 @@ def run( "positions": positions_AB, "selection_indices": selection_indices, "subsampled_pdb_structure": sub_pdb_structure, + "ligand_A_indices": atom_indices_AB_A, + "ligand_B_indices": atom_indices_AB_B, } From b0538acead34a0cd2afefc45f367502fa4ce400d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 May 2026 10:06:16 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../protocols/openmm_septop/base_units.py | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index dcafe5407..bb9a86162 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1551,8 +1551,7 @@ def _run_complex_analysis( from openfe_analysis.utils.apply_transformations import ( apply_alignment_transformations, ) - from openfe_analysis.utils.universe_utils import \ - create_universe_single_state + from openfe_analysis.utils.universe_utils import create_universe_single_state n_lambda = ds.dimensions["state"].size data: dict[str, Any] = { @@ -1569,8 +1568,7 @@ def _run_complex_analysis( prot = u.select_atoms("protein and name CA") lig_A = u.atoms[ligand_A_indices] lig_B = u.atoms[ligand_B_indices] - apply_alignment_transformations(u, protein=prot, - ligand=lig_A + lig_B) + apply_alignment_transformations(u, protein=prot, ligand=lig_A + lig_B) if prot: prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) @@ -1580,17 +1578,14 @@ def _run_complex_analysis( ("ligand_A", lig_A, rdmol_A), ("ligand_B", lig_B, rdmol_B), ]: - lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run( - step=skip) + lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(step=skip) data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) lig_drift = LigandCOMDrift(lig).run(step=skip) data[f"{label}_COM_drift"].append(lig_drift.results.com_drift) if data["time_ps"] is None: - data["time_ps"] = ( - np.arange(len(u.trajectory))[::skip] * u.trajectory.dt - ) + data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt return data @@ -1635,8 +1630,7 @@ def _run_solvent_analysis( from openfe_analysis.utils.apply_transformations import ( apply_alignment_transformations, ) - from openfe_analysis.utils.universe_utils import \ - create_universe_single_state + from openfe_analysis.utils.universe_utils import create_universe_single_state n_lambda = ds.dimensions["state"].size data: dict[str, Any] = { @@ -1654,15 +1648,11 @@ def _run_solvent_analysis( lig = u.atoms[indices] apply_alignment_transformations(u, ligand=lig) - lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run( - step=skip) + lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(step=skip) data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) if data["time_ps"] is None: - data["time_ps"] = ( - np.arange(len(u.trajectory))[ - ::skip] * u.trajectory.dt - ) + data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt return data @@ -1717,25 +1707,30 @@ def _structural_analysis( with nc.Dataset(trj_file) as ds: # Frame stride: aim for ~500 frames per state if hasattr(ds, "PositionInterval"): - n_frames = len( - range(0, ds.dimensions["iteration"].size, - ds.PositionInterval) - ) + n_frames = len(range(0, ds.dimensions["iteration"].size, ds.PositionInterval)) else: n_frames = ds.dimensions["iteration"].size skip = max(n_frames // 500, 1) if simtype == "complex": data = cls._run_complex_analysis( - ds, pdb_file, skip, - ligand_A_indices, ligand_B_indices, - rdmol_A, rdmol_B, + ds, + pdb_file, + skip, + ligand_A_indices, + ligand_B_indices, + rdmol_A, + rdmol_B, ) else: data = cls._run_solvent_analysis( - ds, pdb_file, skip, - ligand_A_indices, ligand_B_indices, - rdmol_A, rdmol_B, + ds, + pdb_file, + skip, + ligand_A_indices, + ligand_B_indices, + rdmol_A, + rdmol_B, ) except Exception as e: @@ -1751,24 +1746,17 @@ def _structural_analysis( for label in ["ligand_A", "ligand_B"]: if data.get(f"{label}_RMSD"): - fig = plotting.plot_ligand_RMSD(time_ps, - data[f"{label}_RMSD"]) + fig = plotting.plot_ligand_RMSD(time_ps, data[f"{label}_RMSD"]) fig.savefig(output_directory / f"{label}_RMSD.png") plt.close(fig) if data.get(f"{label}_COM_drift"): - fig = plotting.plot_ligand_COM_drift( - time_ps, data[f"{label}_COM_drift"] - ) + fig = plotting.plot_ligand_COM_drift(time_ps, data[f"{label}_COM_drift"]) fig.savefig(output_directory / f"{label}_COM_drift.png") plt.close(fig) npz_file = output_directory / "structural_analysis.npz" - npz_data = { - k: np.asarray(v, dtype=np.float32) - for k, v in data.items() - if v is not None - } + npz_data = {k: np.asarray(v, dtype=np.float32) for k, v in data.items() if v is not None} np.savez_compressed(npz_file, **npz_data) return {"structural_analysis": npz_file} From 0ae11d9b7c187e1a06d09df9c199f075da2823f1 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Thu, 28 May 2026 14:45:37 +0200 Subject: [PATCH 03/17] Small fix --- src/openfe/protocols/openmm_septop/equil_septop_method.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/openfe/protocols/openmm_septop/equil_septop_method.py b/src/openfe/protocols/openmm_septop/equil_septop_method.py index 05d84a1bf..2f3fd8443 100644 --- a/src/openfe/protocols/openmm_septop/equil_septop_method.py +++ b/src/openfe/protocols/openmm_septop/equil_septop_method.py @@ -577,6 +577,9 @@ def _create( analysis = unit_classes[phase]["analysis"]( protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, setup=setup, simulation=simulation, generation=0, From 00d30667527ed7a360f4e5c27153388e996b494c Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Thu, 28 May 2026 15:41:34 +0200 Subject: [PATCH 04/17] Small fix --- src/openfe/protocols/openmm_septop/base_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index bb9a86162..40569913c 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1701,7 +1701,7 @@ def _structural_analysis( structural data, or the analysis error. """ import netCDF4 as nc - from openfe_analysis import plotting + from openfe_analysis.utils import plotting try: with nc.Dataset(trj_file) as ds: From 14381f9abcf16acbbd44a449729e5d58939d6b3b Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 29 May 2026 09:43:22 +0200 Subject: [PATCH 05/17] Fix ligand indices mapping --- src/openfe/protocols/openmm_septop/base_units.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 40569913c..2c3709250 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1870,12 +1870,21 @@ def _execute( rdmol_A = alchem_comps["stateA"][0].to_rdkit() rdmol_B = alchem_comps["stateB"][0].to_rdkit() + # Remap ligand indices from full system to subsampled system + selection_indices = np.array(setup.outputs["selection_indices"]) + ligand_A_indices = np.where( + np.isin(selection_indices, setup.outputs["ligand_A_indices"]) + )[0].tolist() + ligand_B_indices = np.where( + np.isin(selection_indices, setup.outputs["ligand_B_indices"]) + )[0].tolist() + outputs = self.run( trajectory=trajectory, checkpoint=checkpoint, pdb_file=setup.outputs["subsampled_pdb_structure"], - ligand_A_indices=setup.outputs["ligand_A_indices"], - ligand_B_indices=setup.outputs["ligand_B_indices"], + ligand_A_indices=ligand_A_indices, + ligand_B_indices=ligand_B_indices, rdmol_A=rdmol_A, rdmol_B=rdmol_B, scratch_basepath=ctx.scratch, From 811a1b675ab28193cf9851c427fd93acb909cbbb Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 29 May 2026 12:42:02 +0200 Subject: [PATCH 06/17] Fix complex alignment --- src/openfe/protocols/openmm_septop/base_units.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 2c3709250..a9101b661 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1549,7 +1549,7 @@ def _run_complex_analysis( SymmetryCorrectedLigandRMSD, ) from openfe_analysis.utils.apply_transformations import ( - apply_alignment_transformations, + apply_complex_alignment_transformations, ) from openfe_analysis.utils.universe_utils import create_universe_single_state @@ -1568,7 +1568,7 @@ def _run_complex_analysis( prot = u.select_atoms("protein and name CA") lig_A = u.atoms[ligand_A_indices] lig_B = u.atoms[ligand_B_indices] - apply_alignment_transformations(u, protein=prot, ligand=lig_A + lig_B) + apply_complex_alignment_transformations(u, protein=prot, ligand=[lig_A, lig_B]) if prot: prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) @@ -1628,7 +1628,7 @@ def _run_solvent_analysis( """ from openfe_analysis.rmsd import SymmetryCorrectedLigandRMSD from openfe_analysis.utils.apply_transformations import ( - apply_alignment_transformations, + apply_ligand_alignment_transformations, ) from openfe_analysis.utils.universe_utils import create_universe_single_state @@ -1646,7 +1646,7 @@ def _run_solvent_analysis( ]: u = create_universe_single_state(pdb_file, ds, state=state_idx) lig = u.atoms[indices] - apply_alignment_transformations(u, ligand=lig) + apply_ligand_alignment_transformations(u, ligand=lig) lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(step=skip) data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) From 54d48392530aec27a71d90e267fef5a9195e0648 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 29 May 2026 12:52:22 +0200 Subject: [PATCH 07/17] Small fix --- src/openfe/protocols/openmm_septop/base_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index a9101b661..40b8e2d54 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1568,7 +1568,7 @@ def _run_complex_analysis( prot = u.select_atoms("protein and name CA") lig_A = u.atoms[ligand_A_indices] lig_B = u.atoms[ligand_B_indices] - apply_complex_alignment_transformations(u, protein=prot, ligand=[lig_A, lig_B]) + apply_complex_alignment_transformations(u, protein=prot, ligands=[lig_A, lig_B]) if prot: prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) From e1c101ee735489ffa1a94474869443839d977cec Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Fri, 29 May 2026 14:53:30 +0200 Subject: [PATCH 08/17] Add ligand indices to test --- .../protocols/openmm_septop/test_septop_protocol_results.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_results.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_results.py index fded8adff..cfca396a6 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_results.py @@ -76,6 +76,8 @@ def patcher(): ] ), "subsampled_pdb_structure": "subsampled.pdb", + "ligand_A_indices": [0], + "ligand_B_indices": [0], }, ), mock.patch( @@ -90,6 +92,8 @@ def patcher(): ] ), "subsampled_pdb_structure": "subsampled.pdb", + "ligand_A_indices": [0], + "ligand_B_indices": [0], }, ), mock.patch( From 14ce7fd3eab778f0197bee53267da3e180a5119c Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Mon, 1 Jun 2026 15:40:09 +0200 Subject: [PATCH 09/17] Some updates --- .../protocols/openmm_septop/base_units.py | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 40b8e2d54..8564ff27e 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -21,6 +21,7 @@ import gufe import matplotlib.pyplot as plt +import MDAnalysis as mda import numpy as np import numpy.typing as npt import openmm @@ -1516,6 +1517,7 @@ def _run_complex_analysis( ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, + protein_selection: str = "protein and name CA", ) -> dict[str, Any]: """ Run structural analysis for the complex phase. @@ -1536,6 +1538,9 @@ def _run_complex_analysis( RDKit molecule for ligand A, used for symmetry-corrected RMSD. rdmol_B : Chem.Mol RDKit molecule for ligand B, used for symmetry-corrected RMSD. + protein_selection : str + MDAnalysis selection string for the protein atoms used for + alignment and RMSD calculations. Default: "protein and name CA". Returns ------- @@ -1562,13 +1567,13 @@ def _run_complex_analysis( "protein_2D_RMSD": [], "time_ps": None, } - + u_top = mda.Universe(pdb_file) for state_idx in range(n_lambda): - u = create_universe_single_state(pdb_file, ds, state=state_idx) - prot = u.select_atoms("protein and name CA") - lig_A = u.atoms[ligand_A_indices] - lig_B = u.atoms[ligand_B_indices] - apply_complex_alignment_transformations(u, protein=prot, ligands=[lig_A, lig_B]) + universe = create_universe_single_state(u_top._topology, ds, state=state_idx) + prot = universe.select_atoms(protein_selection) + lig_A = universe.atoms[ligand_A_indices] + lig_B = universe.atoms[ligand_B_indices] + apply_complex_alignment_transformations(universe, protein=prot, ligands=[lig_A, lig_B]) if prot: prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) @@ -1585,7 +1590,7 @@ def _run_complex_analysis( data[f"{label}_COM_drift"].append(lig_drift.results.com_drift) if data["time_ps"] is None: - data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt + data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt return data @@ -1638,21 +1643,21 @@ def _run_solvent_analysis( "ligand_B_RMSD": [], "time_ps": None, } - + u_top = mda.Universe(pdb_file) for state_idx in range(n_lambda): for label, indices, rdmol in [ ("ligand_A", ligand_A_indices, rdmol_A), ("ligand_B", ligand_B_indices, rdmol_B), ]: - u = create_universe_single_state(pdb_file, ds, state=state_idx) - lig = u.atoms[indices] - apply_ligand_alignment_transformations(u, ligand=lig) + universe = create_universe_single_state(u_top._topology, ds, state=state_idx) + lig = universe.atoms[indices] + apply_ligand_alignment_transformations(universe, ligand=lig) lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(step=skip) data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) if data["time_ps"] is None: - data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt + data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt return data @@ -1668,6 +1673,7 @@ def _structural_analysis( ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, + protein_selection: str = "protein and name CA", ) -> dict[str, str | pathlib.Path]: """ Run structural analysis using ``openfe-analysis``. @@ -1675,11 +1681,12 @@ def _structural_analysis( Parameters ---------- pdb_file : pathlib.Path - Path to the subsampled PDB file. + Path to the PDB file. trj_file : pathlib.Path Path to the trajectory NetCDF file. output_directory : pathlib.Path - Directory where plots and the NPZ file will be written. + The output directory where plots and the data NPZ file + will be stored. dry : bool Whether or not we are running a dry run. simtype : str @@ -1693,6 +1700,10 @@ def _structural_analysis( RDKit molecule for ligand A, used for symmetry-corrected RMSD. rdmol_B : Chem.Mol RDKit molecule for ligand B, used for symmetry-corrected RMSD. + protein_selection : str + MDAnalysis selection string for the protein atoms used for + alignment and RMSD calculations in the complex phase. + Ignored for the solvent phase. Default "protein and name CA". Returns ------- @@ -1705,7 +1716,6 @@ def _structural_analysis( try: with nc.Dataset(trj_file) as ds: - # Frame stride: aim for ~500 frames per state if hasattr(ds, "PositionInterval"): n_frames = len(range(0, ds.dimensions["iteration"].size, ds.PositionInterval)) else: @@ -1721,6 +1731,7 @@ def _structural_analysis( ligand_B_indices, rdmol_A, rdmol_B, + protein_selection=protein_selection, ) else: data = cls._run_solvent_analysis( @@ -1736,6 +1747,7 @@ def _structural_analysis( except Exception as e: return {"structural_analysis_error": str(e)} + # Generate relevant plots if not a dry run if not dry: time_ps = data.get("time_ps", []) @@ -1755,6 +1767,7 @@ def _structural_analysis( fig.savefig(output_directory / f"{label}_COM_drift.png") plt.close(fig) + # Write out an NPZ with all the relevant analysis data npz_file = output_directory / "structural_analysis.npz" npz_data = {k: np.asarray(v, dtype=np.float32) for k, v in data.items() if v is not None} np.savez_compressed(npz_file, **npz_data) @@ -1770,6 +1783,7 @@ def run( ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, + protein_selection: str = "protein and name CA", dry: bool = False, verbose: bool = True, scratch_basepath: pathlib.Path | None = None, @@ -1793,6 +1807,10 @@ def run( RDKit molecule for ligand A. rdmol_B : Chem.Mol RDKit molecule for ligand B. + protein_selection : str + MDAnalysis selection string for the protein atoms used for + alignment and RMSD calculations in the complex phase. + Ignored for the solvent phase. Default "protein and name CA". dry : bool Do a dry run of the calculation, creating all necessary hybrid system components (topology, system, sampler, etc...) but without @@ -1821,7 +1839,7 @@ def run( settings = self._get_settings() # Energies analysis - if verbose: + if self.verbose: self.logger.info("Analyzing energies") energy_analysis = self._analyze_multistate_energies( @@ -1832,7 +1850,7 @@ def run( dry=dry, ) - if verbose: + if self.verbose: self.logger.info("Analyzing structural outputs") structural_analysis = self._structural_analysis( @@ -1845,6 +1863,7 @@ def run( ligand_B_indices=ligand_B_indices, rdmol_A=rdmol_A, rdmol_B=rdmol_B, + protein_selection=protein_selection, ) return energy_analysis | structural_analysis From d9538480c2c541aa695bab07fc99cb58a28447b5 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Tue, 2 Jun 2026 11:11:44 +0200 Subject: [PATCH 10/17] Add tests for structural analysis --- .../test_septop_protocol_analysis.py | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py new file mode 100644 index 000000000..fafd7307d --- /dev/null +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py @@ -0,0 +1,219 @@ +import pathlib +import numpy as np +import pytest +from rdkit.Chem import SDMolSupplier +from openfe.protocols.openmm_septop.base_units import BaseSepTopAnalysisUnit + + +import pooch +from openfe.data._registry import POOCH_CACHE + +pooch_septop_structural = pooch.create( + path=POOCH_CACHE, + base_url="https://zenodo.org/records/20507106/files/", + registry={ + "septop_structural_results.zip": "md5:cffc193dacb5ad8d26c5467ae05b1a00", + }, +) + + +@pytest.fixture(scope="session") +def septop_structural_results_dir(): + pooch_septop_structural.fetch( + "septop_structural_results.zip", processor=pooch.Unzip() + ) + return pathlib.Path( + POOCH_CACHE / "septop_structural_results.zip.unzip/septop_structural_results" + ) + + +@pytest.fixture(scope="session") +def septop_complex_data(septop_structural_results_dir): + base = septop_structural_results_dir / "complex" + return { + "pdb": base / "alchemical_system.pdb", + "nc": base / "complex.nc", + "ligand_A_indices": np.load(base / "ligand_A_indices.npy").tolist(), + "ligand_B_indices": np.load(base / "ligand_B_indices.npy").tolist(), + "rdmol_A": SDMolSupplier(str(base / "ligand_A.sdf"), removeHs=False)[0], + "rdmol_B": SDMolSupplier(str(base / "ligand_B.sdf"), removeHs=False)[0], + } + + +@pytest.fixture(scope="session") +def septop_solvent_data(septop_structural_results_dir): + base = septop_structural_results_dir / "solvent" + return { + "pdb": base / "alchemical_system.pdb", + "nc": base / "solvent.nc", + "ligand_A_indices": np.load(base / "ligand_A_indices.npy").tolist(), + "ligand_B_indices": np.load(base / "ligand_B_indices.npy").tolist(), + "rdmol_A": SDMolSupplier(str(base / "ligand_A.sdf"), removeHs=False)[0], + "rdmol_B": SDMolSupplier(str(base / "ligand_B.sdf"), removeHs=False)[0], + } + + +@pytest.fixture(scope="session") +def complex_structural_analysis_result(septop_complex_data, tmp_path_factory): + d = septop_complex_data + tmp = tmp_path_factory.mktemp("complex_structural_analysis") + result = BaseSepTopAnalysisUnit._structural_analysis( + pdb_file=d["pdb"], + trj_file=d["nc"], + output_directory=tmp, + dry=False, + simtype="complex", + ligand_A_indices=d["ligand_A_indices"], + ligand_B_indices=d["ligand_B_indices"], + rdmol_A=d["rdmol_A"], + rdmol_B=d["rdmol_B"], + ) + return result, tmp + + +@pytest.fixture(scope="session") +def solvent_structural_analysis_result(septop_solvent_data, tmp_path_factory): + d = septop_solvent_data + tmp = tmp_path_factory.mktemp("solvent_structural_analysis") + result = BaseSepTopAnalysisUnit._structural_analysis( + pdb_file=d["pdb"], + trj_file=d["nc"], + output_directory=tmp, + dry=False, + simtype="solvent", + ligand_A_indices=d["ligand_A_indices"], + ligand_B_indices=d["ligand_B_indices"], + rdmol_A=d["rdmol_A"], + rdmol_B=d["rdmol_B"], + ) + return result, tmp + + +class TestComplexStructuralAnalysis: + + def test_npz_written(self, complex_structural_analysis_result): + result, _ = complex_structural_analysis_result + assert "structural_analysis" in result + assert result["structural_analysis"].exists() + + def test_npz_keys_values_shape(self, complex_structural_analysis_result): + result, _ = complex_structural_analysis_result + npz = np.load(result["structural_analysis"]) + expected_keys = { + "ligand_A_RMSD", "ligand_B_RMSD", + "ligand_A_COM_drift", "ligand_B_COM_drift", + "protein_2D_RMSD", "time_ps", + } + assert set(npz.files) == expected_keys + + # First frame RMSD should be zero (reference frame) + assert npz["ligand_A_RMSD"][0][0] == pytest.approx(0.0, abs=1e-5) + assert npz["ligand_B_RMSD"][0][0] == pytest.approx(0.0, abs=1e-5) + assert npz["ligand_A_COM_drift"][0][0] == pytest.approx(0.0, abs=1e-5) + assert npz["ligand_B_COM_drift"][0][0] == pytest.approx(0.0, abs=1e-5) + + # Time should start at zero + assert npz["time_ps"][0] == pytest.approx(0.0, abs=1e-5) + + # All RMSD and COM drift values should be non-negative + for state in npz["ligand_A_RMSD"]: + assert np.all(state >= 0) + for state in npz["ligand_B_RMSD"]: + assert np.all(state >= 0) + for state in npz["ligand_A_COM_drift"]: + assert np.all(state >= 0) + for state in npz["ligand_B_COM_drift"]: + assert np.all(state >= 0) + + # Should be 13 lambda windows + n_lambda = 13 + for key in expected_keys - {"time_ps"}: + assert len(npz[key]) == n_lambda + + def test_plots_written(self, complex_structural_analysis_result): + result, tmp = complex_structural_analysis_result + expected_plots = { + "ligand_A_RMSD.png", "ligand_B_RMSD.png", + "ligand_A_COM_drift.png", "ligand_B_COM_drift.png", + "protein_2D_RMSD.png", + } + written = {f.name for f in tmp.glob("*.png")} + assert written == expected_plots + + + def test_bad_trajectory_returns_error_dict( + self, septop_complex_data, tmp_path + ): + d = septop_complex_data + result = BaseSepTopAnalysisUnit._structural_analysis( + pdb_file=d["pdb"], + trj_file=tmp_path / "nonexistent.nc", + output_directory=tmp_path, + dry=True, + simtype="complex", + ligand_A_indices=d["ligand_A_indices"], + ligand_B_indices=d["ligand_B_indices"], + rdmol_A=d["rdmol_A"], + rdmol_B=d["rdmol_B"], + ) + + assert "structural_analysis_error" in result + assert "structural_analysis" not in result + + +class TestSolventStructuralAnalysis: + + def test_npz_written(self, solvent_structural_analysis_result): + result, _ = solvent_structural_analysis_result + assert "structural_analysis" in result + assert result["structural_analysis"].exists() + + def test_npz_keys_values_shape(self, solvent_structural_analysis_result): + result, _ = solvent_structural_analysis_result + npz = np.load(result["structural_analysis"]) + expected_keys = {"ligand_A_RMSD", "ligand_B_RMSD", "time_ps"} + assert set(npz.files) == expected_keys + + # First frame RMSD should be zero (reference frame) + assert npz["ligand_A_RMSD"][0][0] == pytest.approx(0.0, abs=1e-5) + assert npz["ligand_B_RMSD"][0][0] == pytest.approx(0.0, abs=1e-5) + + # Time should start at zero + assert npz["time_ps"][0] == pytest.approx(0.0, abs=1e-5) + + # All RMSD and COM drift values should be non-negative + for state in npz["ligand_A_RMSD"]: + assert np.all(state >= 0) + for state in npz["ligand_B_RMSD"]: + assert np.all(state >= 0) + + # Should be 13 lambda windows + n_lambda = 13 + for key in expected_keys - {"time_ps"}: + assert len(npz[key]) == n_lambda + + def test_plots_written(self, solvent_structural_analysis_result): + result, tmp = solvent_structural_analysis_result + expected_plots = {"ligand_A_RMSD.png", "ligand_B_RMSD.png"} + written = {f.name for f in tmp.glob("*.png")} + assert written == expected_plots + + + def test_bad_trajectory_returns_error_dict( + self, septop_solvent_data, tmp_path + ): + d = septop_solvent_data + result = BaseSepTopAnalysisUnit._structural_analysis( + pdb_file=d["pdb"], + trj_file=tmp_path / "nonexistent.nc", + output_directory=tmp_path, + dry=True, + simtype="solvent", + ligand_A_indices=d["ligand_A_indices"], + ligand_B_indices=d["ligand_B_indices"], + rdmol_A=d["rdmol_A"], + rdmol_B=d["rdmol_B"], + ) + + assert "structural_analysis_error" in result + assert "structural_analysis" not in result \ No newline at end of file From 850b1a5b9bb04285986a85786e46ca5d7f06dd6f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:12:14 +0000 Subject: [PATCH 11/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../protocols/openmm_septop/base_units.py | 20 ++++++---- .../test_septop_protocol_analysis.py | 39 ++++++++----------- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 8564ff27e..253797ddd 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1590,7 +1590,9 @@ def _run_complex_analysis( data[f"{label}_COM_drift"].append(lig_drift.results.com_drift) if data["time_ps"] is None: - data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt + data["time_ps"] = ( + np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt + ) return data @@ -1657,7 +1659,9 @@ def _run_solvent_analysis( data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) if data["time_ps"] is None: - data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt + data["time_ps"] = ( + np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt + ) return data @@ -1891,12 +1895,12 @@ def _execute( # Remap ligand indices from full system to subsampled system selection_indices = np.array(setup.outputs["selection_indices"]) - ligand_A_indices = np.where( - np.isin(selection_indices, setup.outputs["ligand_A_indices"]) - )[0].tolist() - ligand_B_indices = np.where( - np.isin(selection_indices, setup.outputs["ligand_B_indices"]) - )[0].tolist() + ligand_A_indices = np.where(np.isin(selection_indices, setup.outputs["ligand_A_indices"]))[ + 0 + ].tolist() + ligand_B_indices = np.where(np.isin(selection_indices, setup.outputs["ligand_B_indices"]))[ + 0 + ].tolist() outputs = self.run( trajectory=trajectory, diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py index fafd7307d..5d57acf4e 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py @@ -1,12 +1,12 @@ import pathlib + import numpy as np +import pooch import pytest from rdkit.Chem import SDMolSupplier -from openfe.protocols.openmm_septop.base_units import BaseSepTopAnalysisUnit - -import pooch from openfe.data._registry import POOCH_CACHE +from openfe.protocols.openmm_septop.base_units import BaseSepTopAnalysisUnit pooch_septop_structural = pooch.create( path=POOCH_CACHE, @@ -19,9 +19,7 @@ @pytest.fixture(scope="session") def septop_structural_results_dir(): - pooch_septop_structural.fetch( - "septop_structural_results.zip", processor=pooch.Unzip() - ) + pooch_septop_structural.fetch("septop_structural_results.zip", processor=pooch.Unzip()) return pathlib.Path( POOCH_CACHE / "septop_structural_results.zip.unzip/septop_structural_results" ) @@ -90,7 +88,6 @@ def solvent_structural_analysis_result(septop_solvent_data, tmp_path_factory): class TestComplexStructuralAnalysis: - def test_npz_written(self, complex_structural_analysis_result): result, _ = complex_structural_analysis_result assert "structural_analysis" in result @@ -100,9 +97,12 @@ def test_npz_keys_values_shape(self, complex_structural_analysis_result): result, _ = complex_structural_analysis_result npz = np.load(result["structural_analysis"]) expected_keys = { - "ligand_A_RMSD", "ligand_B_RMSD", - "ligand_A_COM_drift", "ligand_B_COM_drift", - "protein_2D_RMSD", "time_ps", + "ligand_A_RMSD", + "ligand_B_RMSD", + "ligand_A_COM_drift", + "ligand_B_COM_drift", + "protein_2D_RMSD", + "time_ps", } assert set(npz.files) == expected_keys @@ -133,17 +133,16 @@ def test_npz_keys_values_shape(self, complex_structural_analysis_result): def test_plots_written(self, complex_structural_analysis_result): result, tmp = complex_structural_analysis_result expected_plots = { - "ligand_A_RMSD.png", "ligand_B_RMSD.png", - "ligand_A_COM_drift.png", "ligand_B_COM_drift.png", + "ligand_A_RMSD.png", + "ligand_B_RMSD.png", + "ligand_A_COM_drift.png", + "ligand_B_COM_drift.png", "protein_2D_RMSD.png", } written = {f.name for f in tmp.glob("*.png")} assert written == expected_plots - - def test_bad_trajectory_returns_error_dict( - self, septop_complex_data, tmp_path - ): + def test_bad_trajectory_returns_error_dict(self, septop_complex_data, tmp_path): d = septop_complex_data result = BaseSepTopAnalysisUnit._structural_analysis( pdb_file=d["pdb"], @@ -162,7 +161,6 @@ def test_bad_trajectory_returns_error_dict( class TestSolventStructuralAnalysis: - def test_npz_written(self, solvent_structural_analysis_result): result, _ = solvent_structural_analysis_result assert "structural_analysis" in result @@ -198,10 +196,7 @@ def test_plots_written(self, solvent_structural_analysis_result): written = {f.name for f in tmp.glob("*.png")} assert written == expected_plots - - def test_bad_trajectory_returns_error_dict( - self, septop_solvent_data, tmp_path - ): + def test_bad_trajectory_returns_error_dict(self, septop_solvent_data, tmp_path): d = septop_solvent_data result = BaseSepTopAnalysisUnit._structural_analysis( pdb_file=d["pdb"], @@ -216,4 +211,4 @@ def test_bad_trajectory_returns_error_dict( ) assert "structural_analysis_error" in result - assert "structural_analysis" not in result \ No newline at end of file + assert "structural_analysis" not in result From 53947857b79af74da9351e6882707075eba48337 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Tue, 2 Jun 2026 11:33:50 +0200 Subject: [PATCH 12/17] fix mypy --- .../protocols/openmm_septop/base_units.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 253797ddd..9ce6f7876 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1773,8 +1773,30 @@ def _structural_analysis( # Write out an NPZ with all the relevant analysis data npz_file = output_directory / "structural_analysis.npz" - npz_data = {k: np.asarray(v, dtype=np.float32) for k, v in data.items() if v is not None} - np.savez_compressed(npz_file, **npz_data) + if simtype == "complex": + np.savez_compressed( + npz_file, + ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], + dtype=np.float32), + ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], + dtype=np.float32), + ligand_A_COM_drift=np.asarray(data["ligand_A_COM_drift"], + dtype=np.float32), + ligand_B_COM_drift=np.asarray(data["ligand_B_COM_drift"], + dtype=np.float32), + protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], + dtype=np.float32), + time_ps=np.asarray(data["time_ps"], dtype=np.float32), + ) + else: + np.savez_compressed( + npz_file, + ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], + dtype=np.float32), + ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], + dtype=np.float32), + time_ps=np.asarray(data["time_ps"], dtype=np.float32), + ) return {"structural_analysis": npz_file} def run( From 62a34dcba2d94a0436c8146372881e84f4f78e57 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:34:23 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../protocols/openmm_septop/base_units.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 9ce6f7876..7d1724013 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1776,25 +1776,18 @@ def _structural_analysis( if simtype == "complex": np.savez_compressed( npz_file, - ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], - dtype=np.float32), - ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], - dtype=np.float32), - ligand_A_COM_drift=np.asarray(data["ligand_A_COM_drift"], - dtype=np.float32), - ligand_B_COM_drift=np.asarray(data["ligand_B_COM_drift"], - dtype=np.float32), - protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], - dtype=np.float32), + ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], dtype=np.float32), + ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], dtype=np.float32), + ligand_A_COM_drift=np.asarray(data["ligand_A_COM_drift"], dtype=np.float32), + ligand_B_COM_drift=np.asarray(data["ligand_B_COM_drift"], dtype=np.float32), + protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), time_ps=np.asarray(data["time_ps"], dtype=np.float32), ) else: np.savez_compressed( npz_file, - ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], - dtype=np.float32), - ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], - dtype=np.float32), + ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], dtype=np.float32), + ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], dtype=np.float32), time_ps=np.asarray(data["time_ps"], dtype=np.float32), ) return {"structural_analysis": npz_file} From d1e3abe9d00da66d35c714414fb2e26db671c8fa Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Tue, 2 Jun 2026 12:15:58 +0200 Subject: [PATCH 14/17] Update env --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 247f36f1c..9da35a67c 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,6 @@ dependencies: - lomap2>=3.2.1 - networkx - numpy - - openfe-analysis>=0.4.0 # min pin https://github.com/OpenFreeEnergy/openfe/issues/1834#issuecomment-3920079481, no max to check issues with new versions - openff-interchange-base >=0.5.0,!= 0.5.1 # https://github.com/openforcefield/openff-interchange/issues/1450 and https://github.com/OpenFreeEnergy/openfe/pull/1901 - openff-nagl-base >=0.3.3 - openff-nagl-models>=0.1.2 @@ -54,6 +53,7 @@ dependencies: - threadpoolctl - pip: - git+https://github.com/OpenFreeEnergy/gufe@main + - git+https://github.com/OpenFreeEnergy/openfe_analysis@main - run_constrained: # drop this pin when handled upstream in espaloma-feedstock - smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109 From 14d70723c70dbba4e8984c2dfc702136671b6da2 Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Thu, 4 Jun 2026 14:59:37 +0200 Subject: [PATCH 15/17] Handle empty ligand case --- .../protocols/openmm_septop/base_units.py | 86 ++++++++++++------- .../test_septop_protocol_analysis.py | 34 ++++++++ 2 files changed, 90 insertions(+), 30 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 7d1724013..15af4e56d 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1517,7 +1517,7 @@ def _run_complex_analysis( ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, - protein_selection: str = "protein and name CA", + protein_selection: str, ) -> dict[str, Any]: """ Run structural analysis for the complex phase. @@ -1540,7 +1540,7 @@ def _run_complex_analysis( RDKit molecule for ligand B, used for symmetry-corrected RMSD. protein_selection : str MDAnalysis selection string for the protein atoms used for - alignment and RMSD calculations. Default: "protein and name CA". + alignment and RMSD calculations. Returns ------- @@ -1568,9 +1568,10 @@ def _run_complex_analysis( "time_ps": None, } u_top = mda.Universe(pdb_file) + prot_indices = u_top.select_atoms(protein_selection).indices for state_idx in range(n_lambda): universe = create_universe_single_state(u_top._topology, ds, state=state_idx) - prot = universe.select_atoms(protein_selection) + prot = universe.atoms[prot_indices] lig_A = universe.atoms[ligand_A_indices] lig_B = universe.atoms[ligand_B_indices] apply_complex_alignment_transformations(universe, protein=prot, ligands=[lig_A, lig_B]) @@ -1672,12 +1673,13 @@ def _structural_analysis( trj_file: pathlib.Path, output_directory: pathlib.Path, dry: bool, - simtype: str, + simtype: Literal["complex", "solvent"], ligand_A_indices: list[int], ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, - protein_selection: str = "protein and name CA", + protein_selection: str, + skip: int, ) -> dict[str, str | pathlib.Path]: """ Run structural analysis using ``openfe-analysis``. @@ -1693,7 +1695,7 @@ def _structural_analysis( will be stored. dry : bool Whether or not we are running a dry run. - simtype : str + simtype : Literal["complex", "solvent"] Either ``"complex"`` or ``"solvent"``. Controls whether protein analyses are run and how alignment is applied. ligand_A_indices : list[int] @@ -1707,7 +1709,11 @@ def _structural_analysis( protein_selection : str MDAnalysis selection string for the protein atoms used for alignment and RMSD calculations in the complex phase. - Ignored for the solvent phase. Default "protein and name CA". + Ignored for the solvent phase. + skip : int + Frame stride for structural analysis. If ``None``, a stride is + chosen such that approximately (max.) 500 frames are analyzed per state. + Set to 1 to analyze every frame. Returns ------- @@ -1718,13 +1724,27 @@ def _structural_analysis( import netCDF4 as nc from openfe_analysis.utils import plotting + if len(ligand_A_indices) == 0 or len(ligand_B_indices) == 0: + errmsg = ( + "No ligand atoms found in the subsampled trajectory. " + "This likely means the output_indices selection does not include " + "the ligands. Check the output_indices setting in your protocol " + "settings to ensure ligand atoms are included in the trajectory output." + ) + logger.warning(errmsg) + return {"structural_analysis_error": errmsg} + try: with nc.Dataset(trj_file) as ds: if hasattr(ds, "PositionInterval"): n_frames = len(range(0, ds.dimensions["iteration"].size, ds.PositionInterval)) else: n_frames = ds.dimensions["iteration"].size - skip = max(n_frames // 500, 1) + + if skip is None: + # find skip that would give ~500 frames of output + # max against 1 to avoid skip=0 case + skip = max(n_frames // 500, 1) if simtype == "complex": data = cls._run_complex_analysis( @@ -1773,23 +1793,21 @@ def _structural_analysis( # Write out an NPZ with all the relevant analysis data npz_file = output_directory / "structural_analysis.npz" + npz_data = { + "ligand_A_RMSD": np.asarray(data["ligand_A_RMSD"], dtype=np.float32), + "ligand_B_RMSD": np.asarray(data["ligand_B_RMSD"], dtype=np.float32), + "time_ps": np.asarray(data["time_ps"], dtype=np.float32), + } + if simtype == "complex": - np.savez_compressed( - npz_file, - ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], dtype=np.float32), - ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], dtype=np.float32), - ligand_A_COM_drift=np.asarray(data["ligand_A_COM_drift"], dtype=np.float32), - ligand_B_COM_drift=np.asarray(data["ligand_B_COM_drift"], dtype=np.float32), - protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), - time_ps=np.asarray(data["time_ps"], dtype=np.float32), - ) - else: - np.savez_compressed( - npz_file, - ligand_A_RMSD=np.asarray(data["ligand_A_RMSD"], dtype=np.float32), - ligand_B_RMSD=np.asarray(data["ligand_B_RMSD"], dtype=np.float32), - time_ps=np.asarray(data["time_ps"], dtype=np.float32), - ) + npz_data.update({ + "ligand_A_COM_drift": np.asarray(data["ligand_A_COM_drift"], dtype=np.float32), + "ligand_B_COM_drift": np.asarray(data["ligand_B_COM_drift"], dtype=np.float32), + "protein_2D_RMSD": np.asarray(data["protein_2D_RMSD"], dtype=np.float32), + }) + + np.savez_compressed(npz_file, **npz_data) + return {"structural_analysis": npz_file} def run( @@ -1803,6 +1821,7 @@ def run( rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, protein_selection: str = "protein and name CA", + skip: int | None = None, dry: bool = False, verbose: bool = True, scratch_basepath: pathlib.Path | None = None, @@ -1830,6 +1849,10 @@ def run( MDAnalysis selection string for the protein atoms used for alignment and RMSD calculations in the complex phase. Ignored for the solvent phase. Default "protein and name CA". + skip : int | None + Frame stride for structural analysis. If ``None``, a stride is + chosen such that approximately (max.) 500 frames are analyzed per state. + Set to 1 to analyze every frame. dry : bool Do a dry run of the calculation, creating all necessary hybrid system components (topology, system, sampler, etc...) but without @@ -1883,6 +1906,7 @@ def run( rdmol_A=rdmol_A, rdmol_B=rdmol_B, protein_selection=protein_selection, + skip=skip, ) return energy_analysis | structural_analysis @@ -1909,13 +1933,15 @@ def _execute( rdmol_B = alchem_comps["stateB"][0].to_rdkit() # Remap ligand indices from full system to subsampled system + # selection_indices maps: position in subsampled system to index in full system selection_indices = np.array(setup.outputs["selection_indices"]) - ligand_A_indices = np.where(np.isin(selection_indices, setup.outputs["ligand_A_indices"]))[ - 0 - ].tolist() - ligand_B_indices = np.where(np.isin(selection_indices, setup.outputs["ligand_B_indices"]))[ - 0 - ].tolist() + ligand_A_full_indices = np.array(setup.outputs["ligand_A_indices"]) + ligand_B_full_indices = np.array(setup.outputs["ligand_B_indices"]) + + # Find where the full-system ligand indices appear in the subsampled system + # np.isin returns a boolean mask, np.where converts to positional indices + ligand_A_indices = np.where(np.isin(selection_indices, ligand_A_full_indices))[0].tolist() + ligand_B_indices = np.where(np.isin(selection_indices, ligand_B_full_indices))[0].tolist() outputs = self.run( trajectory=trajectory, diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py index 5d57acf4e..4add98873 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py @@ -65,6 +65,8 @@ def complex_structural_analysis_result(septop_complex_data, tmp_path_factory): ligand_B_indices=d["ligand_B_indices"], rdmol_A=d["rdmol_A"], rdmol_B=d["rdmol_B"], + protein_selection="protein and name CA", + skip=None, ) return result, tmp @@ -83,6 +85,8 @@ def solvent_structural_analysis_result(septop_solvent_data, tmp_path_factory): ligand_B_indices=d["ligand_B_indices"], rdmol_A=d["rdmol_A"], rdmol_B=d["rdmol_B"], + protein_selection="protein and name CA", + skip=None, ) return result, tmp @@ -154,6 +158,8 @@ def test_bad_trajectory_returns_error_dict(self, septop_complex_data, tmp_path): ligand_B_indices=d["ligand_B_indices"], rdmol_A=d["rdmol_A"], rdmol_B=d["rdmol_B"], + protein_selection="protein and name CA", + skip=None, ) assert "structural_analysis_error" in result @@ -208,7 +214,35 @@ def test_bad_trajectory_returns_error_dict(self, septop_solvent_data, tmp_path): ligand_B_indices=d["ligand_B_indices"], rdmol_A=d["rdmol_A"], rdmol_B=d["rdmol_B"], + protein_selection="protein and name CA", + skip=None, ) assert "structural_analysis_error" in result assert "structural_analysis" not in result + + def test_empty_ligand_indices_warning_and_error( + self, septop_solvent_data, tmp_path, caplog + ): + import logging + + d = septop_solvent_data + + with caplog.at_level(logging.WARNING): + result = BaseSepTopAnalysisUnit._structural_analysis( + pdb_file=d["pdb"], + trj_file=tmp_path / "nonexistent.nc", # won't be accessed + output_directory=tmp_path, + dry=True, + simtype="solvent", + ligand_A_indices=[], + ligand_B_indices=[], + rdmol_A=d["rdmol_A"], + rdmol_B=d["rdmol_B"], + protein_selection="protein and name CA", + skip=None, + ) + + assert "structural_analysis_error" in result + assert "structural_analysis" not in result + assert any("No ligand atoms found" in msg for msg in caplog.messages) From fedc1ea5c41a4760925d1c9762a0f57c62710df7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Jun 2026 13:01:37 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/openfe/protocols/openmm_septop/base_units.py | 12 +++++++----- .../openmm_septop/test_septop_protocol_analysis.py | 4 +--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 15af4e56d..807993038 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1800,11 +1800,13 @@ def _structural_analysis( } if simtype == "complex": - npz_data.update({ - "ligand_A_COM_drift": np.asarray(data["ligand_A_COM_drift"], dtype=np.float32), - "ligand_B_COM_drift": np.asarray(data["ligand_B_COM_drift"], dtype=np.float32), - "protein_2D_RMSD": np.asarray(data["protein_2D_RMSD"], dtype=np.float32), - }) + npz_data.update( + { + "ligand_A_COM_drift": np.asarray(data["ligand_A_COM_drift"], dtype=np.float32), + "ligand_B_COM_drift": np.asarray(data["ligand_B_COM_drift"], dtype=np.float32), + "protein_2D_RMSD": np.asarray(data["protein_2D_RMSD"], dtype=np.float32), + } + ) np.savez_compressed(npz_file, **npz_data) diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py index 4add98873..ca29381df 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol_analysis.py @@ -221,9 +221,7 @@ def test_bad_trajectory_returns_error_dict(self, septop_solvent_data, tmp_path): assert "structural_analysis_error" in result assert "structural_analysis" not in result - def test_empty_ligand_indices_warning_and_error( - self, septop_solvent_data, tmp_path, caplog - ): + def test_empty_ligand_indices_warning_and_error(self, septop_solvent_data, tmp_path, caplog): import logging d = septop_solvent_data From 78f01d5785dfb441dda0107abc762eb3206a6ead Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Thu, 4 Jun 2026 15:05:59 +0200 Subject: [PATCH 17/17] mypy fixes --- src/openfe/protocols/openmm_septop/base_units.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 807993038..eb463c86a 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -1512,7 +1512,7 @@ def _run_complex_analysis( cls, ds, pdb_file: pathlib.Path, - skip: int, + skip: int | None, ligand_A_indices: list[int], ligand_B_indices: list[int], rdmol_A: Chem.Mol, @@ -1528,7 +1528,7 @@ def _run_complex_analysis( Open NetCDF dataset for the multistate trajectory. pdb_file : pathlib.Path Path to the subsampled PDB file. - skip : int + skip : int | None Frame stride for analysis. ligand_A_indices : list[int] Atom indices of ligand A in the subsampled system. @@ -1602,7 +1602,7 @@ def _run_solvent_analysis( cls, ds, pdb_file: pathlib.Path, - skip: int, + skip: int | None, ligand_A_indices: list[int], ligand_B_indices: list[int], rdmol_A: Chem.Mol, @@ -1617,7 +1617,7 @@ def _run_solvent_analysis( Open NetCDF dataset for the multistate trajectory. pdb_file : pathlib.Path Path to the subsampled PDB file. - skip : int + skip : int | None Frame stride for analysis. ligand_A_indices : list[int] Atom indices of ligand A in the subsampled system. @@ -1679,7 +1679,7 @@ def _structural_analysis( rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, protein_selection: str, - skip: int, + skip: int | None, ) -> dict[str, str | pathlib.Path]: """ Run structural analysis using ``openfe-analysis``. @@ -1710,7 +1710,7 @@ def _structural_analysis( MDAnalysis selection string for the protein atoms used for alignment and RMSD calculations in the complex phase. Ignored for the solvent phase. - skip : int + skip : int | None Frame stride for structural analysis. If ``None``, a stride is chosen such that approximately (max.) 500 frames are analyzed per state. Set to 1 to analyze every frame. @@ -1808,7 +1808,7 @@ def _structural_analysis( } ) - np.savez_compressed(npz_file, **npz_data) + np.savez_compressed(npz_file, **npz_data) # type: ignore[arg-type] return {"structural_analysis": npz_file}