Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions aai_cli/core/der.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Diarization error rate (DER) scoring for `assembly eval`, dependency-free.

The companion to :mod:`aai_cli.core.wer`: where WER scores *what* was said, DER
scores *who* spoke when. It is the standard NIST/pyannote metric — the fraction
of reference speech time that is misattributed — computed here in plain Python so
adding diarization metrics costs no new dependency (``pyannote.metrics`` drags in
numpy/scipy/pandas; the lighter PyPI options still pull numpy or compile a C++
extension). No SDK, no Rich: the command layer owns all rendering.

Both reference and hypothesis are sequences of speaker-labelled :class:`Segment`s.
Speaker labels are arbitrary, so the speakers are optimally mapped one-to-one
before errors are counted (an exact search — diarization speaker counts are
small). The timeline is partitioned at every segment boundary into atomic
intervals; within each, the per-interval NIST tally (missed / false-alarm /
confusion) is summed, all weighted by reference *speaker*-time so overlapping
speech is counted once per concurrent speaker.
"""

from __future__ import annotations

import itertools
from collections.abc import Sequence
from dataclasses import dataclass


@dataclass(frozen=True)
class Segment:
"""A stretch of speech (``start`` to ``end``, in seconds) attributed to one speaker."""

speaker: str
start: float
end: float


@dataclass(frozen=True)
class Score:
"""Diarization error against ``speech`` seconds of reference speech.

The three NIST components are kept separately (so a caller can show the
breakdown) and ``errors`` is their sum; pooled across files for corpus DER.
"""

missed: float
false_alarm: float
confusion: float
speech: float

@property
def errors(self) -> float:
"""Total misattributed time: missed + false-alarm + speaker-confusion seconds."""
return self.missed + self.false_alarm + self.confusion

@property
def der(self) -> float:
"""Diarization error rate: error time over reference speech time.

The caller guarantees a reference with speech (an empty reference makes
DER undefined, the same contract :class:`wer.Score` keeps for ``words``).
"""
return self.errors / self.speech


def _boundaries(reference: Sequence[Segment], hypothesis: Sequence[Segment]) -> list[float]:
"""The sorted, de-duplicated segment endpoints that partition the timeline.

Between two consecutive boundaries every segment is wholly present or wholly
absent, so each atomic interval has a fixed set of active speakers.
"""
return sorted({point for seg in (*reference, *hypothesis) for point in (seg.start, seg.end)})


def _active(segments: Sequence[Segment], start: float, end: float) -> set[str]:
"""The distinct speakers whose segment covers the atomic interval ``[start, end)``."""
return {seg.speaker for seg in segments if seg.start <= start and seg.end >= end}


def _speakers(segments: Sequence[Segment]) -> list[str]:
"""The distinct speaker labels in ``segments``, in a deterministic order."""
return sorted({seg.speaker for seg in segments})


def _weight(weights: list[list[float]], row: int, col: int) -> float:
"""The matrix weight at ``(row, col)``, or 0 outside it — the zero-padding
that lets the larger side's unmatched speakers cost nothing."""
if row < len(weights) and col < len(weights[0]):
return weights[row][col]
return 0.0


def _max_weight_assignment(weights: list[list[float]]) -> float:
"""The greatest total weight of a one-to-one row-to-column assignment.

Exact search: the matrix is zero-padded to ``size x size`` and every
permutation (speaker mapping) is tried. Diarization files have few speakers,
so the factorial search stays cheap, and padding sidesteps the orientation
branch a rectangular search would need.
"""
size = max(len(weights), len(weights[0]))
return max(
sum(_weight(weights, row, col) for row, col in enumerate(perm))
for perm in itertools.permutations(range(size))
)


