From c320a4cd0b5ed0a70edbcbd8d2eb365ca4214c26 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 6 Jun 2026 11:02:24 +0100 Subject: [PATCH] Bound regime-aware QRF donor fits --- pyproject.toml | 2 +- src/microplex_us/pipelines/donor_imputers.py | 26 +++++++++++++- src/microplex_us/pipelines/us.py | 2 ++ .../test_regime_aware_donor_imputer.py | 35 +++++++++++++++++++ tests/pipelines/test_us.py | 4 +++ uv.lock | 8 ++--- 6 files changed, 71 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 856c48b..1229c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ hf = [ "huggingface_hub>=0.24", ] policyengine = [ - "microimpute==3.1.0; python_full_version >= '3.12' and python_full_version < '3.15'", + "microimpute==3.1.1; python_full_version >= '3.12' and python_full_version < '3.15'", "policyengine-us==1.715.2; python_version >= '3.11' and python_version < '3.15'", "spm-calculator>=0.3.1", # Standalone tax-unit construction engine (the extraction of eCPS's diff --git a/src/microplex_us/pipelines/donor_imputers.py b/src/microplex_us/pipelines/donor_imputers.py index a03e1e9..9241dc1 100644 --- a/src/microplex_us/pipelines/donor_imputers.py +++ b/src/microplex_us/pipelines/donor_imputers.py @@ -178,12 +178,18 @@ def __init__( condition_vars: list[str], target_vars: list[str], n_estimators: int = 100, + max_train_samples: int | None = 50_000, classifier_type: str = "hist_gb", seed: int = 42, ) -> None: self.condition_vars = list(condition_vars) self.target_vars = list(target_vars) self.n_estimators = int(n_estimators) + if max_train_samples is not None and int(max_train_samples) < 1: + raise ValueError("max_train_samples must be a positive integer") + self.max_train_samples = ( + None if max_train_samples is None else int(max_train_samples) + ) self.classifier_type = str(classifier_type) self.seed = int(seed) self._fitted: dict[str, Any] = {} @@ -191,6 +197,24 @@ def __init__( self._predictor_columns: tuple[str, ...] = () self._regimes: dict[str, str] = {} + def _configured_qrf_class(self, qrf_class: type[Any]) -> type[Any]: + n_estimators = self.n_estimators + max_train_samples = self.max_train_samples + + class ConfiguredQRF(qrf_class): + def __init__(self, *args: Any, **kwargs: Any) -> None: + if max_train_samples is not None: + kwargs.setdefault("max_train_samples", max_train_samples) + super().__init__(*args, **kwargs) + + def fit(self, *args: Any, **kwargs: Any) -> Any: + kwargs.setdefault("n_estimators", n_estimators) + kwargs.setdefault("n_jobs", -1) + return super().fit(*args, **kwargs) + + ConfiguredQRF.__name__ = "ConfiguredRegimeAwareQRF" + return ConfiguredQRF + def fit( self, data: pd.DataFrame, @@ -234,7 +258,7 @@ def fit( return self wrapper = ZeroInflatedImputer( - base_imputer_class=QRF, + base_imputer_class=self._configured_qrf_class(QRF), base_imputer_kwargs={}, classifier_type=self.classifier_type, sequential=True, diff --git a/src/microplex_us/pipelines/us.py b/src/microplex_us/pipelines/us.py index f351a75..06da948 100644 --- a/src/microplex_us/pipelines/us.py +++ b/src/microplex_us/pipelines/us.py @@ -2186,6 +2186,7 @@ class USMicroplexBuildConfig: donor_imputer_hidden_dim: int = 32 donor_imputer_backend: Literal["maf", "qrf", "zi_qrf", "regime_aware"] = "maf" donor_imputer_qrf_n_estimators: int = 100 + donor_imputer_qrf_max_train_samples: int | None = 50_000 donor_imputer_qrf_zero_threshold: float = 0.05 donor_imputer_condition_selection: Literal[ "all_shared", @@ -6012,6 +6013,7 @@ def _build_donor_imputer( condition_vars=condition_vars, target_vars=list(target_vars), n_estimators=self.config.donor_imputer_qrf_n_estimators, + max_train_samples=self.config.donor_imputer_qrf_max_train_samples, seed=self.config.random_seed, ) zero_inflated_vars = ( diff --git a/tests/pipelines/test_regime_aware_donor_imputer.py b/tests/pipelines/test_regime_aware_donor_imputer.py index b3dcdd9..1299df2 100644 --- a/tests/pipelines/test_regime_aware_donor_imputer.py +++ b/tests/pipelines/test_regime_aware_donor_imputer.py @@ -120,6 +120,41 @@ def test_factory_dispatches_to_regime_aware(self) -> None: class TestRegimeAwareFitGenerate: """Fit/generate contract and tripartite-specific guarantees.""" + def test_qrf_budget_reaches_microimpute_base(self, monkeypatch) -> None: + from microplex_us.pipelines.us import RegimeAwareDonorImputer + + captured: dict[str, object] = {} + + class FakeQRF: + def __init__(self, *args, **kwargs): + captured["init_args"] = args + captured["init_kwargs"] = kwargs + + def fit(self, *args, **kwargs): + captured["fit_args"] = args + captured["fit_kwargs"] = kwargs + return self + + monkeypatch.setattr("microimpute.models.qrf.QRF", FakeQRF) + + train = pd.DataFrame( + { + "age": [25.0, 35.0, 45.0, 55.0] * 10, + "income_leaf": [100.0, 200.0, 300.0, 400.0] * 10, + } + ) + imputer = RegimeAwareDonorImputer( + condition_vars=["age"], + target_vars=["income_leaf"], + n_estimators=7, + max_train_samples=17, + ) + imputer.fit(train) + + assert captured["init_kwargs"]["max_train_samples"] == 17 + assert captured["fit_kwargs"]["n_estimators"] == 7 + assert captured["fit_kwargs"]["n_jobs"] == -1 + def test_multi_target_fit_uses_one_chained_zero_inflated_imputer(self) -> None: from microplex_us.pipelines.us import RegimeAwareDonorImputer diff --git a/tests/pipelines/test_us.py b/tests/pipelines/test_us.py index 6a91503..cca7b5d 100644 --- a/tests/pipelines/test_us.py +++ b/tests/pipelines/test_us.py @@ -3600,6 +3600,8 @@ def __init__(self, **kwargs): USMicroplexBuildConfig( n_synthetic=4, donor_imputer_backend="regime_aware", + donor_imputer_qrf_n_estimators=77, + donor_imputer_qrf_max_train_samples=1234, ) ) regime_pipeline._build_donor_imputer( @@ -3619,6 +3621,8 @@ def __init__(self, **kwargs): ) assert "nonnegative_vars" not in captured["regime_aware"] + assert captured["regime_aware"]["n_estimators"] == 77 + assert captured["regime_aware"]["max_train_samples"] == 1234 assert captured["zi_qrf"]["nonnegative_vars"] == set() assert captured["zi_qrf"]["zero_inflated_vars"] == { "partnership_s_corp_income", diff --git a/uv.lock b/uv.lock index f50adf7..c61afca 100644 --- a/uv.lock +++ b/uv.lock @@ -1124,7 +1124,7 @@ wheels = [ [[package]] name = "microimpute" -version = "3.1.0" +version = "3.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "joblib" }, @@ -1141,9 +1141,9 @@ dependencies = [ { name = "statsmodels" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/e4/d9f5f06eaab4cd09d7700190da784a16d7e57e95802db8542658f97bb7ef/microimpute-3.1.0.tar.gz", hash = "sha256:1e7d0f69e99390127755ef10db94cfa5d91029b1075e6a1c51a1cc6168e15336", size = 145868, upload-time = "2026-06-06T06:44:14.005Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/a0/15e25e78b7fa48d100f52d210290f2ba5820ebc47e4859748a7d89a3cae9/microimpute-3.1.1.tar.gz", hash = "sha256:70aa5bd28e7cef254695b8317c0f88e11e39ea204e0f6362cb33a94163438c3e", size = 146197, upload-time = "2026-06-06T09:54:07.848Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/82/1d4ae0872aab0bafdb167bf4de6fa232dd599bc710425e8875e20ab23149/microimpute-3.1.0-py3-none-any.whl", hash = "sha256:4e4fcd21e9c4b1c5928075ba23858b4df47bdecba97fe2ae6b8540027ca68184", size = 127161, upload-time = "2026-06-06T06:44:12.867Z" }, + { url = "https://files.pythonhosted.org/packages/7f/d9/2b1ae246461f88388e2a43ff29f0f3477ac89cc52b11abaf775089b7553d/microimpute-3.1.1-py3-none-any.whl", hash = "sha256:c6a8fcb2ab129486fce48299cf89901b12632d69c0483251bae5f8d68a0d326b", size = 127432, upload-time = "2026-06-06T09:54:06.717Z" }, ] [[package]] @@ -1211,7 +1211,7 @@ requires-dist = [ { name = "h5py", specifier = ">=3.10" }, { name = "huggingface-hub", marker = "extra == 'hf'", specifier = ">=0.24" }, { name = "jupyter-book", marker = "extra == 'docs'", specifier = ">=0.15,<0.16" }, - { name = "microimpute", marker = "python_full_version >= '3.12' and python_full_version < '3.15' and extra == 'policyengine'", specifier = "==3.1.0" }, + { name = "microimpute", marker = "python_full_version >= '3.12' and python_full_version < '3.15' and extra == 'policyengine'", specifier = "==3.1.1" }, { name = "microplex", extras = ["calibrate"], git = "https://github.com/PolicyEngine/microplex.git?rev=1e0627182f9df40aacd7043c96956c2895bf9d30" }, { name = "microunit", marker = "extra == 'policyengine'", specifier = ">=0.1.0" }, { name = "policyengine-us", marker = "python_full_version >= '3.11' and python_full_version < '3.15' and extra == 'policyengine'", specifier = "==1.715.2" },