Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 161 additions & 9 deletions src/microplex_us/pipelines/ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,35 @@ class ComparisonGateError(ValueError):
def _assert_refit_effective(
label: str, refit: dict[str, Any], min_reduction: float
) -> None:
"""Fail if a refit did not materially reduce the loss (a no-op refit).
"""Fail if a refit did not move at all (a frozen no-op refit).

A no-op refit (optimized loss ~= initial loss) means that side was never
A frozen refit (optimized loss == initial loss) means that side was never
actually reweighted, so its loss is meaningless for comparison -- usually a
degenerate loss matrix or a total-weight/population mismatch under
``preserve_input``.
``preserve_input``. A refit that moves the loss is effective even if the
full-set loss rises slightly: the refit minimizes the train objective, so an
already-well-calibrated dataset can legitimately see full loss tick up from
the held-out split. Only a frozen no-movement refit is a failure.
"""
initial = float(refit["initial_full_loss"])
optimized = float(refit["optimized_full_loss"])
if optimized > initial - min_reduction:
if not _refit_moved(refit, min_reduction):
initial = float(refit["initial_full_loss"])
optimized = float(refit["optimized_full_loss"])
raise ComparisonGateError(
f"{label} refit was a no-op: optimized loss {optimized:.6g} did not "
f"improve on initial {initial:.6g} ({min_reduction:g} reduction "
f"required). The refit never reweighted this dataset, so the "
f"{label} refit was a no-op: optimized loss {optimized:.6g} is "
f"unchanged from initial {initial:.6g} (no movement beyond "
f"{min_reduction:g}). The refit never reweighted this dataset, so the "
f"comparison is meaningless -- likely a degenerate loss matrix or a "
f"total-weight/population mismatch under preserve_input. Pass "
f"assert_refit_effective=False only to deliberately accept this."
)


def _refit_moved(refit: dict[str, Any], min_reduction: float) -> bool:
initial = float(refit["initial_full_loss"])
optimized = float(refit["optimized_full_loss"])
return abs(optimized - initial) > float(min_reduction)


def _assert_baseline_sane(
score_summary: dict[str, Any], max_msre: float
) -> dict[str, Any]:
Expand Down Expand Up @@ -351,6 +360,12 @@ def build_sound_ecps_replacement_comparison(
if assert_refit_effective:
_assert_refit_effective("candidate", candidate_refit, min_refit_loss_reduction)
_assert_refit_effective("baseline", baseline_refit, min_refit_loss_reduction)
candidate_refit_effective_passed = _refit_moved(
candidate_refit, min_refit_loss_reduction
)
baseline_refit_effective_passed = _refit_moved(
baseline_refit, min_refit_loss_reduction
)

protected_family_losses = _protected_family_losses(
target_names=target_names,
Expand Down Expand Up @@ -484,6 +499,9 @@ def build_sound_ecps_replacement_comparison(
"score_source": score_source,
"exact_rescore_requested": bool(exact_rescore),
"exact_rescore_status": exact_rescore_status,
"candidate_refit_effective_passed": candidate_refit_effective_passed,
"baseline_refit_effective_passed": baseline_refit_effective_passed,
"ecps_refit_effective_passed": baseline_refit_effective_passed,
"candidate_refit_config": refit_config,
"baseline_refit_config": refit_config,
"symmetric_refit": True,
Expand All @@ -504,6 +522,23 @@ def build_sound_ecps_replacement_comparison(
),
}
)
frozen_baseline_certificate = _frozen_ecps_baseline_certificate(
baseline_dataset_path=baseline_path,
policyengine_targets_db_path=resolved_targets_db,
policyengine_us_data_repo=policyengine_us_data_repo,
period=period,
target_names=target_names,
target_scope=target_scope,
holdout_target_fraction=holdout_target_fraction,
holdout_target_seed=holdout_target_seed,
matched_sample_method=matched_sample_method,
refit_config=refit_config,
skip_tax_expenditure_targets=skip_tax_expenditure_targets,
exact_rescore=exact_rescore,
score_source=score_source,
baseline_sanity=baseline_sanity,
score_summary=score_summary,
)
payload = {
"schema_version": 1,
"metric": "sound_ecps_replacement_comparison",
Expand All @@ -523,11 +558,13 @@ def build_sound_ecps_replacement_comparison(
"score_candidate_only": False,
"refit_objective_matches_scoring": objective_identity_passed,
"ecps_refit_recovery_passed": ecps_refit_recovery_passed,
"ecps_refit_effective_passed": baseline_refit_effective_passed,
"holdout_target_fraction": float(holdout_target_fraction),
"holdout_targets": int(holdout_mask.sum()),
"target_scope_filter": target_scope,
"protected_family_losses": protected_family_losses,
},
"frozen_ecps_baseline_certificate": frozen_baseline_certificate,
"entity_structure": {
"candidate_source": _entity_structure_summary(
candidate_path,
Expand Down Expand Up @@ -1686,6 +1723,121 @@ def _sha256(path: Path) -> str:
return digest.hexdigest()


def _frozen_ecps_baseline_certificate(
*,
baseline_dataset_path: Path,
policyengine_targets_db_path: Path | None,
policyengine_us_data_repo: str | Path | None,
period: int,
target_names: list[str],
target_scope: str,
holdout_target_fraction: float,
holdout_target_seed: int,
matched_sample_method: str,
refit_config: dict[str, Any],
skip_tax_expenditure_targets: bool,
exact_rescore: bool,
score_source: str,
baseline_sanity: dict[str, Any],
score_summary: dict[str, Any],
) -> dict[str, Any]:
"""Freeze the eCPS baseline surface used for this numeric verdict.