def _correct_time(
reference: Sequence[Segment],
hypothesis: Sequence[Segment],
cooccurrence: dict[tuple[str, str], float],
) -> float:
"""Correctly attributed speech time under the optimal speaker mapping.

``cooccurrence[(ref, hyp)]`` is how long that reference and hypothesis
speaker were concurrently active; the best one-to-one mapping maximises the
matched total (an empty hypothesis maps to nothing, so it scores 0).
"""
ref_speakers, hyp_speakers = _speakers(reference), _speakers(hypothesis)
if not ref_speakers:
return 0.0
weights = [[cooccurrence.get((ref, hyp), 0.0) for hyp in hyp_speakers] for ref in ref_speakers]
return _max_weight_assignment(weights)


def score(reference: Sequence[Segment], hypothesis: Sequence[Segment]) -> Score:
"""Score a hypothesis diarization against a reference.

Walks the shared timeline once, tallying missed speech (reference speakers
with no hypothesis speaker to cover them), false alarms (the reverse), and
co-occurrence per speaker pair; speaker confusion is then the matched
overlap that the optimal mapping could *not* account for.
"""
cooccurrence: dict[tuple[str, str], float] = {}
missed = false_alarm = matched = speech = 0.0
boundaries = _boundaries(reference, hypothesis)
for start, end in itertools.pairwise(boundaries):
duration = end - start
ref_active = _active(reference, start, end)
hyp_active = _active(hypothesis, start, end)
speech += duration * len(ref_active)
missed += duration * max(0, len(ref_active) - len(hyp_active))
false_alarm += duration * max(0, len(hyp_active) - len(ref_active))
matched += duration * min(len(ref_active), len(hyp_active))
for ref in ref_active:
for hyp in hyp_active:
cooccurrence[(ref, hyp)] = cooccurrence.get((ref, hyp), 0.0) + duration
confusion = matched - _correct_time(reference, hypothesis, cooccurrence)
return Score(missed=missed, false_alarm=false_alarm, confusion=confusion, speech=speech)


