Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cxr/covid19cxr_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,7 @@
" # Input size is inferred automatically from image dimensions\n",
" result = chefer_gen.attribute(\n",
" interpolate=True,\n",
" class_index=pred_class,\n",
" target_class_idx=pred_class,\n",
" **batch\n",
" )\n",
" attr_map = result[\"image\"] # Keyed by task schema's feature key\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/cxr/covid19cxr_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
# Compute attribution for each class in the prediction set
overlays = []
for class_idx in predset_class_indices:
attr_map = chefer.attribute(class_index=class_idx, **batch)["image"]
attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"]
_, _, overlay = visualize_image_attr(
image=batch["image"][0],
attribution=attr_map[0, 0],
Expand Down
2 changes: 1 addition & 1 deletion examples/cxr/covid19cxr_tutorial_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
# Compute attribution for each class in the prediction set
overlays = []
for class_idx in predset_class_indices:
attr_map = chefer.attribute(class_index=class_idx, **batch)["image"]
attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"]
_, _, overlay = visualize_image_attr(
image=batch["image"][0],
attribution=attr_map[0, 0],
Expand Down
46 changes: 43 additions & 3 deletions pyhealth/interpret/methods/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from abc import ABC, abstractmethod
from typing import Dict, cast
from typing import Dict, Optional, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -138,8 +138,12 @@ def attribute(
by the task's ``input_schema``.
- Label key (optional): Ground truth labels, may be needed
by some methods for loss computation.
- ``class_index`` (optional): Target class for attribution.
If not provided, uses the predicted class.
- ``target_class_idx`` (Optional[int]): Target class for
attribution. For binary classification (single logit
output), this is a no-op because there is only one
output. For multi-class or multi-label classification,
specifies which class index to explain. If not provided,
uses the argmax of logits.
- Additional method-specific parameters (e.g., ``baseline``,
``steps``, ``interpolate``).

Expand Down Expand Up @@ -207,6 +211,42 @@ def attribute(
"""
pass

def _resolve_target_indices(
self,
logits: torch.Tensor,
target_class_idx: Optional[int],
) -> torch.Tensor:
"""Resolve target class indices for attribution.

Returns a ``[batch]`` tensor of class indices identifying which
logit to explain. All prediction modes share this single code
path:

* **Binary** (single logit): ``target_class_idx`` is a no-op
because there is only one output. Always returns zeros
(index 0).
* **Multi-class / multi-label**: uses ``target_class_idx`` if
given, otherwise the argmax of logits.

Args:
logits: Model output logits, shape ``[batch, num_classes]``.
target_class_idx: Optional user-specified class index.

Returns:
``torch.LongTensor`` of shape ``[batch]``.
"""
if logits.shape[-1] == 1:
# Single logit output — nothing to select.
return torch.zeros(
logits.shape[0], device=logits.device, dtype=torch.long,
)
if target_class_idx is not None:
return torch.full(
(logits.shape[0],), target_class_idx,
device=logits.device, dtype=torch.long,
)
return logits.argmax(dim=-1)

def _prediction_mode(self) -> str:
"""Resolve the prediction mode from the model.

Expand Down
19 changes: 8 additions & 11 deletions pyhealth/interpret/methods/chefer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class CheferRelevance(BaseInterpreter):
>>> print(attributions["conditions"].shape) # [batch, num_tokens]
>>>
>>> # Optional: attribute to a specific class (e.g., class 1)
>>> attributions = interpreter.attribute(class_index=1, **batch)
>>> attributions = interpreter.attribute(target_class_idx=1, **batch)
"""

def __init__(self, model: BaseModel):
Expand All @@ -139,14 +139,16 @@ def __init__(self, model: BaseModel):

def attribute(
self,
class_index: Optional[int] = None,
target_class_idx: Optional[int] = None,
**data,
) -> Dict[str, torch.Tensor]:
"""Compute relevance scores for each input token.

Args:
class_index: Target class index to compute attribution for.
If None (default), uses the model's predicted class.
target_class_idx: Target class index to compute attribution for.
If None (default), uses the argmax of model output.
For binary classification (single logit output), this is
a no-op because there is only one output.
**data: Input data from dataloader batch containing feature
keys and label key.

Expand All @@ -163,15 +165,10 @@ def attribute(
self.model.set_attention_hooks(False)

# --- 2. Backward from target class ---
if class_index is None:
class_index_t = torch.argmax(logits, dim=-1)
elif isinstance(class_index, int):
class_index_t = torch.tensor(class_index)
else:
class_index_t = class_index
target_indices = self._resolve_target_indices(logits, target_class_idx)

one_hot = F.one_hot(
class_index_t.detach().clone(), logits.size(1)
target_indices.detach().clone(), logits.size(1)
).float()
one_hot = one_hot.requires_grad_(True)
scalar = torch.sum(one_hot.to(logits.device) * logits)
Expand Down
79 changes: 13 additions & 66 deletions pyhealth/interpret/methods/deeplift.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,35 +411,7 @@ def attribute(
with torch.no_grad():
base_logits = self.model.forward(**inputs)["logit"]

mode = self._prediction_mode()
if mode == "binary":
if target_class_idx is not None:
target = torch.tensor([target_class_idx], device=device)
else:
target = (torch.sigmoid(base_logits) > 0.5).long()
elif mode == "multiclass":
if target_class_idx is not None:
target = F.one_hot(
torch.tensor(target_class_idx, device=device),
num_classes=base_logits.shape[-1],
).float()
else:
target = torch.argmax(base_logits, dim=-1)
target = F.one_hot(
target, num_classes=base_logits.shape[-1]
).float()
elif mode == "multilabel":
if target_class_idx is not None:
target = F.one_hot(
torch.tensor(target_class_idx, device=device),
num_classes=base_logits.shape[-1],
).float()
else:
target = (torch.sigmoid(base_logits) > 0.5).float()
else:
raise ValueError(
"Unsupported prediction mode for DeepLIFT attribution."
)
target_indices = self._resolve_target_indices(base_logits, target_class_idx)

# Generate baselines
if baseline is None:
Expand Down Expand Up @@ -491,7 +463,7 @@ def attribute(
inputs=inputs,
xs=values,
bs=baselines,
target=target,
target_indices=target_indices,
token_keys=token_keys,
)

Expand All @@ -505,7 +477,7 @@ def _deeplift(
inputs: Dict[str, tuple[torch.Tensor, ...]],
xs: Dict[str, torch.Tensor],
bs: Dict[str, torch.Tensor],
target: torch.Tensor,
target_indices: torch.Tensor,
token_keys: set[str],
) -> Dict[str, torch.Tensor]:
"""Core DeepLIFT computation using the Rescale rule.
Expand All @@ -517,8 +489,7 @@ def _deeplift(
inputs: Full input tuples keyed by feature name.
xs: Input values (embedded if token features with use_embeddings).
bs: Baseline values (embedded if token features with use_embeddings).
target: Target tensor for computing the scalar output to
differentiate (one-hot for multiclass, class idx for binary).
target_indices: [batch] tensor of target class indices.
token_keys: Set of feature keys that are token (already embedded).

Returns:
Expand Down Expand Up @@ -590,9 +561,9 @@ def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, to
baseline_logits = baseline_output["logit"] # type: ignore[index]

# Compute per-sample target outputs
target_output = self._compute_target_output(logits, target)
target_output = self._compute_target_output(logits, target_indices)
baseline_target_output = self._compute_target_output(
baseline_logits, target
baseline_logits, target_indices
)

self.model.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -626,46 +597,22 @@ def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, to
def _compute_target_output(
self,
logits: torch.Tensor,
target: torch.Tensor,
target_indices: torch.Tensor,
) -> torch.Tensor:
"""Compute per-sample target output.

Creates a differentiable per-sample scalar from the model logits
that, when summed and differentiated, gives the gradient of the
target class logit w.r.t. the input.
Selects the target-class logit for each sample.

Args:
logits: Model output logits, shape [batch, num_classes] or
[batch, 1].
target: Target tensor. For binary: [batch] or [1] with 0/1
class indices. For multiclass/multilabel: [batch, num_classes]
one-hot or multi-hot tensor.
logits: Model output logits, shape [batch, num_classes].
target_indices: [batch] tensor of target class indices.

Returns:
Per-sample target output tensor, shape [batch].
"""
target_f = target.to(logits.device).float()
mode = self._prediction_mode()

if mode == "binary":
while target_f.dim() < logits.dim():
target_f = target_f.unsqueeze(-1)
target_f = target_f.expand_as(logits)
signs = 2.0 * target_f - 1.0
# Sum over all dims except batch to get per-sample scalar
per_sample = (signs * logits)
if per_sample.dim() > 1:
per_sample = per_sample.sum(dim=tuple(range(1, per_sample.dim())))
return per_sample
else:
# multiclass or multilabel: target is one-hot/multi-hot
while target_f.dim() < logits.dim():
target_f = target_f.unsqueeze(0)
target_f = target_f.expand_as(logits)
per_sample = (target_f * logits)
if per_sample.dim() > 1:
per_sample = per_sample.sum(dim=tuple(range(1, per_sample.dim())))
return per_sample
return logits.gather(
1, target_indices.unsqueeze(1)
).squeeze(1)

# ------------------------------------------------------------------
# Completeness enforcement
Expand Down
69 changes: 13 additions & 56 deletions pyhealth/interpret/methods/gim.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,9 @@ def attribute(
"""Compute GIM attributions for a batch.

Args:
target_class_idx: Target class index for attribution. If None,
uses the model's predicted class.
target_class_idx: Target class index for attribution. For
binary classification (single logit output), this is a
no-op. If None, uses the argmax of model output.
**kwargs: Input data dictionary from a dataloader batch containing
feature tensors or tuples of tensors for each modality, plus
optional label tensors.
Expand Down Expand Up @@ -419,35 +420,7 @@ def attribute(
with torch.no_grad():
base_logits = self.model.forward(**inputs)["logit"]

mode = self._prediction_mode()
if mode == "binary":
if target_class_idx is not None:
target = torch.tensor([target_class_idx], device=device)
else:
target = (torch.sigmoid(base_logits) > 0.5).long()
elif mode == "multiclass":
if target_class_idx is not None:
target = F.one_hot(
torch.tensor(target_class_idx, device=device),
num_classes=base_logits.shape[-1],
).float()
else:
target = torch.argmax(base_logits, dim=-1)
target = F.one_hot(
target, num_classes=base_logits.shape[-1]
).float()
elif mode == "multilabel":
if target_class_idx is not None:
target = F.one_hot(
torch.tensor(target_class_idx, device=device),
num_classes=base_logits.shape[-1],
).float()
else:
target = (torch.sigmoid(base_logits) > 0.5).float()
else:
raise ValueError(
"Unsupported prediction mode for GIM attribution."
)
target_indices = self._resolve_target_indices(base_logits, target_class_idx)

# Embed values and detach for gradient attribution.
# Split features by type using is_token():
Expand Down Expand Up @@ -521,7 +494,7 @@ def attribute(
output = self.model.forward_from_embedding(**forward_inputs)

logits = output["logit"] # type: ignore[assignment]
target_output = self._compute_target_output(logits, target)
target_output = self._compute_target_output(logits, target_indices)

# Clear stale gradients, then backpropagate through the
# GIM-modified computational graph.
Expand Down Expand Up @@ -552,39 +525,23 @@ def attribute(
def _compute_target_output(
self,
logits: torch.Tensor,
target: torch.Tensor,
target_indices: torch.Tensor,
) -> torch.Tensor:
"""Compute scalar target output for backpropagation.

Creates a differentiable scalar from the model logits that,
when differentiated, gives the gradient of the target class
logit w.r.t. the input.
Selects the target-class logit for each sample and sums over
the batch to produce a single differentiable scalar.

Args:
logits: Model output logits, shape [batch, num_classes] or
[batch, 1].
target: Target tensor. For binary: [batch] or [1] with 0/1
class indices. For multiclass/multilabel: [batch, num_classes]
one-hot or multi-hot tensor.
logits: Model output logits, shape [batch, num_classes].
target_indices: [batch] tensor of target class indices.

Returns:
Scalar tensor for backpropagation.
"""
target_f = target.to(logits.device).float()
mode = self._prediction_mode()

if mode == "binary":
while target_f.dim() < logits.dim():
target_f = target_f.unsqueeze(-1)
target_f = target_f.expand_as(logits)
signs = 2.0 * target_f - 1.0
return (signs * logits).sum()
else:
# multiclass or multilabel: target is one-hot/multi-hot
while target_f.dim() < logits.dim():
target_f = target_f.unsqueeze(0)
target_f = target_f.expand_as(logits)
return (target_f * logits).sum()
return logits.gather(
1, target_indices.unsqueeze(1)
).squeeze(1).sum()

# ------------------------------------------------------------------
# Utility helpers
Expand Down
Loading
Loading