diff --git a/examples/cxr/covid19cxr_tutorial.ipynb b/examples/cxr/covid19cxr_tutorial.ipynb index 2a04844c5..ec10756a1 100644 --- a/examples/cxr/covid19cxr_tutorial.ipynb +++ b/examples/cxr/covid19cxr_tutorial.ipynb @@ -1339,7 +1339,7 @@ " # Input size is inferred automatically from image dimensions\n", " result = chefer_gen.attribute(\n", " interpolate=True,\n", - " class_index=pred_class,\n", + " target_class_idx=pred_class,\n", " **batch\n", " )\n", " attr_map = result[\"image\"] # Keyed by task schema's feature key\n", diff --git a/examples/cxr/covid19cxr_tutorial.py b/examples/cxr/covid19cxr_tutorial.py index 0f24f4b58..06b134f93 100644 --- a/examples/cxr/covid19cxr_tutorial.py +++ b/examples/cxr/covid19cxr_tutorial.py @@ -131,7 +131,7 @@ # Compute attribution for each class in the prediction set overlays = [] for class_idx in predset_class_indices: - attr_map = chefer.attribute(class_index=class_idx, **batch)["image"] + attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"] _, _, overlay = visualize_image_attr( image=batch["image"][0], attribution=attr_map[0, 0], diff --git a/examples/cxr/covid19cxr_tutorial_display.py b/examples/cxr/covid19cxr_tutorial_display.py index 3f6a33b82..f3a4acddb 100644 --- a/examples/cxr/covid19cxr_tutorial_display.py +++ b/examples/cxr/covid19cxr_tutorial_display.py @@ -128,7 +128,7 @@ # Compute attribution for each class in the prediction set overlays = [] for class_idx in predset_class_indices: - attr_map = chefer.attribute(class_index=class_idx, **batch)["image"] + attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"] _, _, overlay = visualize_image_attr( image=batch["image"][0], attribution=attr_map[0, 0], diff --git a/pyhealth/interpret/methods/base_interpreter.py b/pyhealth/interpret/methods/base_interpreter.py index de75c897a..fb17690ae 100644 --- a/pyhealth/interpret/methods/base_interpreter.py +++ b/pyhealth/interpret/methods/base_interpreter.py @@ -10,7 +10,7 @@ """ from abc import ABC, abstractmethod -from typing import Dict, cast +from typing import Dict, Optional, cast import torch import torch.nn as nn @@ -138,8 +138,12 @@ def attribute( by the task's ``input_schema``. - Label key (optional): Ground truth labels, may be needed by some methods for loss computation. - - ``class_index`` (optional): Target class for attribution. - If not provided, uses the predicted class. + - ``target_class_idx`` (Optional[int]): Target class for + attribution. For binary classification (single logit + output), this is a no-op because there is only one + output. For multi-class or multi-label classification, + specifies which class index to explain. If not provided, + uses the argmax of logits. - Additional method-specific parameters (e.g., ``baseline``, ``steps``, ``interpolate``). @@ -207,6 +211,42 @@ def attribute( """ pass + def _resolve_target_indices( + self, + logits: torch.Tensor, + target_class_idx: Optional[int], + ) -> torch.Tensor: + """Resolve target class indices for attribution. + + Returns a ``[batch]`` tensor of class indices identifying which + logit to explain. All prediction modes share this single code + path: + + * **Binary** (single logit): ``target_class_idx`` is a no-op + because there is only one output. Always returns zeros + (index 0). + * **Multi-class / multi-label**: uses ``target_class_idx`` if + given, otherwise the argmax of logits. + + Args: + logits: Model output logits, shape ``[batch, num_classes]``. + target_class_idx: Optional user-specified class index. + + Returns: + ``torch.LongTensor`` of shape ``[batch]``. + """ + if logits.shape[-1] == 1: + # Single logit output — nothing to select. + return torch.zeros( + logits.shape[0], device=logits.device, dtype=torch.long, + ) + if target_class_idx is not None: + return torch.full( + (logits.shape[0],), target_class_idx, + device=logits.device, dtype=torch.long, + ) + return logits.argmax(dim=-1) + def _prediction_mode(self) -> str: """Resolve the prediction mode from the model. diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 5efca68eb..26ce6ffd2 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -128,7 +128,7 @@ class CheferRelevance(BaseInterpreter): >>> print(attributions["conditions"].shape) # [batch, num_tokens] >>> >>> # Optional: attribute to a specific class (e.g., class 1) - >>> attributions = interpreter.attribute(class_index=1, **batch) + >>> attributions = interpreter.attribute(target_class_idx=1, **batch) """ def __init__(self, model: BaseModel): @@ -139,14 +139,16 @@ def __init__(self, model: BaseModel): def attribute( self, - class_index: Optional[int] = None, + target_class_idx: Optional[int] = None, **data, ) -> Dict[str, torch.Tensor]: """Compute relevance scores for each input token. Args: - class_index: Target class index to compute attribution for. - If None (default), uses the model's predicted class. + target_class_idx: Target class index to compute attribution for. + If None (default), uses the argmax of model output. + For binary classification (single logit output), this is + a no-op because there is only one output. **data: Input data from dataloader batch containing feature keys and label key. @@ -163,15 +165,10 @@ def attribute( self.model.set_attention_hooks(False) # --- 2. Backward from target class --- - if class_index is None: - class_index_t = torch.argmax(logits, dim=-1) - elif isinstance(class_index, int): - class_index_t = torch.tensor(class_index) - else: - class_index_t = class_index + target_indices = self._resolve_target_indices(logits, target_class_idx) one_hot = F.one_hot( - class_index_t.detach().clone(), logits.size(1) + target_indices.detach().clone(), logits.size(1) ).float() one_hot = one_hot.requires_grad_(True) scalar = torch.sum(one_hot.to(logits.device) * logits) diff --git a/pyhealth/interpret/methods/deeplift.py b/pyhealth/interpret/methods/deeplift.py index 29f99d795..8f00f8a8b 100644 --- a/pyhealth/interpret/methods/deeplift.py +++ b/pyhealth/interpret/methods/deeplift.py @@ -411,35 +411,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - if mode == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif mode == "multiclass": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = torch.argmax(base_logits, dim=-1) - target = F.one_hot( - target, num_classes=base_logits.shape[-1] - ).float() - elif mode == "multilabel": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = (torch.sigmoid(base_logits) > 0.5).float() - else: - raise ValueError( - "Unsupported prediction mode for DeepLIFT attribution." - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Generate baselines if baseline is None: @@ -491,7 +463,7 @@ def attribute( inputs=inputs, xs=values, bs=baselines, - target=target, + target_indices=target_indices, token_keys=token_keys, ) @@ -505,7 +477,7 @@ def _deeplift( inputs: Dict[str, tuple[torch.Tensor, ...]], xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], - target: torch.Tensor, + target_indices: torch.Tensor, token_keys: set[str], ) -> Dict[str, torch.Tensor]: """Core DeepLIFT computation using the Rescale rule. @@ -517,8 +489,7 @@ def _deeplift( inputs: Full input tuples keyed by feature name. xs: Input values (embedded if token features with use_embeddings). bs: Baseline values (embedded if token features with use_embeddings). - target: Target tensor for computing the scalar output to - differentiate (one-hot for multiclass, class idx for binary). + target_indices: [batch] tensor of target class indices. token_keys: Set of feature keys that are token (already embedded). Returns: @@ -590,9 +561,9 @@ def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, to baseline_logits = baseline_output["logit"] # type: ignore[index] # Compute per-sample target outputs - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) baseline_target_output = self._compute_target_output( - baseline_logits, target + baseline_logits, target_indices ) self.model.zero_grad(set_to_none=True) @@ -626,46 +597,22 @@ def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, to def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Compute per-sample target output. - Creates a differentiable per-sample scalar from the model logits - that, when summed and differentiated, gives the gradient of the - target class logit w.r.t. the input. + Selects the target-class logit for each sample. Args: - logits: Model output logits, shape [batch, num_classes] or - [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + logits: Model output logits, shape [batch, num_classes]. + target_indices: [batch] tensor of target class indices. Returns: Per-sample target output tensor, shape [batch]. """ - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - # Sum over all dims except batch to get per-sample scalar - per_sample = (signs * logits) - if per_sample.dim() > 1: - per_sample = per_sample.sum(dim=tuple(range(1, per_sample.dim()))) - return per_sample - else: - # multiclass or multilabel: target is one-hot/multi-hot - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - per_sample = (target_f * logits) - if per_sample.dim() > 1: - per_sample = per_sample.sum(dim=tuple(range(1, per_sample.dim()))) - return per_sample + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1) # ------------------------------------------------------------------ # Completeness enforcement diff --git a/pyhealth/interpret/methods/gim.py b/pyhealth/interpret/methods/gim.py index d1b74b573..abbb9388a 100644 --- a/pyhealth/interpret/methods/gim.py +++ b/pyhealth/interpret/methods/gim.py @@ -365,8 +365,9 @@ def attribute( """Compute GIM attributions for a batch. Args: - target_class_idx: Target class index for attribution. If None, - uses the model's predicted class. + target_class_idx: Target class index for attribution. For + binary classification (single logit output), this is a + no-op. If None, uses the argmax of model output. **kwargs: Input data dictionary from a dataloader batch containing feature tensors or tuples of tensors for each modality, plus optional label tensors. @@ -419,35 +420,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - if mode == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif mode == "multiclass": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = torch.argmax(base_logits, dim=-1) - target = F.one_hot( - target, num_classes=base_logits.shape[-1] - ).float() - elif mode == "multilabel": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = (torch.sigmoid(base_logits) > 0.5).float() - else: - raise ValueError( - "Unsupported prediction mode for GIM attribution." - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Embed values and detach for gradient attribution. # Split features by type using is_token(): @@ -521,7 +494,7 @@ def attribute( output = self.model.forward_from_embedding(**forward_inputs) logits = output["logit"] # type: ignore[assignment] - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) # Clear stale gradients, then backpropagate through the # GIM-modified computational graph. @@ -552,39 +525,23 @@ def attribute( def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Compute scalar target output for backpropagation. - Creates a differentiable scalar from the model logits that, - when differentiated, gives the gradient of the target class - logit w.r.t. the input. + Selects the target-class logit for each sample and sums over + the batch to produce a single differentiable scalar. Args: - logits: Model output logits, shape [batch, num_classes] or - [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + logits: Model output logits, shape [batch, num_classes]. + target_indices: [batch] tensor of target class indices. Returns: Scalar tensor for backpropagation. """ - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - return (signs * logits).sum() - else: - # multiclass or multilabel: target is one-hot/multi-hot - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - return (target_f * logits).sum() + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1).sum() # ------------------------------------------------------------------ # Utility helpers diff --git a/pyhealth/interpret/methods/ig_gim.py b/pyhealth/interpret/methods/ig_gim.py index a33f5529b..49c2fa6c0 100644 --- a/pyhealth/interpret/methods/ig_gim.py +++ b/pyhealth/interpret/methods/ig_gim.py @@ -105,8 +105,9 @@ def attribute( near-zero for continuous features). steps: Number of interpolation steps. Overrides the instance default when given. - target_class_idx: Target class for attribution. ``None`` - uses the model's predicted class. + target_class_idx: Target class for attribution. For binary + classification (single logit output), this is a no-op. + ``None`` uses the argmax of model output. **kwargs: Dataloader batch (feature tensors + optional labels). Returns: @@ -155,10 +156,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - target = self._resolve_target( - base_logits, mode, target_class_idx, device - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # ----- baselines ----- if baseline is None: @@ -201,7 +199,7 @@ def attribute( xs=values, bs=baselines, steps=steps, - target=target, + target_indices=target_indices, token_keys=token_keys, continuous_keys=continuous_keys, ) @@ -217,7 +215,7 @@ def _integrated_gradients_gim( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], steps: int, - target: torch.Tensor, + target_indices: torch.Tensor, token_keys: set[str], continuous_keys: set[str], ) -> Dict[str, torch.Tensor]: @@ -280,7 +278,7 @@ def _integrated_gradients_gim( with _GIMHookContext(self.model, self.temperature): output = self.model.forward_from_embedding(**forward_inputs) logits = output["logit"] - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) self.model.zero_grad(set_to_none=True) target_output.backward(retain_graph=True) @@ -308,58 +306,16 @@ def _integrated_gradients_gim( # ------------------------------------------------------------------ # Target helpers (shared logic with IG / GIM) # ------------------------------------------------------------------ - @staticmethod - def _resolve_target( - logits: torch.Tensor, - mode: str, - target_class_idx: Optional[int], - device: torch.device, - ) -> torch.Tensor: - """Convert logits and optional class index into a target tensor.""" - if mode == "binary": - if target_class_idx is not None: - return torch.tensor([target_class_idx], device=device) - return (torch.sigmoid(logits) > 0.5).long() - - if mode == "multiclass": - if target_class_idx is not None: - return F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=logits.shape[-1], - ).float() - target = torch.argmax(logits, dim=-1) - return F.one_hot(target, num_classes=logits.shape[-1]).float() - - if mode == "multilabel": - if target_class_idx is not None: - return F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=logits.shape[-1], - ).float() - return (torch.sigmoid(logits) > 0.5).float() - - raise ValueError(f"Unsupported prediction mode: {mode}") def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Scalar target output for backpropagation.""" - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - return (signs * logits).sum() - else: - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - return (target_f * logits).sum() + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1).sum() # ------------------------------------------------------------------ # Baseline generation diff --git a/pyhealth/interpret/methods/integrated_gradients.py b/pyhealth/interpret/methods/integrated_gradients.py index a529a6f3f..249f5a5e9 100644 --- a/pyhealth/interpret/methods/integrated_gradients.py +++ b/pyhealth/interpret/methods/integrated_gradients.py @@ -217,9 +217,11 @@ def attribute( the integral. If None, uses self.steps (set during initialization). More steps lead to better approximation but slower computation. - target_class_idx: Target class index for attribution - computation. If None, uses the predicted class (argmax of - model output). + target_class_idx: Target class index for attribution. + For binary classification (single logit output), this is + a no-op because there is only one output. For multi-class + or multi-label, specifies which class to explain. If None, + uses the argmax of model output. **kwargs: Input data dictionary from a dataloader batch containing: - Feature keys (e.g., 'conditions', 'procedures'): @@ -324,35 +326,7 @@ def attribute( with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] - mode = self._prediction_mode() - if mode == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif mode == "multiclass": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = torch.argmax(base_logits, dim=-1) - target = F.one_hot( - target, num_classes=base_logits.shape[-1] - ).float() - elif mode == "multilabel": - if target_class_idx is not None: - target = F.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ).float() - else: - target = (torch.sigmoid(base_logits) > 0.5).float() - else: - raise ValueError( - "Unsupported prediction mode for Integrated Gradients attribution." - ) + target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Generate baselines if baseline is None: @@ -405,7 +379,7 @@ def attribute( xs=values, bs=baselines, steps=steps, - target=target, + target_indices=target_indices, ) return self._map_to_input_shapes(attributions, shapes) @@ -419,7 +393,7 @@ def _integrated_gradients( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], steps: int, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Compute integrated gradients via Riemann sum approximation. @@ -438,8 +412,7 @@ def _integrated_gradients( xs: Input values (embedded if use_embeddings=True). bs: Baseline values (embedded if use_embeddings=True). steps: Number of interpolation steps. - target: Target tensor for computing the scalar output to - differentiate (one-hot for multiclass, class idx for binary). + target_indices: [batch] tensor of target class indices. Returns: Dictionary mapping feature keys to attribution tensors. @@ -513,7 +486,7 @@ def _integrated_gradients( logits = output["logit"] # Compute target output and backward pass - target_output = self._compute_target_output(logits, target) + target_output = self._compute_target_output(logits, target_indices) self.model.zero_grad() target_output.backward(retain_graph=True) @@ -550,41 +523,23 @@ def _integrated_gradients( def _compute_target_output( self, logits: torch.Tensor, - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Compute scalar target output for backpropagation. - Creates a differentiable scalar from the model logits that, - when differentiated, gives the gradient of the target class - logit w.r.t. the input. + Selects the target-class logit for each sample and sums over + the batch to produce a single differentiable scalar. Args: - logits: Model output logits, shape [batch, num_classes] or - [batch, 1]. - target: Target tensor. For binary: [batch] or [1] with 0/1 - class indices. For multiclass/multilabel: [batch, num_classes] - one-hot or multi-hot tensor. + logits: Model output logits, shape [batch, num_classes]. + target_indices: [batch] tensor of target class indices. Returns: Scalar tensor for backpropagation. """ - target_f = target.to(logits.device).float() - mode = self._prediction_mode() - - if mode == "binary": - # target shape: [1] or [batch, 1] with 0/1 values - # Convert to signs: 0 -> -1, 1 -> 1 - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(-1) - target_f = target_f.expand_as(logits) - signs = 2.0 * target_f - 1.0 - return (signs * logits).sum() - else: - # multiclass or multilabel: target is one-hot/multi-hot - while target_f.dim() < logits.dim(): - target_f = target_f.unsqueeze(0) - target_f = target_f.expand_as(logits) - return (target_f * logits).sum() + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1).sum() # ------------------------------------------------------------------ # Baseline generation diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 5176bfeaf..4e407fccf 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -250,25 +250,7 @@ def attribute( # Extract and prepare inputs base_logits = self.model.forward(**inputs)["logit"] - # Enforce target class selection for multi-class models to avoid class flipping - if self._prediction_mode() == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif self._prediction_mode() == "multiclass": - if target_class_idx is not None: - target = torch.nn.functional.one_hot(torch.tensor(target_class_idx, device=device), num_classes=base_logits.shape[-1]) - else: - target = torch.argmax(base_logits, dim=-1) - target = torch.nn.functional.one_hot(target, num_classes=base_logits.shape[-1]) - elif self._prediction_mode() == "multilabel": - if target_class_idx is not None: - target = torch.nn.functional.one_hot(torch.tensor(target_class_idx, device=device), num_classes=base_logits.shape[-1]) - else: - target = torch.sigmoid(base_logits) > 0.5 - else: - raise ValueError("Unsupported prediction mode for LIME attribution.") + target_indices = self._resolve_target_indices(base_logits, target_class_idx) if baseline is None: baselines = self._generate_baseline(values, use_embeddings=self.use_embeddings) @@ -309,7 +291,7 @@ def attribute( xs=values, bs=baselines, n_features=n_features, - target=target, + target_indices=target_indices, ) return self._map_to_input_shapes(out, shapes) @@ -323,7 +305,7 @@ def _compute_lime( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], n_features: dict[str, int], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Compute LIME coefficients using interpretable linear model. @@ -376,7 +358,7 @@ def _compute_lime( pred = self._evaluate_sample( inputs, perturb, - target, + target_indices, ) # Create perturbed sample for each batch item @@ -475,15 +457,19 @@ def _evaluate_sample( self, inputs: dict[str, tuple[torch.Tensor, ...]], perturb: dict[str, torch.Tensor], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Evaluate model prediction for a perturbed sample. + Returns the model's prediction for the target class, so the + weighted linear regression approximates the model's actual + output (not a distance to a label). + Args: inputs: Original input tuples (used for non-value fields like time/mask). perturb: Perturbed sample tensors. Token features are already embedded; continuous features are still in raw space. - target: Target class tensor. + target_indices: [batch] tensor of target class indices. Returns: Model prediction for the perturbed sample, shape (batch_size, ). @@ -523,8 +509,10 @@ def _evaluate_sample( # model's regular forward pass handle embedding internally. logits = self.model.forward(**inputs)["logit"] - # Reduce to [batch_size, ] by taking absolute difference from target class logit - return (target - logits).abs().mean(dim=tuple(range(1, logits.ndim))) + # Extract the target class prediction (logits, not label distances) + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1) def _compute_similarity( self, diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index 46d40a977..df52ea732 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -220,33 +220,7 @@ def attribute( # Extract and prepare inputs base_logits = self.model.forward(**inputs)["logit"] - # Enforce target class selection for multi-class models to avoid class flipping - if self._prediction_mode() == "binary": - if target_class_idx is not None: - target = torch.tensor([target_class_idx], device=device) - else: - target = (torch.sigmoid(base_logits) > 0.5).long() - elif self._prediction_mode() == "multiclass": - if target_class_idx is not None: - target = torch.nn.functional.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ) - else: - target = torch.argmax(base_logits, dim=-1) - target = torch.nn.functional.one_hot( - target, num_classes=base_logits.shape[-1] - ) - elif self._prediction_mode() == "multilabel": - if target_class_idx is not None: - target = torch.nn.functional.one_hot( - torch.tensor(target_class_idx, device=device), - num_classes=base_logits.shape[-1], - ) - else: - target = torch.sigmoid(base_logits) > 0.5 - else: - raise ValueError("Unsupported prediction mode for SHAP attribution.") + target_indices = self._resolve_target_indices(base_logits, target_class_idx) if baseline is None: baselines = self._generate_background_samples( @@ -295,7 +269,7 @@ def attribute( xs=values, bs=baselines, n_features=n_features, - target=target, + target_indices=target_indices, ) return self._map_to_input_shapes(out, shapes) @@ -309,7 +283,7 @@ def _compute_kernel_shap( xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], n_features: dict[str, int], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Compute SHAP values using the Kernel SHAP approximation method. @@ -325,7 +299,7 @@ def _compute_kernel_shap( xs: Dictionary of input values (or embeddings). bs: Dictionary of baseline values (or embeddings). n_features: Dictionary mapping feature keys to feature counts. - target: Target tensor for prediction comparison. + target_indices: [batch] tensor of target class indices. Returns: Dictionary mapping feature keys to SHAP value tensors. @@ -353,7 +327,7 @@ def _compute_kernel_shap( coalition, keys, n_features, batch_size ) perturb = self._create_perturbed_sample(xs, bs, gates) - pred = self._evaluate_sample(inputs, perturb, target) + pred = self._evaluate_sample(inputs, perturb, target_indices) coalition_vectors.append(coalition.float()) coalition_preds.append(pred.detach()) @@ -374,7 +348,7 @@ def _compute_kernel_shap( coalition, keys, n_features, batch_size ) perturb = self._create_perturbed_sample(xs, bs, gates) - pred = self._evaluate_sample(inputs, perturb, target) + pred = self._evaluate_sample(inputs, perturb, target_indices) coalition_vectors.append(coalition.float()) coalition_preds.append(pred.detach()) @@ -480,7 +454,7 @@ def _evaluate_sample( self, inputs: dict[str, tuple[torch.Tensor, ...]], perturb: dict[str, torch.Tensor], - target: torch.Tensor, + target_indices: torch.Tensor, ) -> torch.Tensor: """Evaluate model prediction for a perturbed sample. @@ -492,9 +466,7 @@ def _evaluate_sample( Args: inputs: Original input tuples from the dataloader. perturb: Dictionary of perturbed value tensors. - target: Target tensor used to select which class prediction to - return. For binary this is a 0/1 scalar or (batch,1) tensor; - for multiclass/multilabel it is a one-hot vector. + target_indices: [batch] tensor of target class indices. Returns: Target-class prediction scalar per batch item, shape (batch_size,). @@ -537,56 +509,9 @@ def _evaluate_sample( # model's regular forward pass handle embedding internally. logits = self.model.forward(**inputs)["logit"] - return self._extract_target_prediction(logits, target) - - def _extract_target_prediction( - self, - logits: torch.Tensor, - target: torch.Tensor, - ) -> torch.Tensor: - """Extract the model's prediction for the target class. - - Kernel SHAP decomposes f(x) ≈ φ₀ + Σ φᵢ zᵢ via weighted least squares. - Using **raw logits** (unbounded) rather than probabilities (bounded - [0, 1]) is critical: sigmoid compression squashes coalition differences - in the saturated regions, producing uniformly small SHAP values and - degraded feature rankings. - - Args: - logits: Raw model logits, shape (batch_size, n_classes) or - (batch_size, 1). - target: Target indicator. Binary: scalar/tensor with 0 or 1. - Multiclass: one-hot tensor. Multilabel: multi-hot tensor. - - Returns: - Scalar prediction per batch item, shape (batch_size,). - """ - mode = self._prediction_mode() - - if mode == "binary": - # Use raw logit — not sigmoid probability — to preserve the - # dynamic range that Kernel SHAP's linear decomposition needs. - logit = logits.squeeze(-1) # (batch,) - t = target.float() - if t.dim() > 1: - t = t.squeeze(-1) - # target=1 → logit (higher logit ⇒ more positive class) - # target=0 → −logit (higher value ⇒ more negative class) - return t * logit + (1 - t) * (-logit) - - elif mode == "multiclass": - # target is one-hot; dot-product extracts the target-class logit - return (target.float() * logits).sum(dim=-1) # (batch,) - - elif mode == "multilabel": - # target is multi-hot; average logits over active labels - t = target.float() - n_active = t.sum(dim=-1).clamp(min=1) # avoid div-by-zero - return (t * logits).sum(dim=-1) / n_active # (batch,) - - else: - # regression or unknown — just return the logit - return logits.squeeze(-1) + return logits.gather( + 1, target_indices.unsqueeze(1) + ).squeeze(1) # ------------------------------------------------------------------ # Weighted least squares solver diff --git a/pyhealth/metrics/interpretability/base.py b/pyhealth/metrics/interpretability/base.py index d3dc120f9..61299803f 100644 --- a/pyhealth/metrics/interpretability/base.py +++ b/pyhealth/metrics/interpretability/base.py @@ -111,7 +111,7 @@ def _detect_classifier_type(self): self.num_classes = 2 print("[RemovalBasedMetric] Detected BINARY classifier") print(" - Output shape: [batch, 1] with P(class=1)") - print(" - Only evaluates positive predictions (>=threshold)") + print(" - Evaluates both positive and negative predictions") elif mode == "multiclass": self.classifier_type = "multiclass" # Get num_classes from processor @@ -365,9 +365,11 @@ def compute( samples have value 0. Note: - For binary classifiers, the valid_mask indicates samples with - P(class=1) >= threshold (default 0.5). Use this mask to filter - scores during averaging or analysis. + For binary classifiers, all samples are evaluated + (both positive and negative predictions). For class 0 + predictions, attributions are negated internally so that + feature importance is measured relative to the predicted + class. """ # Get original predictions (returns 3 values) original_probs, pred_classes, original_class_probs = get_model_predictions( @@ -390,17 +392,27 @@ def compute( self._positive_threshold, ) - # For binary: determine which samples to evaluate + # All samples are evaluated regardless of predicted class + positive_mask = torch.ones( + batch_size, dtype=torch.bool, device=original_probs.device + ) + num_positive = batch_size + num_negative = 0 + + # For binary class 0 predictions, negate attributions so that + # "top features" become those most important for the predicted + # class (features with low class-1 attribution support class 0). if self.classifier_type == "binary": - positive_mask = pred_classes == 1 - num_positive = positive_mask.sum().item() - num_negative = (~positive_mask).sum().item() - else: - positive_mask = torch.ones( - batch_size, dtype=torch.bool, device=original_probs.device - ) - num_positive = batch_size - num_negative = 0 + neg_mask = pred_classes == 0 + if neg_mask.any(): + attributions = { + key: torch.where( + neg_mask.view(-1, *([1] * (attr.dim() - 1))), + -attr, + attr, + ) + for key, attr in attributions.items() + } # Debug output (if requested and returning per percentage) if debug and return_per_percentage: @@ -411,9 +423,9 @@ def compute( print(f"Classifier type: {self.classifier_type}") if self.classifier_type == "binary": - print(f"Positive class samples: {num_positive}") - print(f"Negative class samples: {num_negative}") - print("NOTE: Only computing metrics for POSITIVE class") + print(f"Positive class samples: {(pred_classes == 1).sum().item()}") + print(f"Negative class samples: {(pred_classes == 0).sum().item()}") + print("NOTE: Evaluating BOTH positive and negative predictions") print(f"Original probs shape: {original_probs.shape}") print(f"Predicted classes: {pred_classes.tolist()}") @@ -421,8 +433,8 @@ def compute( if self.classifier_type == "binary": print("\nOriginal probabilities P(class=1):") for i, prob in enumerate(original_probs): - status = "EVAL" if positive_mask[i] else "SKIP" - print(f" Sample {i} [{status}]: {prob.item():.6f}") + cls = pred_classes[i].item() + print(f" Sample {i} [class={cls}]: {prob.item():.6f}") else: print("\nOriginal probabilities (all classes):") for i, probs in enumerate(original_probs): @@ -430,11 +442,8 @@ def compute( print("\nOriginal probs for predicted class:") for i, prob in enumerate(original_class_probs): - if self.classifier_type == "binary": - status = "EVAL" if positive_mask[i] else "SKIP" - print(f" Sample {i} [{status}]: {prob.item():.6f}") - else: - print(f" Sample {i}: {prob.item():.6f}") + cls = pred_classes[i].item() + print(f" Sample {i} [class={cls}]: {prob.item():.6f}") # Store results per percentage if return_per_percentage: @@ -476,8 +485,8 @@ def compute( if self.classifier_type == "binary": print("\nAblated probabilities P(class=1):") for i, prob in enumerate(ablated_probs): - status = "EVAL" if positive_mask[i] else "SKIP" - print(f" Sample {i} [{status}]: {prob.item():.6f}") + cls = pred_classes[i].item() + print(f" Sample {i} [class={cls}]: {prob.item():.6f}") else: print("\nAblated probabilities (all classes):") for i, probs in enumerate(ablated_probs): @@ -485,15 +494,13 @@ def compute( print("\nProbability drops (original - ablated):") for i, drop in enumerate(prob_drop): - if drop == 0 and not positive_mask[i]: - print(f" Sample {i} [SKIP]: " f"0.000000 (negative class)") - else: - orig = original_class_probs[i].item() - abl = ablated_class_probs[i].item() - print( - f" Sample {i} [EVAL]: {drop.item():.6f} " - f"({orig:.6f} - {abl:.6f})" - ) + orig = original_class_probs[i].item() + abl = ablated_class_probs[i].item() + cls = pred_classes[i].item() + print( + f" Sample {i} [class={cls}]: {drop.item():.6f} " + f"({orig:.6f} - {abl:.6f})" + ) # Check for unexpected negative values evaluated_drops = prob_drop[positive_mask] diff --git a/pyhealth/metrics/interpretability/evaluator.py b/pyhealth/metrics/interpretability/evaluator.py index fe6f99f61..13f79c019 100644 --- a/pyhealth/metrics/interpretability/evaluator.py +++ b/pyhealth/metrics/interpretability/evaluator.py @@ -113,8 +113,9 @@ def evaluate( Example: {'comprehensiveness': {10: tensor(...), 20: ...}} Note: - For binary classifiers, valid_mask indicates samples with - P(class=1) >= threshold. Use: scores[valid_mask].mean() + For binary classifiers, all samples are evaluated + (both positive and negative predictions). + Use: scores[valid_mask].mean() Examples: >>> # Default: averaged scores @@ -245,13 +246,15 @@ def evaluate_attribution( ) # Accumulate statistics incrementally (no tensor storage) + first_metric = metrics[0] + batch_size = len(batch_results[first_metric][0]) + total_samples += batch_size + for metric_name in metrics: scores, valid_mask = batch_results[metric_name] # Track statistics efficiently - batch_size = len(scores) num_valid = valid_mask.sum().item() - total_samples += batch_size total_valid[metric_name] += num_valid # Update running sum (valid scores only) @@ -325,8 +328,8 @@ def evaluate_attribution( print(" * Important features not correctly identified") print(" * Consider checking attribution method") - valid_ratio = sum(total_valid.values()) / (len(metrics) * total_samples) - if valid_ratio < 0.1: + valid_ratio = sum(total_valid.values()) / (len(metrics) * total_samples) if total_samples > 0 else 0 + if valid_ratio < 0.1 and total_samples > 0: print(f"\n⚠ WARNING: Only {valid_ratio*100:.1f}% valid samples") print(" - Most predictions are negative class") print(" - Consider:") diff --git a/pyhealth/metrics/interpretability/utils.py b/pyhealth/metrics/interpretability/utils.py index 5206bbf1d..2eb00debb 100644 --- a/pyhealth/metrics/interpretability/utils.py +++ b/pyhealth/metrics/interpretability/utils.py @@ -61,8 +61,10 @@ def get_model_predictions( if classifier_type == "binary": # For binary: class 1 if P(class=1) >= threshold, else 0 pred_classes = (y_prob.squeeze(-1) >= positive_threshold).long() if pred_classes is None else pred_classes - # For binary, class_probs is P(class=1) - class_probs = y_prob.squeeze(-1) + # For binary, class_probs is P(predicted_class) + # class 1: P(class=1), class 0: 1 - P(class=1) + p1 = y_prob.squeeze(-1) + class_probs = torch.where(pred_classes == 1, p1, 1 - p1) else: # For multiclass/multilabel: argmax pred_classes = torch.argmax(y_prob, dim=-1) if pred_classes is None else pred_classes @@ -94,8 +96,8 @@ def create_validity_mask( batch_size = y_prob.shape[0] if classifier_type == "binary": - # For binary: valid = P(class=1) >= threshold - return y_prob.squeeze(-1) >= positive_threshold + # All samples are valid (evaluate both positive and negative predictions) + return torch.ones(batch_size, dtype=torch.bool, device=y_prob.device) else: # For multiclass/multilabel: all samples are valid return torch.ones(batch_size, dtype=torch.bool, device=y_prob.device) diff --git a/tests/core/test_deeplift.py b/tests/core/test_deeplift.py index 8fe832d82..29dbf8328 100644 --- a/tests/core/test_deeplift.py +++ b/tests/core/test_deeplift.py @@ -88,15 +88,15 @@ def test_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_attribution_with_target_class(self): - """Test attribution computation with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" dl = DeepLift(self.model) data_batch = next(iter(self.test_loader)) attr_class_0 = dl.attribute(**data_batch, target_class_idx=0) attr_class_1 = dl.attribute(**data_batch, target_class_idx=1) - # Attributions should differ for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"]) ) diff --git a/tests/core/test_gim.py b/tests/core/test_gim.py index 284931883..4e7cfb753 100644 --- a/tests/core/test_gim.py +++ b/tests/core/test_gim.py @@ -282,8 +282,8 @@ def _manual_token_attribution( output = model.forward_from_embedding(codes=tuple(parts), label=labels) logits = output["logit"] - # Binary mode: target class 0 → sign = -1 (2*0 - 1) - target = (-1.0 * logits).sum() + # Binary (single logit): _resolve_target_indices always selects index 0. + target = logits.sum() model.zero_grad(set_to_none=True) if embeddings.grad is not None: diff --git a/tests/core/test_ig_gim.py b/tests/core/test_ig_gim.py index 38565227d..fed1b436f 100644 --- a/tests/core/test_ig_gim.py +++ b/tests/core/test_ig_gim.py @@ -567,7 +567,7 @@ def test_auto_target_class(self): self.assertEqual(attrs["codes"].shape, self.tokens.shape) def test_different_target_classes(self): - """Attributions for different target classes should differ.""" + """For binary (single logit), target_class_idx is a no-op.""" model = _ToyModel() ig_gim = IntegratedGradientGIM(model, temperature=1.0, steps=10) @@ -578,9 +578,10 @@ def test_different_target_classes(self): codes=self.tokens, label=self.labels, target_class_idx=1, )["codes"] - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attrs_0, attrs_1), - "Different target classes should give different attributions", + "Single-logit binary: target_class_idx is a no-op", ) # ----- Temporal tuple inputs ----- diff --git a/tests/core/test_integrated_gradients.py b/tests/core/test_integrated_gradients.py index 7c9212280..78ec4fab5 100644 --- a/tests/core/test_integrated_gradients.py +++ b/tests/core/test_integrated_gradients.py @@ -93,7 +93,7 @@ def test_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_attribution_with_target_class(self): - """Test attribution computation with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" ig = IntegratedGradients(self.model) data_batch = next(iter(self.test_loader)) @@ -103,8 +103,8 @@ def test_attribution_with_target_class(self): # Compute attributions for class 1 attr_class_1 = ig.attribute(**data_batch, target_class_idx=1, steps=10) - # Check that attributions are different for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"]) ) @@ -307,7 +307,7 @@ def test_attribution_shapes_stagenet(self): self.assertEqual(attributions[key].shape, value_tensor.shape) def test_attribution_with_target_class_stagenet(self): - """Test attribution with specific target class for StageNet.""" + """For binary (single logit), target_class_idx is a no-op.""" ig = IntegratedGradients(self.model) data_batch = next(iter(self.test_loader)) @@ -315,8 +315,8 @@ def test_attribution_with_target_class_stagenet(self): attr_0 = ig.attribute(**data_batch, target_class_idx=0, steps=10) attr_1 = ig.attribute(**data_batch, target_class_idx=1, steps=10) - # Check that attributions differ for different classes - self.assertFalse(torch.allclose(attr_0["codes"], attr_1["codes"])) + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue(torch.allclose(attr_0["codes"], attr_1["codes"])) def test_attribution_values_finite_stagenet(self): """Test that StageNet attributions are finite.""" diff --git a/tests/core/test_lime.py b/tests/core/test_lime.py index ab061cb4a..55dcf0ff0 100644 --- a/tests/core/test_lime.py +++ b/tests/core/test_lime.py @@ -308,7 +308,7 @@ def test_target_class_idx_none(self): self.assertEqual(attributions["x"].shape, inputs.shape) def test_target_class_idx_specified(self): - """Should handle specific target class index.""" + """For binary (single logit), target_class_idx is a no-op.""" inputs = torch.tensor([[1.0, 0.5, -0.3]]) attr_class_0 = self.explainer.attribute( @@ -323,8 +323,8 @@ def test_target_class_idx_specified(self): target_class_idx=1, ) - # Attributions should differ for different classes - self.assertFalse(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) def test_attribution_values_are_finite(self): """Test that attribution values are finite (no NaN or Inf).""" @@ -777,7 +777,7 @@ def test_lime_mlp_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_lime_mlp_with_target_class(self): - """Test LIME attribution with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" explainer = LimeExplainer( self.model, use_embeddings=True, @@ -792,8 +792,8 @@ def test_lime_mlp_with_target_class(self): # Compute attributions for class 1 attr_class_1 = explainer.attribute(**data_batch, target_class_idx=1) - # Check that attributions are different for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"], atol=0.01) ) diff --git a/tests/core/test_shap.py b/tests/core/test_shap.py index 8c03a1c1f..25d703a0d 100644 --- a/tests/core/test_shap.py +++ b/tests/core/test_shap.py @@ -318,7 +318,7 @@ def test_target_class_idx_none(self): self.assertEqual(attributions["x"].shape, inputs.shape) def test_target_class_idx_specified(self): - """Should handle specific target class index.""" + """For binary (single logit), target_class_idx is a no-op.""" inputs = torch.tensor([[1.0, 0.5, -0.3]]) attr_class_0 = self.explainer.attribute( @@ -333,8 +333,8 @@ def test_target_class_idx_specified(self): target_class_idx=1, ) - # Attributions should differ for different classes - self.assertFalse(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) def test_attribution_values_are_finite(self): """Test that attribution values are finite (no NaN or Inf).""" @@ -696,7 +696,7 @@ def test_shap_mlp_basic_attribution(self): self.assertIsInstance(attributions["procedures"], torch.Tensor) def test_shap_mlp_with_target_class(self): - """Test SHAP attribution with specific target class.""" + """For binary (single logit), target_class_idx is a no-op.""" explainer = ShapExplainer(self.model) data_batch = next(iter(self.test_loader)) @@ -706,8 +706,8 @@ def test_shap_mlp_with_target_class(self): # Compute attributions for class 1 attr_class_1 = explainer.attribute(**data_batch, target_class_idx=1) - # Check that attributions are different for different classes - self.assertFalse( + # Single-logit binary: target_class_idx is a no-op, both should match + self.assertTrue( torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"], atol=0.01) ) @@ -1023,23 +1023,13 @@ def test_kernel_weight_computation_edge_cases(self): self.assertTrue(torch.isfinite(weight_partial)) def test_target_prediction_extraction_binary(self): - """Test target prediction extraction for binary classification.""" - explainer = ShapExplainer( - self.model, - use_embeddings=False, - ) + """Test target prediction for binary classification via gather.""" # Single logit (binary classification) logits_binary = torch.tensor([[0.5], [1.0], [-0.3]]) - - # Class 1 target tensor - target_1 = torch.tensor([1, 1, 1]) - pred_1 = explainer._extract_target_prediction(logits_binary, target_1) - self.assertEqual(pred_1.shape, (3,)) - - # Class 0 target tensor - target_0 = torch.tensor([0, 0, 0]) - pred_0 = explainer._extract_target_prediction(logits_binary, target_0) - self.assertEqual(pred_0.shape, (3,)) + target_indices = torch.zeros(3, dtype=torch.long) + pred = logits_binary.gather(1, target_indices.unsqueeze(1)).squeeze(1) + self.assertEqual(pred.shape, (3,)) + torch.testing.assert_close(pred, torch.tensor([0.5, 1.0, -0.3])) def test_shape_mapping_simple(self): """Test mapping SHAP values back to input shapes.""" diff --git a/tests/core/test_transformer.py b/tests/core/test_transformer.py index a5fa6cc6b..e74468fc5 100644 --- a/tests/core/test_transformer.py +++ b/tests/core/test_transformer.py @@ -154,8 +154,7 @@ def test_chefer_relevance(self): relevance = CheferRelevance(model) # Test with explicitly specified class index - data_batch["class_index"] = 0 - scores = relevance.get_relevance_matrix(**data_batch) + scores = relevance.get_relevance_matrix(target_class_idx=0, **data_batch) # Verify that scores are returned for all feature keys self.assertIsInstance(scores, dict) @@ -167,9 +166,8 @@ def test_chefer_relevance(self): # Verify scores are non-negative (due to clamping in relevance computation) self.assertTrue(torch.all(scores[feature_key] >= 0)) - # Test without specifying class_index (should use predicted class) - data_batch_no_idx = {k: v for k, v in data_batch.items() if k != "class_index"} - scores_auto = relevance.get_relevance_matrix(**data_batch_no_idx) + # Test without specifying target_class_idx (should use predicted class) + scores_auto = relevance.get_relevance_matrix(**data_batch) # Verify that scores are returned self.assertIsInstance(scores_auto, dict)