def pooled(scores: list[Score]) -> Score:
"""Corpus-level score: error and speech time summed across files (DER is then
total error time over total reference time, not a mean of per-file rates)."""
return Score(
missed=sum(item.missed for item in scores),
false_alarm=sum(item.false_alarm for item in scores),
confusion=sum(item.confusion for item in scores),
speech=sum(item.speech for item in scores),
)
126 changes: 126 additions & 0 deletions tests/test_der.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""DER scoring (`aai_cli.core.der`): timeline tally, optimal speaker mapping, pooling."""

import dataclasses

import pytest

from aai_cli.core import der


def _assign(obj, attribute, value):
setattr(obj, attribute, value)


def seg(speaker: str, start: float, end: float) -> der.Segment:
return der.Segment(speaker, start, end)


def test_score_is_an_immutable_value():
with pytest.raises(dataclasses.FrozenInstanceError):
_assign(der.Score(missed=0.0, false_alarm=0.0, confusion=0.0, speech=1.0), "missed", 1.0)


def test_segment_is_an_immutable_value():
with pytest.raises(dataclasses.FrozenInstanceError):
_assign(seg("A", 0, 10), "speaker", "B")


def test_errors_and_der_combine_the_components():
score = der.Score(missed=1.0, false_alarm=2.0, confusion=3.0, speech=10.0)
assert score.errors == 6.0
assert score.der == 0.6


def test_identical_diarization_scores_zero():
ref = [seg("A", 0, 10), seg("B", 10, 20)]
score = der.score(ref, ref)
assert score == der.Score(missed=0.0, false_alarm=0.0, confusion=0.0, speech=20.0)
assert score.der == 0.0


def test_speaker_labels_are_mapped_optimally():
# Same timing, relabelled and reversed: the optimal 1:1 mapping recovers it,
# so a correct scorer reports no error despite none of the labels matching.
ref = [seg("A", 0, 10), seg("B", 10, 20)]
hyp = [seg("spk_1", 10, 20), seg("spk_0", 0, 10)]
assert der.score(ref, hyp).der == 0.0


def test_missing_hypothesis_is_all_missed_speech():
score = der.score([seg("A", 0, 10)], [])
assert score == der.Score(missed=10.0, false_alarm=0.0, confusion=0.0, speech=10.0)
assert score.der == 1.0


def test_hypothesis_speech_outside_the_reference_is_false_alarm():
score = der.score([seg("A", 0, 10)], [seg("X", 0, 15)])
assert score == der.Score(missed=0.0, false_alarm=5.0, confusion=0.0, speech=10.0)
assert score.der == 0.5


def test_no_reference_speech_is_pure_false_alarm():
# An empty reference: every hypothesis second is a false alarm and there is
# no speech to map against (DER itself is undefined, but the tally holds).
score = der.score([], [seg("X", 0, 10)])
assert score == der.Score(missed=0.0, false_alarm=10.0, confusion=0.0, speech=0.0)


def test_one_hypothesis_speaker_split_across_two_references_is_confusion():
# A single hypothesis speaker covers both reference speakers; the mapping can
# only credit the one it overlaps most (B, 10s), the rest (A, 3s) is confusion.
ref = [seg("A", 0, 3), seg("B", 3, 13)]
hyp = [seg("X", 0, 13)]
score = der.score(ref, hyp)
assert score == der.Score(missed=0.0, false_alarm=0.0, confusion=3.0, speech=13.0)
assert score.der == pytest.approx(3 / 13)


def test_one_reference_speaker_split_across_two_hypotheses_is_confusion():
# The mirror case (more hypothesis than reference speakers): only one of the
# two hypothesis speakers can be mapped to A, the other 5s is confusion.
ref = [seg("A", 0, 10)]
hyp = [seg("X", 0, 5), seg("Y", 5, 10)]
score = der.score(ref, hyp)
assert score == der.Score(missed=0.0, false_alarm=0.0, confusion=5.0, speech=10.0)
assert score.der == 0.5


def test_disjoint_extra_speakers_are_missed_and_false_alarm():
# A<->X overlap perfectly; reference B (10s) has no hypothesis (missed) and
# hypothesis Y (10s) has no reference (false alarm). The leftover B/Y pair
# never co-occurs, so the mapping cannot credit it as correct.
ref = [seg("A", 0, 10), seg("B", 10, 20)]
hyp = [seg("X", 0, 10), seg("Y", 20, 30)]
score = der.score(ref, hyp)
assert score == der.Score(missed=10.0, false_alarm=10.0, confusion=0.0, speech=20.0)
assert score.der == 1.0


def test_optimal_mapping_beats_a_greedy_one():
# Co-occurrence is (A,X)=10, (A,Y)=9, (B,X)=8. Greedy takes A↔X (10) and is
# then stuck with B↔Y (0) for 10 correct; the optimal A↔Y + B↔X gives 17
# correct, so confusion is 27-17=10, not 27-10=17. Pins the max-assignment.
ref = [seg("A", 0, 19), seg("B", 19, 27)]
hyp = [seg("X", 0, 10), seg("Y", 10, 19), seg("X", 19, 27)]
score = der.score(ref, hyp)
assert score == der.Score(missed=0.0, false_alarm=0.0, confusion=10.0, speech=27.0)
assert score.der == pytest.approx(10 / 27)


def test_overlapping_reference_speakers_count_per_speaker():
# [5,10) has two reference speakers talking at once, so it contributes 2x its
# 5s of wall-clock to reference speech (15s wall-clock -> 20s speaker-time).
ref = [seg("A", 0, 10), seg("B", 5, 15)]
score = der.score(ref, ref)
assert score == der.Score(missed=0.0, false_alarm=0.0, confusion=0.0, speech=20.0)


def test_pooled_sums_components_for_corpus_der():
total = der.pooled(
[
der.Score(missed=1.0, false_alarm=2.0, confusion=3.0, speech=10.0),
der.Score(missed=0.0, false_alarm=1.0, confusion=1.0, speech=30.0),
]
)
assert total == der.Score(missed=1.0, false_alarm=3.0, confusion=4.0, speech=40.0)
assert total.der == 0.2
Loading