diff --git a/environment.yml b/environment.yml index e48a81f..3c48d3a 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - openff-units - pip - tqdm + - prolif - pyyaml # for testing - coverage diff --git a/src/openfe_analysis/prolif.py b/src/openfe_analysis/prolif.py new file mode 100644 index 0000000..283eff9 --- /dev/null +++ b/src/openfe_analysis/prolif.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +import numpy as np +from typing import Any, Dict, Optional, Sequence, Tuple +import warnings + +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.guesser.tables import vdwradii as MDA_VDWRADII + +import prolif as plf + + +class ProLIFAnalysis(AnalysisBase): + """ + ProLIF interaction fingerprint analysis for an OpenFEReader Universe. + """ + + def __init__( + self, + universe: mda.Universe, + ligand_ag: mda.AtomGroup, + water_order: int = 3, + protein_cutoff: int = 12, + water_cutoff: int = 8, + interactions: Optional[Sequence[str] | str] = None, + guess_bonds: bool = True, + vdwradii: Optional[Dict[str, float]] = None, + **kwargs, + ) -> None: + """ + Initialize the ProLIF analysis. + + Parameters + ---------- + universe + MDAnalysis Universe containing topology and trajectory. + ligand_ag + mda.AtomGroup representing the ligand. + water_order + Maximum WaterBridge interaction order (water-water interaction). + Only used if "WaterBridge" is tracked. + protein_cutoff + Distance cutoff in angstrom used to define the protein pocket + around the ligand. + water_cutoff + Distance cutoff in angstrom used to define waters considered + around the ligand/protein pocket. + interactions + Which interactions to track: + - None: ProLIF defaults + - "all": all registered (non-bridged; depends on ProLIF version) + - Sequence[str]: explicit list like ["VdWContact", "HBDonor"] + guess_bonds + If True, guess bonds for (protein, ligand, water) so ProLIF can + recognize donors/acceptors and bonded hydrogens. + vdwradii + Optional dict of van der Waals radii used by MDAnalysis bond guesser. + Useful when your topology contains types the guesser doesn't know + (e.g. "Cl", "Na"). If None, uses coded defaults. + """ + self.universe = universe + self.ligand_ag = ligand_ag + self.water_order = water_order + + super().__init__(universe.trajectory, **kwargs) + + # --- Guess bonds once on stable selections so RDKit/ProLIF can detect HBonds --- + if guess_bonds: + if vdwradii is None: + vdwradii = dict(MDA_VDWRADII) + vdwradii.update( + { + "Cl": vdwradii["CL"], + "Br": vdwradii["BR"], + "Na": vdwradii["NA"], + } + ) + + # Protein: guess on the full protein so any pocket residue later has bonds + universe.select_atoms("protein").guess_bonds(vdwradii=vdwradii) + + # Ligand: stable group + self.ligand_ag.guess_bonds(vdwradii=vdwradii) + + # Water: only if you care about water-mediated interactions + if guess_bonds: + wat_all = universe.select_atoms("water") + if wat_all.n_atoms: + wat_all.guess_bonds(vdwradii=vdwradii) + + self.protein_ag = self.universe.select_atoms( + f"protein and byres around {protein_cutoff} group ligand", + ligand=self.ligand_ag, + updating=True, + ) + self.water_ag = self.universe.select_atoms( + f"water and byres around {water_cutoff} (group ligand or group pocket)", + ligand=self.ligand_ag, + pocket=self.protein_ag, + updating=True, + ) + + available = plf.Fingerprint.list_available(show_bridged=True) + + if interactions is None: + fp_interactions = None + + elif interactions == "all": + fp_interactions = "all" + + else: + # Cover case of false interaction + missing = [i for i in interactions if i not in available] + if missing: + raise ValueError( + f"Unknown interaction(s): {missing}. " f"Available: {available}" + ) + fp_interactions = list(interactions) + + self._parameters = None + if ( + fp_interactions is not None + and fp_interactions != "all" + and "WaterBridge" in fp_interactions + ): + if self.water_ag.n_atoms == 0: + warnings.warn( + "WaterBridge selected but water selection is empty at the initial " + "frame; removing WaterBridge from the requested interactions.", + UserWarning, + stacklevel=2, + ) + fp_interactions = [ + interaction + for interaction in fp_interactions + if interaction != "WaterBridge" + ] + else: + self._parameters = { + "WaterBridge": {"water": self.water_ag, "order": self.water_order} + } + + if fp_interactions is None: + self.fp = plf.Fingerprint(parameters=self._parameters) + elif len(fp_interactions) == 0: + self.fp = plf.Fingerprint(parameters=self._parameters) + else: + self.fp = plf.Fingerprint( + interactions=fp_interactions, + parameters=self._parameters, + ) + + def _prepare(self): + self.results.ifp = None + self.results.ifp_df = None + + def _conclude(self): + self.results.ifp = getattr(self.fp, "ifp", None) + + def run( + self, + *, + start: Optional[int] = None, + stop: Optional[int] = None, + step: Optional[int] = None, + residues: Optional[bool] = None, + progress: bool = True, + n_jobs: Optional[int] = None, + parallel_strategy: Optional[str] = None, + converter_kwargs: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, + ) -> "ProLIFAnalysis": + """ + Run the fingerprint calculation over a slice of the trajectory. + + Parameters + ---------- + start, stop, step + Trajectory slicing parameters. + residues + Passed to ProLIF: whether to aggregate interactions with residues. + If None, ProLIF's default is used and interactions with atoms are identified. + progress + Show progress bar. + n_jobs + Number of workers for parallel execution.). + parallel_strategy + ProLIF parallel strategy. If None, this wrapper sets: + - "chunk" for n_jobs None/1 + - "queue" for n_jobs > 1 + converter_kwargs + Two dicts: (ligand_kwargs, protein_kwargs) forwarded to the MDAnalysis→RDKit + converter. If None, we default to: + - ligand: {"inferrer": None, "implicit_hydrogens": False} (avoid valence issues) + - protein: {"implicit_hydrogens": False} (use topology bonds) + + Returns + ------- + self + Returned for fluent chaining. + """ + # Due to FEReader trajectory only certain strategies work with the format + if parallel_strategy is None: + # avoid ProLIF trying to pickle FEReader/netCDF trajectory to auto-pick strategy + parallel_strategy = "chunk" if (n_jobs is None or n_jobs == 1) else "queue" + + _slice = slice(start, stop, step) + traj = self.universe.trajectory[_slice] + + try: + n_total = len(self.universe.trajectory) + s0, s1, s2 = _slice.indices(n_total) + self.frames = np.arange(s0, s1, s2, dtype=int) + self.n_frames = len(traj) + + if ( + hasattr(self.universe.trajectory, "times") + and self.universe.trajectory.times is not None + ): + self.times = np.asarray(self.universe.trajectory.times)[self.frames] + elif getattr(self.universe.trajectory, "dt", None) is not None: + self.times = self.frames * self.universe.trajectory.dt + else: + self.times = None + except Exception: + self.frames = None + self.times = None + self.n_frames = None + + if converter_kwargs is None: + # Avoid Valence errors + converter_kwargs = ( + {"inferrer": None, "implicit_hydrogens": False}, # ligand + {"implicit_hydrogens": False}, # protein + ) + + self.fp.run( + traj, + self.ligand_ag, + self.protein_ag, + residues=residues, + converter_kwargs=converter_kwargs, + progress=progress, + n_jobs=n_jobs, + parallel_strategy=parallel_strategy, + ) + + self._conclude() + return self + + # For now, depending on what we do withe the data + def to_dataframe(self, **kwargs): + """ + Transform fingerprint results to pd.DataFrame. + """ + df = self.fp.to_dataframe(**kwargs) + self.results.ifp_df = df + return df + + def plot_lignetwork( + self, + ligand_mol=None, + *, + frame: Optional[int] = None, + kind: Literal["aggregate", "frame"] = "frame", + display_all: bool = False, + threshold: float = 0.3, + use_coordinates: bool = True, + flatten_coordinates: bool = True, + kekulize: bool = False, + molsize: int = 35, + rotation: float = 0, + carbon: float = 0.16, + width: str = "100%", + height: str = "500px", + fontsize: int = 20, + show_interaction_data: bool = False, + ): + """ + 2D ProLIF ligand-network visualization. + """ + if not hasattr(self.fp, "ifp") or not self.fp.ifp: + raise RuntimeError( + "No ProLIF fingerprint data found. Run `analysis.run(...)` first." + ) + + available_frames = list(self.fp.ifp.keys()) + + if frame is None: + frame = available_frames[0] + + if kind == "frame" and frame not in self.fp.ifp: + preview = available_frames[:10] + suffix = " ..." if len(available_frames) > 10 else "" + raise ValueError( + f"frame={frame} not present in fingerprint results. " + f"Available frames: {preview}{suffix}" + ) + + if frame is not None: + self.universe.trajectory[frame] + + if ligand_mol is None: + ligand_mol = plf.Molecule.from_mda( + self.ligand_ag, + inferrer=None, + implicit_hydrogens=False, + use_segid=self.fp.use_segid, + ) + + return self.fp.plot_lignetwork( + ligand_mol, + kind=kind, + frame=frame, + display_all=display_all, + threshold=threshold, + use_coordinates=use_coordinates, + flatten_coordinates=flatten_coordinates, + kekulize=kekulize, + molsize=molsize, + rotation=rotation, + carbon=carbon, + width=width, + height=height, + fontsize=fontsize, + show_interaction_data=show_interaction_data, + ) + + plot_2d = plot_lignetwork diff --git a/src/openfe_analysis/tests/test_prolif.py b/src/openfe_analysis/tests/test_prolif.py new file mode 100644 index 0000000..d38c3bc --- /dev/null +++ b/src/openfe_analysis/tests/test_prolif.py @@ -0,0 +1,173 @@ +import MDAnalysis as mda +import numpy as np +import pytest +from rdkit.Chem import Lipinski + +from openfe_analysis.reader import FEReader +from openfe_analysis.prolif import ProLIFAnalysis + + +def test_prolifanalysis_runs_vdwcontact( + simulation_skipped_nc, hybrid_system_skipped_pdb +): + """ + Test for identification of interactions + """ + u = mda.Universe( + hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=0 + ) + ligand_ag = u.select_atoms("resname UNK") + + analysis = ProLIFAnalysis( + u, ligand_ag, interactions=["VdWContact"], guess_bonds=True + ) + analysis.run(stop=5, step=1, n_jobs=1, progress=False) + + df = analysis.to_dataframe(dtype=np.uint8) + assert df.shape[0] == 5 + assert hasattr(analysis.fp, "ifp") + assert len(analysis.fp.ifp) == 5 + + # Check AnalysisBase + assert hasattr(analysis, "results") + assert hasattr(analysis.results, "ifp") + assert analysis.results.ifp is analysis.fp.ifp + assert len(analysis.results.ifp) == 5 + + assert analysis.results.ifp_df is df + + # Ensure there is at least one detected interaction across all processed frames + assert sum(len(v) for v in analysis.fp.ifp.values()) > 0 + + +def test_guess_bonds_enables_protein_chemistry( + simulation_skipped_nc, hybrid_system_skipped_pdb +): + """ + Test for protein connectivity + """ + u = mda.Universe( + hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=0 + ) + ligand_ag = u.select_atoms("resname UNK") + + analysis = ProLIFAnalysis( + u, ligand_ag, interactions=["VdWContact"], guess_bonds=True + ) + + # pick a residue from the pocket and check it has connectivity in RDKit + u.trajectory[0] + res_atoms = analysis.protein_ag.residues[0].atoms + res_mol = res_atoms.convert_to("RDKIT", implicit_hydrogens=False) + assert res_mol.GetNumBonds() > 0 + + # ensure the protein donors/acceptors exist + prot_mol = analysis.protein_ag.convert_to("RDKIT", implicit_hydrogens=False) + assert Lipinski.NumHDonors(prot_mol) + Lipinski.NumHAcceptors(prot_mol) > 0 + + +def test_prolifanalysis_accepts_all_keyword( + simulation_skipped_nc, hybrid_system_skipped_pdb +): + """ + The string "all" should be accepted as the special keyword for + all available ProLIF interactions. + """ + u = mda.Universe( + hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=0 + ) + ligand_ag = u.select_atoms("resname UNK") + + analysis = ProLIFAnalysis(u, ligand_ag, interactions="all", guess_bonds=True) + + assert analysis.fp is not None + + +def test_waterbridge_empty_selection_warns_and_skips_parameters( + simulation_skipped_nc, hybrid_system_skipped_pdb, monkeypatch +): + """ + Requesting WaterBridge with an empty water selection should warn + instead of raising, and should not configure WaterBridge parameters. + """ + u = mda.Universe( + hybrid_system_skipped_pdb, simulation_skipped_nc, format=FEReader, index=0 + ) + ligand_ag = u.select_atoms("resname UNK") + + original_select_atoms = u.select_atoms + + def patched_select_atoms(selection, *args, **kwargs): + if selection == "water and byres around 8 (group ligand or group pocket)": + return u.atoms[[]] + return original_select_atoms(selection, *args, **kwargs) + + monkeypatch.setattr(u, "select_atoms", patched_select_atoms) + + with pytest.warns(UserWarning, match="WaterBridge selected"): + analysis = ProLIFAnalysis( + u, + ligand_ag, + interactions=["WaterBridge"], + guess_bonds=True, + ) + + assert analysis._parameters is None + + +def test_plot_2d_builds_ligand_mol_and_delegates(monkeypatch): + """ + plot_2d should build a ligand molecule internally when one is not + provided and delegate to ProLIF's plot_lignetwork. + """ + + class DummyTrajectory: + def __init__(self): + self.last_frame = None + + def __getitem__(self, frame): + self.last_frame = frame + return None + + class DummyFP: + def __init__(self): + self.ifp = {0: {"dummy": []}} + self.use_segid = False + + def plot_lignetwork(self, ligand_mol, **kwargs): + calls["plot_lignetwork"] = (ligand_mol, kwargs) + return "fake-view" + + ligand_ag = object() + calls = {} + analysis = object.__new__(ProLIFAnalysis) + analysis.ligand_ag = ligand_ag + analysis.universe = type( + "DummyUniverse", + (), + {"trajectory": DummyTrajectory()}, + )() + analysis.fp = DummyFP() + + fake_ligand_mol = object() + + def fake_from_mda(atomgroup, **kwargs): + calls["from_mda"] = (atomgroup, kwargs) + return fake_ligand_mol + + monkeypatch.setattr( + "openfe_analysis.prolif.plf.Molecule.from_mda", + fake_from_mda, + ) + + view = analysis.plot_2d(frame=0, kind="frame") + + assert view == "fake-view" + assert calls["from_mda"][0] is ligand_ag + assert calls["from_mda"][1]["inferrer"] is None + assert calls["from_mda"][1]["implicit_hydrogens"] is False + assert calls["from_mda"][1]["use_segid"] == analysis.fp.use_segid + assert calls["plot_lignetwork"][0] is fake_ligand_mol + assert calls["plot_lignetwork"][1]["frame"] == 0 + assert calls["plot_lignetwork"][1]["kind"] == "frame" + assert analysis.universe.trajectory.last_frame == 0