Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions src/microplex_us/pipelines/ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,55 @@
"employment_income_before_lsr",
)

_BENCHMARK_MANIFEST_EVIDENCE_PATHS: dict[str, tuple[tuple[str, ...], ...]] = {
"certificate_type": (("certificate_type",),),
"period": (("period",),),
"baseline_dataset.sha256": (
("baseline_dataset", "sha256"),
("enhanced_cps", "sha256"),
("baseline_dataset_sha256",),
("enhanced_cps_sha256",),
),
"target_db.sha256": (
("target_db", "sha256"),
("targets_db", "sha256"),
("policyengine_targets_db", "sha256"),
("target_db_sha256",),
("policyengine_targets_db_sha256",),
),
"policyengine_us_data.commit": (
("policyengine_us_data", "commit"),
("policyengine_us_data", "commit_sha"),
("policyengine_us_data_commit",),
("policyengine_us_data_commit_sha",),
),
"policyengine_us.version": (
("policyengine_us", "version"),
("policyengine_us_version",),
),
"target_surface.target_profile": (
("target_surface", "target_profile"),
("target_profile",),
),
"target_surface.target_scope": (
("target_surface", "target_scope"),
("target_scope",),
("target_scope_filter",),
),
"target_surface.target_count": (
("target_surface", "target_count"),
("target_count",),
),
"target_surface.target_names_sha256": (
("target_surface", "target_names_sha256"),
("target_names_sha256",),
),
"scoring_config.sha256": (
("scoring_config", "sha256"),
("scoring_config_sha256",),
),
}


