From d0483b37349af8350c63645d9b5f4a7d6c7fecd4 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 6 Jun 2026 09:02:08 +0100 Subject: [PATCH] Enforce benchmark manifests in eCPS comparisons --- .../pipelines/ecps_replacement_comparison.py | 165 ++++++++++++++++++ .../test_ecps_replacement_comparison.py | 90 ++++++++++ 2 files changed, 255 insertions(+) diff --git a/src/microplex_us/pipelines/ecps_replacement_comparison.py b/src/microplex_us/pipelines/ecps_replacement_comparison.py index dcf76cb..1fc023e 100644 --- a/src/microplex_us/pipelines/ecps_replacement_comparison.py +++ b/src/microplex_us/pipelines/ecps_replacement_comparison.py @@ -66,6 +66,55 @@ "employment_income_before_lsr", ) +_BENCHMARK_MANIFEST_EVIDENCE_PATHS: dict[str, tuple[tuple[str, ...], ...]] = { + "certificate_type": (("certificate_type",),), + "period": (("period",),), + "baseline_dataset.sha256": ( + ("baseline_dataset", "sha256"), + ("enhanced_cps", "sha256"), + ("baseline_dataset_sha256",), + ("enhanced_cps_sha256",), + ), + "target_db.sha256": ( + ("target_db", "sha256"), + ("targets_db", "sha256"), + ("policyengine_targets_db", "sha256"), + ("target_db_sha256",), + ("policyengine_targets_db_sha256",), + ), + "policyengine_us_data.commit": ( + ("policyengine_us_data", "commit"), + ("policyengine_us_data", "commit_sha"), + ("policyengine_us_data_commit",), + ("policyengine_us_data_commit_sha",), + ), + "policyengine_us.version": ( + ("policyengine_us", "version"), + ("policyengine_us_version",), + ), + "target_surface.target_profile": ( + ("target_surface", "target_profile"), + ("target_profile",), + ), + "target_surface.target_scope": ( + ("target_surface", "target_scope"), + ("target_scope",), + ("target_scope_filter",), + ), + "target_surface.target_count": ( + ("target_surface", "target_count"), + ("target_count",), + ), + "target_surface.target_names_sha256": ( + ("target_surface", "target_names_sha256"), + ("target_names_sha256",), + ), + "scoring_config.sha256": ( + ("scoring_config", "sha256"), + ("scoring_config_sha256",), + ), +} + def _comparison_bad_targets() -> tuple[str, ...]: return tuple( @@ -236,6 +285,7 @@ def build_sound_ecps_replacement_comparison( assert_baseline_sane: bool = True, baseline_sanity_mode: str = "msre", max_baseline_unweighted_msre: float = 2.0, + benchmark_manifest_path: str | Path | None = None, ) -> dict[str, Any]: """Build a release-contract eCPS comparison payload. @@ -541,6 +591,14 @@ def build_sound_ecps_replacement_comparison( baseline_sanity=baseline_sanity, score_summary=score_summary, ) + benchmark_manifest = ( + _assert_certificate_matches_benchmark_manifest( + frozen_baseline_certificate, + benchmark_manifest_path, + ) + if benchmark_manifest_path is not None + else None + ) payload = { "schema_version": 1, "metric": "sound_ecps_replacement_comparison", @@ -567,6 +625,7 @@ def build_sound_ecps_replacement_comparison( "protected_family_losses": protected_family_losses, }, "frozen_ecps_baseline_certificate": frozen_baseline_certificate, + "benchmark_manifest": benchmark_manifest, "entity_structure": { "candidate_source": _entity_structure_summary( candidate_path, @@ -1806,6 +1865,102 @@ def _frozen_ecps_baseline_certificate( } +def _assert_certificate_matches_benchmark_manifest( + certificate: dict[str, Any], + benchmark_manifest_path: str | Path, +) -> dict[str, Any]: + """Fail before writing a comparison if it is not on the pinned surface.""" + + manifest_path = Path(benchmark_manifest_path).expanduser().resolve() + if not manifest_path.exists(): + raise FileNotFoundError(f"benchmark manifest not found: {manifest_path}") + try: + manifest = json.loads(manifest_path.read_text()) + except json.JSONDecodeError as exc: + raise ComparisonGateError( + f"benchmark manifest is not valid JSON: {manifest_path}: {exc}" + ) from exc + + manifest_evidence = _benchmark_manifest_evidence(manifest) + certificate_evidence = _benchmark_manifest_evidence(certificate) + missing = [ + field + for field, value in manifest_evidence.items() + if not _valid_benchmark_evidence_value(field, value) + ] + mismatches = [ + { + "field": field, + "benchmark_manifest_value": expected, + "certificate_value": certificate_evidence.get(field), + } + for field, expected in manifest_evidence.items() + if _valid_benchmark_evidence_value(field, expected) + and str(certificate_evidence.get(field)) != str(expected) + ] + if missing or mismatches: + problems = [] + if missing: + problems.append("missing manifest evidence: " + ", ".join(missing)) + if mismatches: + problems.append( + "mismatched evidence: " + + ", ".join(str(item["field"]) for item in mismatches) + ) + raise ComparisonGateError( + "Comparison does not match pinned production eCPS benchmark manifest; " + + "; ".join(problems) + ) + + return { + **_dataset_descriptor(manifest_path), + "certificate_match": { + "status": "passed", + "checked_evidence": manifest_evidence, + }, + } + + +def _benchmark_manifest_evidence(payload: dict[str, Any]) -> dict[str, Any]: + return { + field: _first_nested_path_value(payload, paths) + for field, paths in _BENCHMARK_MANIFEST_EVIDENCE_PATHS.items() + } + + +def _first_nested_path_value( + payload: dict[str, Any], + paths: tuple[tuple[str, ...], ...], +) -> Any: + for path in paths: + current: Any = payload + for part in path: + if not isinstance(current, dict) or part not in current: + current = None + break + current = current[part] + if current is not None: + return current + return None + + +def _valid_benchmark_evidence_value(field: str, value: Any) -> bool: + if value is None: + return False + if isinstance(value, str) and not value: + return False + if field.endswith(".sha256"): + return isinstance(value, str) and len(value) == 64 + if field.endswith(".commit"): + return isinstance(value, str) and len(value) >= 7 + if field.endswith(".target_count"): + try: + return int(value) > 0 + except (TypeError, ValueError): + return False + return True + + def _installed_policyengine_us_version() -> str: try: return importlib.metadata.version("policyengine-us") @@ -1904,6 +2059,15 @@ def main(argv: list[str] | None = None) -> int: "storage folder so the target surface is pinned." ), ) + parser.add_argument( + "--benchmark-manifest", + help=( + "Pre-existing frozen benchmark manifest to enforce before writing " + "the comparison. The comparison certificate must match its baseline " + "H5, target DB, target surface, scorer checkout, PolicyEngine-US " + "version, and scoring config." + ), + ) parser.add_argument("--skip-tax-expenditure-targets", action="store_true") parser.add_argument( "--target-scope", @@ -1996,6 +2160,7 @@ def main(argv: list[str] | None = None) -> int: assert_baseline_sane=args.assert_baseline_sane, baseline_sanity_mode=args.baseline_sanity_mode, max_baseline_unweighted_msre=args.max_baseline_unweighted_msre, + benchmark_manifest_path=args.benchmark_manifest, ) print(str(written)) return 0 diff --git a/tests/pipelines/test_ecps_replacement_comparison.py b/tests/pipelines/test_ecps_replacement_comparison.py index 406bdda..68f0cef 100644 --- a/tests/pipelines/test_ecps_replacement_comparison.py +++ b/tests/pipelines/test_ecps_replacement_comparison.py @@ -622,6 +622,96 @@ def fail_exact_rescore(**_kwargs): assert payload["score"]["score_source"] == "refit_loss_matrix" +def test_sound_ecps_replacement_comparison_enforces_benchmark_manifest( + monkeypatch, + tmp_path, +): + candidate = _write_minimal_policyengine_dataset(tmp_path / "candidate.h5") + baseline = _write_minimal_policyengine_dataset(tmp_path / "baseline.h5") + targets_db = tmp_path / "policyengine_targets.db" + targets_db.write_bytes(b"pinned target database") + scorer_repo = _write_clean_git_repo(tmp_path / "policyengine-us-data") + monkeypatch.setattr(ecps, "_extract_pe_native_loss_inputs", _fake_loss_inputs) + monkeypatch.setattr(ecps, "compute_us_pe_native_support_audit", _fake_support_audit) + + bootstrap = ecps.build_sound_ecps_replacement_comparison( + candidate_dataset_path=candidate, + baseline_dataset_path=baseline, + output_dir=tmp_path / "bootstrap", + optimizer_max_iter=50, + policyengine_targets_db_path=targets_db, + policyengine_us_data_repo=scorer_repo, + ) + benchmark_manifest = tmp_path / "benchmark_manifest.json" + _benchmark_manifest( + benchmark_manifest, + certificate=bootstrap["frozen_ecps_baseline_certificate"], + ) + + payload = ecps.build_sound_ecps_replacement_comparison( + candidate_dataset_path=candidate, + baseline_dataset_path=baseline, + output_dir=tmp_path / "comparison", + optimizer_max_iter=50, + policyengine_targets_db_path=targets_db, + policyengine_us_data_repo=scorer_repo, + benchmark_manifest_path=benchmark_manifest, + ) + + assert payload["benchmark_manifest"]["certificate_match"]["status"] == "passed" + assert ( + payload["benchmark_manifest"]["certificate_match"]["checked_evidence"][ + "target_surface.target_names_sha256" + ] + == bootstrap["frozen_ecps_baseline_certificate"]["target_surface"][ + "target_names_sha256" + ] + ) + + +def test_sound_ecps_replacement_comparison_rejects_benchmark_manifest_mismatch( + monkeypatch, + tmp_path, +): + candidate = _write_minimal_policyengine_dataset(tmp_path / "candidate.h5") + baseline = _write_minimal_policyengine_dataset(tmp_path / "baseline.h5") + targets_db = tmp_path / "policyengine_targets.db" + targets_db.write_bytes(b"pinned target database") + scorer_repo = _write_clean_git_repo(tmp_path / "policyengine-us-data") + monkeypatch.setattr(ecps, "_extract_pe_native_loss_inputs", _fake_loss_inputs) + monkeypatch.setattr(ecps, "compute_us_pe_native_support_audit", _fake_support_audit) + + bootstrap = ecps.build_sound_ecps_replacement_comparison( + candidate_dataset_path=candidate, + baseline_dataset_path=baseline, + output_dir=tmp_path / "bootstrap", + optimizer_max_iter=50, + policyengine_targets_db_path=targets_db, + policyengine_us_data_repo=scorer_repo, + ) + benchmark_manifest = tmp_path / "benchmark_manifest.json" + _benchmark_manifest( + benchmark_manifest, + certificate=bootstrap["frozen_ecps_baseline_certificate"], + ) + manifest = json.loads(benchmark_manifest.read_text()) + manifest["target_surface"]["target_names_sha256"] = "f" * 64 + benchmark_manifest.write_text(json.dumps(manifest, indent=2, sort_keys=True)) + + with pytest.raises(ecps.ComparisonGateError) as excinfo: + ecps.build_sound_ecps_replacement_comparison( + candidate_dataset_path=candidate, + baseline_dataset_path=baseline, + output_dir=tmp_path / "comparison", + optimizer_max_iter=50, + policyengine_targets_db_path=targets_db, + policyengine_us_data_repo=scorer_repo, + benchmark_manifest_path=benchmark_manifest, + ) + + assert "target_surface.target_names_sha256" in str(excinfo.value) + + def test_sound_ecps_replacement_comparison_writes_target_diagnostics_sidecar( monkeypatch, tmp_path,