Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions diff_diff/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,7 @@ def solve_poisson(
tol: float = 1e-8,
init_beta: Optional[np.ndarray] = None,
rank_deficient_action: str = "warn",
weights: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Poisson IRLS (Newton-Raphson with log link).

Expand All @@ -2389,6 +2390,9 @@ def solve_poisson(
log(mean(y)) to improve convergence for large-scale outcomes.
rank_deficient_action : {"warn", "error", "silent"}
How to handle rank-deficient design matrices. Mirrors solve_ols/solve_logit.
weights : (n,) optional observation weights (e.g. survey sampling weights).
When provided, the weighted pseudo-log-likelihood is maximised:
score = X'(w*(y - mu)), Hessian = X'diag(w*mu)X.

Returns
-------
Expand All @@ -2397,6 +2401,20 @@ def solve_poisson(
"""
n, k_orig = X.shape

# Validate weights (mirrors solve_logit validation)
if weights is not None:
weights = np.asarray(weights, dtype=np.float64)
if weights.shape != (n,):
raise ValueError(f"weights must have shape ({n},), got {weights.shape}")
if np.any(np.isnan(weights)):
raise ValueError("weights contain NaN values")
if np.any(~np.isfinite(weights)):
raise ValueError("weights contain Inf values")
if np.any(weights < 0):
raise ValueError("weights must be non-negative")
if np.sum(weights) <= 0:
raise ValueError("weights sum to zero — no observations have positive weight")

# Validate rank_deficient_action (same as solve_logit/solve_ols)
valid_actions = ("warn", "error", "silent")
if rank_deficient_action not in valid_actions:
Expand Down Expand Up @@ -2425,6 +2443,46 @@ def solve_poisson(
X = X[:, kept_cols]

n, k = X.shape

# Validate effective weighted sample when weights have zeros
# (mirrors solve_logit's positive-weight safeguards)
if weights is not None and np.any(weights == 0):
pos_mask = weights > 0
n_pos = int(np.sum(pos_mask))
X_eff = X[pos_mask]
eff_rank_info = _detect_rank_deficiency(X_eff)
if len(eff_rank_info[1]) > 0:
n_dropped_eff = len(eff_rank_info[1])
if rank_deficient_action == "error":
raise ValueError(
f"Effective (positive-weight) sample is rank-deficient: "
f"{n_dropped_eff} linearly dependent column(s). "
f"Cannot identify Poisson model on this subpopulation."
)
elif rank_deficient_action == "warn":
warnings.warn(
f"Effective (positive-weight) sample is rank-deficient: "
f"dropping {n_dropped_eff} column(s). Poisson estimates "
f"may be unreliable on this subpopulation.",
UserWarning,
stacklevel=2,
)
eff_dropped = set(int(d) for d in eff_rank_info[1])
eff_kept = np.array([i for i in range(k) if i not in eff_dropped])
X = X[:, eff_kept]
if len(dropped_cols) > 0:
kept_cols = kept_cols[eff_kept]
else:
kept_cols = eff_kept
dropped_cols = list(eff_dropped)
n, k = X.shape
if n_pos <= k:
raise ValueError(
f"Only {n_pos} positive-weight observation(s) for "
f"{k} parameters (after rank reduction). "
f"Cannot identify Poisson model."
)

if init_beta is not None:
beta = init_beta[kept_cols].copy() if len(dropped_cols) > 0 else init_beta.copy()
else:
Expand All @@ -2438,8 +2496,12 @@ def solve_poisson(
for _ in range(max_iter):
eta = np.clip(X @ beta, -500, 500)
mu = np.exp(eta)
score = X.T @ (y - mu) # gradient of log-likelihood
hess = X.T @ (mu[:, None] * X) # -Hessian = X'WX, W=diag(mu)
if weights is not None:
score = X.T @ (weights * (y - mu))
hess = X.T @ ((weights * mu)[:, None] * X)
else:
score = X.T @ (y - mu)
hess = X.T @ (mu[:, None] * X)
try:
delta = np.linalg.solve(hess + 1e-12 * np.eye(k), score)
except np.linalg.LinAlgError:
Expand Down
Loading
Loading