Promotion gates consume this certificate and compare it to the pinned
benchmark manifest. That prevents a release from passing on a live
recomputation against a different eCPS H5, target DB, scorer checkout, or
scoring config.
"""

scoring_config = {
"period": int(period),
"target_profile": "pe_native_broad",
"target_scope": str(target_scope),
"holdout_target_fraction": float(holdout_target_fraction),
"holdout_target_seed": int(holdout_target_seed),
"matched_sample_method": str(matched_sample_method),
"refit_config": dict(refit_config),
"skip_tax_expenditure_targets": bool(skip_tax_expenditure_targets),
"exact_rescore": bool(exact_rescore),
"score_source": str(score_source),
"comparison_bad_targets": list(_comparison_bad_targets()),
}
baseline_metrics = {
key: score_summary.get(key)
for key in (
"baseline_initial_enhanced_cps_native_loss",
"baseline_enhanced_cps_native_loss",
"baseline_train_loss",
"baseline_holdout_loss",
"baseline_unweighted_msre",
"n_targets_kept",
"n_national_targets",
"n_state_targets",
)
if score_summary.get(key) is not None
}
return {
"schema_version": 1,
"certificate_type": "frozen_production_ecps_baseline",
"period": int(period),
"baseline_dataset": _dataset_descriptor(baseline_dataset_path),
"target_db": (
_dataset_descriptor(policyengine_targets_db_path)
if policyengine_targets_db_path is not None
else None
),
"policyengine_us_data": _git_repo_descriptor(policyengine_us_data_repo),
"target_surface": {
"target_profile": "pe_native_broad",
"target_scope": str(target_scope),
"target_count": int(len(target_names)),
"target_names_sha256": _canonical_json_sha256(list(target_names)),
},
"scoring_config": {
**scoring_config,
"sha256": _canonical_json_sha256(scoring_config),
},
"baseline_metrics": baseline_metrics,
"baseline_sanity": dict(baseline_sanity),
}


def _git_repo_descriptor(repo_path: str | Path | None) -> dict[str, Any] | None:
if repo_path is None:
return None
repo = Path(repo_path).expanduser().resolve()
descriptor: dict[str, Any] = {"repo": str(repo)}
commit = _git_output_or_none(repo, "rev-parse", "HEAD")
if commit:
descriptor["commit"] = commit
status = _git_output_or_none(repo, "status", "--porcelain")
if status is not None:
descriptor["dirty"] = bool(status)
return descriptor


def _git_output_or_none(repo: Path, *args: str) -> str | None:
completed = subprocess.run(
["git", "-C", str(repo), *args],
check=False,
capture_output=True,
text=True,
)
if completed.returncode != 0:
return None
return completed.stdout.strip()


def _canonical_json_sha256(payload: Any) -> str:
encoded = json.dumps(
payload,
sort_keys=True,
separators=(",", ":"),
default=str,
).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()


def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description=(
Expand Down
Loading
Loading