From ee660038921ff3a59f16538d9610c9ff96d19ec6 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 6 Jun 2026 09:43:55 +0100 Subject: [PATCH] Allow nonnumeric regime-aware donor targets --- src/microplex_us/pipelines/donor_imputers.py | 4 +- .../test_regime_aware_donor_imputer.py | 39 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/microplex_us/pipelines/donor_imputers.py b/src/microplex_us/pipelines/donor_imputers.py index 50bd4dd..a03e1e9 100644 --- a/src/microplex_us/pipelines/donor_imputers.py +++ b/src/microplex_us/pipelines/donor_imputers.py @@ -249,7 +249,9 @@ def fit( self._predictor_columns = predictor_vars self._fitted = {column: fitted for column in self._fitted_columns} self._regimes = { - column: wrapper.get_regime(column) for column in self._fitted_columns + column: regime + for column, regime in getattr(wrapper, "_regimes", {}).items() + if column in target_set } return self diff --git a/tests/pipelines/test_regime_aware_donor_imputer.py b/tests/pipelines/test_regime_aware_donor_imputer.py index 35e3938..b3dcdd9 100644 --- a/tests/pipelines/test_regime_aware_donor_imputer.py +++ b/tests/pipelines/test_regime_aware_donor_imputer.py @@ -245,6 +245,45 @@ def test_duplicate_input_columns_are_collapsed_before_microimpute(self) -> None: synthetic[["first_income_leaf", "second_income_leaf"]].notna().all().all() ) + def test_nonnumeric_targets_do_not_require_numeric_regimes(self) -> None: + from microplex_us.pipelines.us import RegimeAwareDonorImputer + + rng = np.random.default_rng(2026060603) + n = 300 + age = rng.integers(18, 80, size=n).astype(float) + income = rng.normal(loc=age * 250.0, scale=1_000.0, size=n) + train = pd.DataFrame( + { + "age": age, + "self_employment_income": income, + "business_is_sstb": income > np.median(income), + } + ) + + imputer = RegimeAwareDonorImputer( + condition_vars=["age"], + target_vars=["self_employment_income", "business_is_sstb"], + n_estimators=25, + ) + imputer.fit(train) + + assert "self_employment_income" in imputer._regimes + assert "business_is_sstb" not in imputer._regimes + + conditions = pd.DataFrame({"age": [25.0, 45.0, 65.0]}) + synthetic = imputer.generate(conditions, seed=20260606) + assert list(synthetic.columns) == [ + "age", + "self_employment_income", + "business_is_sstb", + ] + assert ( + synthetic[["self_employment_income", "business_is_sstb"]] + .notna() + .all() + .all() + ) + def _fit_generate( self, n_train: int = 1500, n_gen: int = 2000, seed: int = 0 ) -> np.ndarray: