From 85feb9412696e92d781a07b7ac32246140699651 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 6 Jun 2026 08:35:36 +0100 Subject: [PATCH] Handle overlapping regime-aware donor predictors --- src/microplex_us/pipelines/donor_imputers.py | 21 ++++++---- .../test_regime_aware_donor_imputer.py | 40 +++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/src/microplex_us/pipelines/donor_imputers.py b/src/microplex_us/pipelines/donor_imputers.py index 8e545ea..4a90e8b 100644 --- a/src/microplex_us/pipelines/donor_imputers.py +++ b/src/microplex_us/pipelines/donor_imputers.py @@ -180,6 +180,7 @@ def __init__( self.seed = int(seed) self._fitted: dict[str, Any] = {} self._fitted_columns: tuple[str, ...] = () + self._predictor_columns: tuple[str, ...] = () self._regimes: dict[str, str] = {} def fit( @@ -209,12 +210,15 @@ def fit( self._fitted = {} self._fitted_columns = () + self._predictor_columns = () self._regimes = {} - subset = ( - data[self.condition_vars + self.target_vars] - .replace([np.inf, -np.inf], np.nan) - .dropna() + target_vars = tuple(dict.fromkeys(self.target_vars)) + target_set = set(target_vars) + predictor_vars = tuple( + dict.fromkeys(var for var in self.condition_vars if var not in target_set) ) + fit_columns = tuple(dict.fromkeys((*predictor_vars, *target_vars))) + subset = data[list(fit_columns)].replace([np.inf, -np.inf], np.nan).dropna() if len(subset) < 25: return self @@ -227,10 +231,11 @@ def fit( ) fitted = wrapper.fit( subset, - predictors=list(self.condition_vars), - imputed_variables=list(self.target_vars), + predictors=list(predictor_vars), + imputed_variables=list(target_vars), ) - self._fitted_columns = tuple(self.target_vars) + self._fitted_columns = target_vars + self._predictor_columns = predictor_vars self._fitted = {column: fitted for column in self._fitted_columns} self._regimes = { column: wrapper.get_regime(column) for column in self._fitted_columns @@ -251,7 +256,7 @@ def generate( prediction_seed = self.seed if seed is None else int(seed) self._reset_prediction_rngs(fitted, seed=prediction_seed) - preds = fitted.predict(synthetic[self.condition_vars]) + preds = fitted.predict(synthetic[list(self._predictor_columns)]) for column in self.target_vars: if column in preds.columns: synthetic[column] = preds[column].to_numpy(dtype=float) diff --git a/tests/pipelines/test_regime_aware_donor_imputer.py b/tests/pipelines/test_regime_aware_donor_imputer.py index fbb8546..dd57fe5 100644 --- a/tests/pipelines/test_regime_aware_donor_imputer.py +++ b/tests/pipelines/test_regime_aware_donor_imputer.py @@ -150,6 +150,46 @@ def test_multi_target_fit_uses_one_chained_zero_inflated_imputer(self) -> None: second_bundle = second_fitted._per_variable["second_income_leaf"] assert second_bundle["predictors"] == ["age", "first_income_leaf"] + def test_target_predictor_overlap_is_owned_by_sequential_chain(self) -> None: + from microplex_us.pipelines.us import RegimeAwareDonorImputer + + rng = np.random.default_rng(2026060601) + n = 300 + age = rng.integers(18, 80, size=n).astype(float) + first = rng.normal(loc=age * 300.0, scale=1_000.0, size=n) + second = 0.5 * first + rng.normal(scale=250.0, size=n) + train = pd.DataFrame( + { + "age": age, + "first_income_leaf": first, + "second_income_leaf": second, + } + ) + + imputer = RegimeAwareDonorImputer( + condition_vars=["age", "first_income_leaf"], + target_vars=["first_income_leaf", "second_income_leaf"], + n_estimators=25, + ) + imputer.fit(train) + + fitted = imputer._fitted["first_income_leaf"] + first_bundle = fitted._per_variable["first_income_leaf"] + second_bundle = fitted._per_variable["second_income_leaf"] + assert first_bundle["predictors"] == ["age"] + assert second_bundle["predictors"] == ["age", "first_income_leaf"] + + conditions = pd.DataFrame({"age": [25.0, 45.0, 65.0]}) + synthetic = imputer.generate(conditions, seed=20260606) + assert list(synthetic.columns) == [ + "age", + "first_income_leaf", + "second_income_leaf", + ] + assert ( + synthetic[["first_income_leaf", "second_income_leaf"]].notna().all().all() + ) + def _fit_generate( self, n_train: int = 1500, n_gen: int = 2000, seed: int = 0 ) -> np.ndarray: