Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ms2pip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,18 @@ def _infer_output_name(


@click.group()
@click.option("--logging-level", "-l", type=click.Choice(LOGGING_LEVELS.keys()), default="INFO")
@click.option(
"--logging-level",
"-l",
type=click.Choice(LOGGING_LEVELS.keys(), case_sensitive=False),
default="INFO",
)
@click.version_option(version=__version__)
def cli(*args, **kwargs):
logging.basicConfig(
format="%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=LOGGING_LEVELS[kwargs["logging_level"]],
level=LOGGING_LEVELS[kwargs["logging_level"].upper()],
handlers=[RichHandler(rich_tracebacks=True, show_level=True, show_path=False)],
)
rich.print(build_credits())
Expand Down
120 changes: 26 additions & 94 deletions ms2pip/_spectrum_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from typing import NamedTuple

import numpy as np
from psm_utils import PSM, PSMList, Peptidoform
from ms2rescore_rs import (
AnnotatedMS2Spectrum, # type: ignore[ty:unresolved-import]
MS2Spectrum, # type: ignore[ty:unresolved-import]
Precursor, # type: ignore[ty:unresolved-import]
annotate_ms2_spectra, # type: ignore[ty:unresolved-import]
get_ms2_spectra, # type: ignore[ty:unresolved-import]
)
from psm_utils import PSM, Peptidoform, PSMList

import ms2pip.exceptions as exceptions
from ms2pip.constants import MODELS
Expand All @@ -28,12 +28,12 @@


class MatchedSpectrum(NamedTuple):
"""A PSM matched to its preprocessed observed spectrum and peak annotations."""
"""A PSM matched to its preprocessed observed spectrum and annotated spectrum."""

psm_index: int
psm: PSM
spectrum: ObservedSpectrum
peak_annotations: list
annotated_spectrum: AnnotatedMS2Spectrum


def _read_raw_spectra(spectrum_file: str) -> Generator[MS2Spectrum, None, None]:
Expand Down Expand Up @@ -108,11 +108,11 @@ def annotate_spectrum(
model: str,
ms2_tolerance: float,
ms2_tolerance_mode: str,
) -> list[list[tuple]]:
) -> AnnotatedMS2Spectrum:
"""
Annotate an ObservedSpectrum using ms2rescore-rs.

Returns peak annotations as plain Python lists of ``(series, position, charge)`` tuples.
Returns the AnnotatedMS2Spectrum object from ms2rescore-rs.
"""
ms2_spectrum = MS2Spectrum(
identifier=spectrum.identifier or "",
Expand All @@ -126,75 +126,16 @@ def annotate_spectrum(
)
frag_model = MODELS[model]["fragmentation"]
proforma = proforma_to_mass_shift(psm.peptidoform)
seq_len = len(psm.peptidoform.parsed_sequence)

annotated = annotate_ms2_spectra(
spectra=[ms2_spectrum],
proformas=[proforma],
seq_lens=[seq_len],
fragmentation_model=frag_model,
mass_mode="monoisotopic",
tolerance_value=float(ms2_tolerance),
tolerance_mode=ms2_tolerance_mode.lower(),
)
return [
[(a.series, a.position, a.charge) for a in peak_anns]
for peak_anns in annotated[0].peak_annotations
]


def targets_from_annotations(
peak_annotations: list,
intensity: np.ndarray,
ion_types: list[str],
seq_len: int,
) -> dict[str, np.ndarray]:
"""
Extract observed intensity targets from peak annotations.

Converts per-peak fragment annotations into per-ion-type intensity arrays.

Parameters
----------
peak_annotations
Per-peak annotations. Each element is a list of annotations for that peak.
Annotations can be :py:class:`ms2rescore_rs.FragmentAnnotation` objects or
``(series, position, charge)`` tuples.
intensity
Preprocessed intensity array (TIC-normalized, log2-transformed).
ion_types
Ion types to extract, e.g. ``["b", "y"]`` or ``["b", "y", "b2", "y2"]``.
seq_len
Length of the peptide sequence (number of amino acids).

Returns
-------
targets
Dict mapping ion type to intensity array of length ``seq_len - 1``.

"""
n_ions = seq_len - 1
floor_value = np.float32(np.log2(0.001))
targets = {ion: np.full(n_ions, floor_value, dtype=np.float32) for ion in ion_types}

for peak_idx, peak_anns in enumerate(peak_annotations):
for ann in peak_anns:
if isinstance(ann, tuple):
series, position, charge = ann
else:
series, position, charge = ann.series, ann.position, ann.charge

ion_key = series if charge == 1 else f"{series}{charge}"

if ion_key not in targets:
continue

pos = position - 1
if 0 <= pos < n_ions:
if intensity[peak_idx] > targets[ion_key][pos]:
targets[ion_key][pos] = intensity[peak_idx]

return targets
return annotated[0]


def _load_and_match_spectra(
Expand Down Expand Up @@ -224,6 +165,7 @@ def _load_and_match_spectra(
psms_by_specid[str(psm.spectrum_id)].append((i, psm))

# Step 1: Read raw spectra and match to PSMs (no conversion yet)
logger.info("Reading spectra from file...")
matched_raw: list[tuple[str, MS2Spectrum, list[tuple[int, PSM]]]] = []
for spectrum in _read_raw_spectra(str(spectrum_file)):
match = spectrum_id_regex.search(str(spectrum.identifier))
Expand All @@ -245,28 +187,28 @@ def _load_and_match_spectra(
return []

# Step 2: Batch annotate all matched spectra (single Rust call, Rayon-parallelized)
logger.debug("Annotating %d matched spectra...", len(matched_raw))
batch_spectra = []
batch_proformas = []
batch_seq_lens = []
batch_indices = [] # (matched_raw_idx, psm_within_spectrum_idx)

for raw_idx, (_, spectrum, psm_pairs) in enumerate(matched_raw):
for psm_idx, (_, psm) in enumerate(psm_pairs):
batch_spectra.append(spectrum)
batch_proformas.append(proforma_to_mass_shift(psm.peptidoform))
batch_seq_lens.append(len(psm.peptidoform.parsed_sequence))
batch_indices.append((raw_idx, psm_idx))

frag_model = MODELS[model]["fragmentation"]
logger.debug("Starting annotation...")
annotated_spectra = annotate_ms2_spectra(
spectra=batch_spectra,
proformas=batch_proformas,
seq_lens=batch_seq_lens,
fragmentation_model=frag_model,
mass_mode="monoisotopic",
tolerance_value=float(ms2_tolerance),
tolerance_mode=ms2_tolerance_mode.lower(),
)
logger.debug("Annotation complete.")

# Step 3: Convert to ObservedSpectrum, preprocess, and assemble results
preprocessed_cache: dict[str, ObservedSpectrum] = {}
Expand All @@ -281,12 +223,9 @@ def _load_and_match_spectra(
_preprocess_spectrum(obs, model)
preprocessed_cache[spec_id] = obs

peak_annotations = [
[(a.series, a.position, a.charge) for a in peak_anns]
for peak_anns in annotated_spectra[batch_idx].peak_annotations
]

results.append(MatchedSpectrum(psm_index, psm, preprocessed_cache[spec_id], peak_annotations))
results.append(
MatchedSpectrum(psm_index, psm, preprocessed_cache[spec_id], annotated_spectra[batch_idx])
)

return results

Expand All @@ -309,7 +248,9 @@ def _preloaded_to_annotations(
# Convert to ObservedSpectrum and preprocess; store raw spectra and annotations
preloaded_spectra: dict[str, ObservedSpectrum] = {}
raw_spectra: dict[str, MS2Spectrum] = {}
preloaded_annotations: dict[str, list] | None = {} if spectra_are_annotated else None
preloaded_annotated: dict[str, AnnotatedMS2Spectrum] | None = (
{} if spectra_are_annotated else None
)
for psm in psm_list:
spec_id = str(psm.spectrum_id)
if spec_id in preloaded_spectra:
Expand All @@ -322,11 +263,8 @@ def _preloaded_to_annotations(
raw_spectra[spec_id] = spectrum # keep original for annotation
if spectra_are_annotated:
assert isinstance(spectrum, AnnotatedMS2Spectrum)
assert preloaded_annotations is not None
preloaded_annotations[spec_id] = [
[(a.series, a.position, a.charge) for a in peak_anns]
for peak_anns in spectrum.peak_annotations
]
assert preloaded_annotated is not None
preloaded_annotated[spec_id] = spectrum

# Build MatchedSpectrum list
psm_spectrum_annotations: list[MatchedSpectrum] = []
Expand All @@ -337,30 +275,30 @@ def _preloaded_to_annotations(
obs_spectrum = preloaded_spectra.get(spec_id)
if obs_spectrum is None:
continue
if preloaded_annotations is not None and spec_id in preloaded_annotations:
if preloaded_annotated is not None and spec_id in preloaded_annotated:
psm_spectrum_annotations.append(
MatchedSpectrum(i, psm, obs_spectrum, preloaded_annotations[spec_id])
MatchedSpectrum(i, psm, obs_spectrum, preloaded_annotated[spec_id])
)
else:
psm_spectrum_annotations.append(MatchedSpectrum(i, psm, obs_spectrum, []))
# Placeholder -- will be replaced after batch annotation below
psm_spectrum_annotations.append(
MatchedSpectrum(i, psm, obs_spectrum, None) # type: ignore[arg-type]
)
needs_annotation.append(len(psm_spectrum_annotations) - 1)

# Batch annotate any unannotated spectra using original MS2Spectrum objects
if needs_annotation:
frag_model = MODELS[model]["fragmentation"]
batch_spectra = []
batch_proformas = []
batch_seq_lens = []
for idx in needs_annotation:
m = psm_spectrum_annotations[idx]
batch_spectra.append(raw_spectra[str(m.psm.spectrum_id)])
batch_proformas.append(proforma_to_mass_shift(m.psm.peptidoform))
batch_seq_lens.append(len(m.psm.peptidoform.parsed_sequence))

annotated = annotate_ms2_spectra(
spectra=batch_spectra,
proformas=batch_proformas,
seq_lens=batch_seq_lens,
fragmentation_model=frag_model,
mass_mode="monoisotopic",
tolerance_value=float(ms2_tolerance),
Expand All @@ -369,11 +307,7 @@ def _preloaded_to_annotations(

for j, idx in enumerate(needs_annotation):
m = psm_spectrum_annotations[idx]
peak_anns = [
[(a.series, a.position, a.charge) for a in anns]
for anns in annotated[j].peak_annotations
]
psm_spectrum_annotations[idx] = m._replace(peak_annotations=peak_anns)
psm_spectrum_annotations[idx] = m._replace(annotated_spectrum=annotated[j])

return psm_spectrum_annotations

Expand All @@ -397,9 +331,7 @@ def resolve_spectra(
]
if all(has_spectrum):
if spectrum_file is not None:
logger.warning(
"PSMs already have preloaded spectra; `spectrum_file` will be ignored."
)
logger.warning("PSMs already have preloaded spectra; `spectrum_file` will be ignored.")
matched = _preloaded_to_annotations(psm_list, model, ms2_tolerance, ms2_tolerance_mode)
elif not any(has_spectrum):
if spectrum_file is None:
Expand Down
2 changes: 1 addition & 1 deletion ms2pip/_utils/xgb_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

logger = logging.getLogger(__name__)

_MAX_PREDICTION_THREADS = 16
_MAX_PREDICTION_THREADS = 32


def validate_model(model: str, model_dir: str | Path | None = None) -> Path:
Expand Down
Loading
Loading