def _comparison_bad_targets() -> tuple[str, ...]:
return tuple(
Expand Down Expand Up @@ -236,6 +285,7 @@ def build_sound_ecps_replacement_comparison(
assert_baseline_sane: bool = True,
baseline_sanity_mode: str = "msre",
max_baseline_unweighted_msre: float = 2.0,
benchmark_manifest_path: str | Path | None = None,
) -> dict[str, Any]:
"""Build a release-contract eCPS comparison payload.

Expand Down Expand Up @@ -541,6 +591,14 @@ def build_sound_ecps_replacement_comparison(
baseline_sanity=baseline_sanity,
score_summary=score_summary,
)
benchmark_manifest = (
_assert_certificate_matches_benchmark_manifest(
frozen_baseline_certificate,
benchmark_manifest_path,
)
if benchmark_manifest_path is not None
else None
)
payload = {
"schema_version": 1,
"metric": "sound_ecps_replacement_comparison",
Expand All @@ -567,6 +625,7 @@ def build_sound_ecps_replacement_comparison(
"protected_family_losses": protected_family_losses,
},
"frozen_ecps_baseline_certificate": frozen_baseline_certificate,
"benchmark_manifest": benchmark_manifest,
"entity_structure": {
"candidate_source": _entity_structure_summary(
candidate_path,
Expand Down Expand Up @@ -1806,6 +1865,102 @@ def _frozen_ecps_baseline_certificate(
}


def _assert_certificate_matches_benchmark_manifest(
certificate: dict[str, Any],
benchmark_manifest_path: str | Path,
) -> dict[str, Any]:
"""Fail before writing a comparison if it is not on the pinned surface."""

manifest_path = Path(benchmark_manifest_path).expanduser().resolve()
if not manifest_path.exists():
raise FileNotFoundError(f"benchmark manifest not found: {manifest_path}")
try:
manifest = json.loads(manifest_path.read_text())
except json.JSONDecodeError as exc:
raise ComparisonGateError(
f"benchmark manifest is not valid JSON: {manifest_path}: {exc}"
) from exc

manifest_evidence = _benchmark_manifest_evidence(manifest)
certificate_evidence = _benchmark_manifest_evidence(certificate)
missing = [
field
for field, value in manifest_evidence.items()
if not _valid_benchmark_evidence_value(field, value)
]
mismatches = [
{
"field": field,
"benchmark_manifest_value": expected,
"certificate_value": certificate_evidence.get(field),
}
for field, expected in manifest_evidence.items()
if _valid_benchmark_evidence_value(field, expected)
and str(certificate_evidence.get(field)) != str(expected)
]
if missing or mismatches:
problems = []
if missing:
problems.append("missing manifest evidence: " + ", ".join(missing))
if mismatches:
problems.append(
"mismatched evidence: "
+ ", ".join(str(item["field"]) for item in mismatches)
)
raise ComparisonGateError(
"Comparison does not match pinned production eCPS benchmark manifest; "
+ "; ".join(problems)
)

return {
**_dataset_descriptor(manifest_path),
"certificate_match": {
"status": "passed",
"checked_evidence": manifest_evidence,
},
}


def _benchmark_manifest_evidence(payload: dict[str, Any]) -> dict[str, Any]:
return {
field: _first_nested_path_value(payload, paths)
for field, paths in _BENCHMARK_MANIFEST_EVIDENCE_PATHS.items()
}


def _first_nested_path_value(
payload: dict[str, Any],
paths: tuple[tuple[str, ...], ...],
) -> Any:
for path in paths:
current: Any = payload
for part in path:
if not isinstance(current, dict) or part not in current:
current = None
break
current = current[part]
if current is not None:
return current
return None


def _valid_benchmark_evidence_value(field: str, value: Any) -> bool:
if value is None:
return False
if isinstance(value, str) and not value:
return False
if field.endswith(".sha256"):
return isinstance(value, str) and len(value) == 64
if field.endswith(".commit"):
return isinstance(value, str) and len(value) >= 7
if field.endswith(".target_count"):
try:
return int(value) > 0
except (TypeError, ValueError):
return False
return True


def _installed_policyengine_us_version() -> str:
try:
return importlib.metadata.version("policyengine-us")
Expand Down Expand Up @@ -1904,6 +2059,15 @@ def main(argv: list[str] | None = None) -> int:
"storage folder so the target surface is pinned."
),
)
parser.add_argument(
"--benchmark-manifest",
help=(
"Pre-existing frozen benchmark manifest to enforce before writing "
"the comparison. The comparison certificate must match its baseline "
"H5, target DB, target surface, scorer checkout, PolicyEngine-US "
"version, and scoring config."
),
)
parser.add_argument("--skip-tax-expenditure-targets", action="store_true")
parser.add_argument(
"--target-scope",
Expand Down Expand Up @@ -1996,6 +2160,7 @@ def main(argv: list[str] | None = None) -> int:
assert_baseline_sane=args.assert_baseline_sane,
baseline_sanity_mode=args.baseline_sanity_mode,
max_baseline_unweighted_msre=args.max_baseline_unweighted_msre,
benchmark_manifest_path=args.benchmark_manifest,
)
print(str(written))
return 0
Expand Down
90 changes: 90 additions & 0 deletions tests/pipelines/test_ecps_replacement_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,96 @@ def fail_exact_rescore(**_kwargs):
assert payload["score"]["score_source"] == "refit_loss_matrix"


def test_sound_ecps_replacement_comparison_enforces_benchmark_manifest(
monkeypatch,
tmp_path,
):
candidate = _write_minimal_policyengine_dataset(tmp_path / "candidate.h5")
baseline = _write_minimal_policyengine_dataset(tmp_path / "baseline.h5")
targets_db = tmp_path / "policyengine_targets.db"
targets_db.write_bytes(b"pinned target database")
scorer_repo = _write_clean_git_repo(tmp_path / "policyengine-us-data")
monkeypatch.setattr(ecps, "_extract_pe_native_loss_inputs", _fake_loss_inputs)
monkeypatch.setattr(ecps, "compute_us_pe_native_support_audit", _fake_support_audit)

bootstrap = ecps.build_sound_ecps_replacement_comparison(
candidate_dataset_path=candidate,
baseline_dataset_path=baseline,
output_dir=tmp_path / "bootstrap",
optimizer_max_iter=50,
policyengine_targets_db_path=targets_db,
policyengine_us_data_repo=scorer_repo,
)
benchmark_manifest = tmp_path / "benchmark_manifest.json"
_benchmark_manifest(
benchmark_manifest,
certificate=bootstrap["frozen_ecps_baseline_certificate"],
)

payload = ecps.build_sound_ecps_replacement_comparison(
candidate_dataset_path=candidate,
baseline_dataset_path=baseline,
output_dir=tmp_path / "comparison",
optimizer_max_iter=50,
policyengine_targets_db_path=targets_db,
policyengine_us_data_repo=scorer_repo,
benchmark_manifest_path=benchmark_manifest,
)

assert payload["benchmark_manifest"]["certificate_match"]["status"] == "passed"
assert (
payload["benchmark_manifest"]["certificate_match"]["checked_evidence"][
"target_surface.target_names_sha256"
]
== bootstrap["frozen_ecps_baseline_certificate"]["target_surface"][
"target_names_sha256"
]
)


def test_sound_ecps_replacement_comparison_rejects_benchmark_manifest_mismatch(
monkeypatch,
tmp_path,
):
candidate = _write_minimal_policyengine_dataset(tmp_path / "candidate.h5")
baseline = _write_minimal_policyengine_dataset(tmp_path / "baseline.h5")
targets_db = tmp_path / "policyengine_targets.db"
targets_db.write_bytes(b"pinned target database")
scorer_repo = _write_clean_git_repo(tmp_path / "policyengine-us-data")
monkeypatch.setattr(ecps, "_extract_pe_native_loss_inputs", _fake_loss_inputs)
monkeypatch.setattr(ecps, "compute_us_pe_native_support_audit", _fake_support_audit)

bootstrap = ecps.build_sound_ecps_replacement_comparison(
candidate_dataset_path=candidate,
baseline_dataset_path=baseline,
output_dir=tmp_path / "bootstrap",
optimizer_max_iter=50,
policyengine_targets_db_path=targets_db,
policyengine_us_data_repo=scorer_repo,
)
benchmark_manifest = tmp_path / "benchmark_manifest.json"
_benchmark_manifest(
benchmark_manifest,
certificate=bootstrap["frozen_ecps_baseline_certificate"],
)
manifest = json.loads(benchmark_manifest.read_text())
manifest["target_surface"]["target_names_sha256"] = "f" * 64
benchmark_manifest.write_text(json.dumps(manifest, indent=2, sort_keys=True))

with pytest.raises(ecps.ComparisonGateError) as excinfo:
ecps.build_sound_ecps_replacement_comparison(
candidate_dataset_path=candidate,
baseline_dataset_path=baseline,
output_dir=tmp_path / "comparison",
optimizer_max_iter=50,
policyengine_targets_db_path=targets_db,
policyengine_us_data_repo=scorer_repo,
benchmark_manifest_path=benchmark_manifest,
)

assert "target_surface.target_names_sha256" in str(excinfo.value)


def test_sound_ecps_replacement_comparison_writes_target_diagnostics_sidecar(
monkeypatch,
tmp_path,
Expand Down
Loading