From 58b336c5cfd55558861bd036d18761e581329c22 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Fri, 12 Jun 2026 13:38:34 -0700 Subject: [PATCH] Differentiate non-missing GT pairwise bias --- gigl/nn/graph_transformer.py | 39 ++++++ gigl/transforms/graph_transformer.py | 124 +++++++++++++++--- tests/unit/nn/graph_transformer_test.py | 41 ++++++ .../unit/transforms/graph_transformer_test.py | 101 ++++++++++++++ 4 files changed, 284 insertions(+), 21 deletions(-) diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f9d8b345a..0825d80bb 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -640,12 +640,16 @@ def __init__( ) self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None + self._pairwise_nonmissing_attention_bias: Optional[nn.Parameter] = None if self._pairwise_attention_bias_attr_names: self._pairwise_pe_attention_bias_projection = nn.Linear( len(self._pairwise_attention_bias_attr_names), num_heads, bias=False, ) + self._pairwise_nonmissing_attention_bias = nn.Parameter( + torch.zeros(num_heads) + ) # Transformer encoder layers # Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU @@ -966,6 +970,7 @@ def _build_attention_bias( attention_bias_data: Dictionary containing optional PE tensors: - "anchor_bias": (batch, seq, num_anchor_attrs) or None - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None + - "pairwise_nonmissing_indices": (num_pairs, 3) or None Returns: Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len) @@ -1029,6 +1034,40 @@ def _build_attention_bias( ) # (batch, num_heads, seq, seq) attn_bias = attn_bias + pairwise_bias + pairwise_nonmissing_indices = attention_bias_data.get( + "pairwise_nonmissing_indices" + ) + if pairwise_nonmissing_indices is not None: + if self._pairwise_nonmissing_attention_bias is None: + raise ValueError( + "Pairwise nonmissing attention bias is not initialized." + ) + if pairwise_nonmissing_indices.numel() > 0: + if attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + pairwise_nonmissing_indices = pairwise_nonmissing_indices.to( + device=device, + dtype=torch.long, + ) + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + nonmissing_bias = self._pairwise_nonmissing_attention_bias.to( + dtype=attn_bias.dtype + ).view(1, -1) + attn_bias_by_position.index_put_( + ( + pairwise_nonmissing_indices[:, 0], + pairwise_nonmissing_indices[:, 1], + pairwise_nonmissing_indices[:, 2], + ), + nonmissing_bias.expand(pairwise_nonmissing_indices.size(0), -1), + accumulate=True, + ) + return attn_bias def _encode_and_readout( diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 602f95bde..3be7ed7ed 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -71,6 +71,7 @@ class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] + pairwise_nonmissing_indices: Optional[Tensor] token_input: Optional[TokenInputData] @@ -143,6 +144,9 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"pairwise_nonmissing_indices"`` shaped ``(num_pairs, 3)`` or None, + storing ``(batch_idx, row_pos, col_pos)`` coordinates for + nonmissing pairwise entries ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -306,11 +310,14 @@ def heterodata_to_graph_transformer_input( device=device, ) - pairwise_feature_sequences = _lookup_pairwise_relative_features( - node_index_sequences=node_index_sequences, - valid_mask=valid_mask, - csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, - device=device, + pairwise_feature_sequences, pairwise_nonmissing_indices = ( + _lookup_pairwise_relative_features( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, + attr_names=pairwise_bias_attr_names, + device=device, + ) ) anchor_bias_features = _compose_anchor_feature_tensor( @@ -332,6 +339,7 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "pairwise_nonmissing_indices": pairwise_nonmissing_indices, "token_input": token_input_features, }, ) @@ -798,8 +806,9 @@ def _lookup_pairwise_relative_features( node_index_sequences: Tensor, valid_mask: Tensor, csr_matrices: Optional[list[Tensor]], + attr_names: Optional[list[str]], device: torch.device, -) -> Optional[Tensor]: +) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Look up pairwise sparse values for each valid token pair in the sequence. @@ -815,13 +824,20 @@ def _lookup_pairwise_relative_features( node_index_sequences: (batch_size, max_seq_len) node indices for each sequence position valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions csr_matrices: List of sparse CSR matrices, each (num_nodes, num_nodes) + attr_names: Optional names for the pairwise attributes. Used only to + produce clearer error messages when multiple attrs disagree on + sparse support. device: Device for output tensor Returns: features: (batch_size, max_seq_len, max_seq_len, num_attrs) tensor where features[b, i, j, k] = csr_matrices[k][node_index_sequences[b, i], node_index_sequences[b, j]] for valid (i, j) pairs, 0.0 for padding positions. - Returns None if csr_matrices is empty. + nonmissing_indices: (num_nonmissing_pairs, 3) long tensor containing + ``(batch_idx, row_pos, col_pos)`` coordinates for valid diagonal + self pairs and valid sparse entries. Missing non-self pairs and + padding are omitted. + Returns (None, None) if csr_matrices is empty. Example: # batch_size=2, max_seq_len=3, num_attrs=1 (e.g., random_walk_se) @@ -844,7 +860,7 @@ def _lookup_pairwise_relative_features( # (pad) [0.0, 0.0, 0.0] """ if not csr_matrices: - return None + return None, None batch_size, max_seq_len = node_index_sequences.shape num_attrs = len(csr_matrices) @@ -856,23 +872,62 @@ def _lookup_pairwise_relative_features( pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) if not pair_valid_mask.any(): - return features - - row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) - col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) - - valid_row_indices = row_indices[pair_valid_mask] - valid_col_indices = col_indices[pair_valid_mask] + return features, torch.zeros((0, 3), dtype=torch.long, device=device) + valid_batch_indices, valid_row_positions, valid_col_positions = torch.nonzero( + pair_valid_mask, + as_tuple=True, + ) + valid_row_indices = node_index_sequences[ + valid_batch_indices, + valid_row_positions, + ] + valid_col_indices = node_index_sequences[ + valid_batch_indices, + valid_col_positions, + ] + self_pair_mask = valid_row_positions == valid_col_positions + + first_attr_name = attr_names[0] if attr_names else "attr_0" + nonmissing_support: Optional[Tensor] = None for attr_idx, pe_matrix in enumerate(csr_matrices): - pe_values = _lookup_csr_values( + pe_values, found_mask = _lookup_csr_values_and_found( csr_matrix=pe_matrix, row_indices=valid_row_indices, col_indices=valid_col_indices, ) - features[..., attr_idx][pair_valid_mask] = pe_values + features[ + valid_batch_indices, + valid_row_positions, + valid_col_positions, + attr_idx, + ] = pe_values + attr_nonmissing_support = found_mask | self_pair_mask + if attr_idx == 0: + nonmissing_support = attr_nonmissing_support + continue + if nonmissing_support is None or not torch.equal( + nonmissing_support, + attr_nonmissing_support, + ): + attr_name = attr_names[attr_idx] if attr_names else f"attr_{attr_idx}" + raise ValueError( + "Pairwise attention bias attributes must share identical " + "nonmissing support after treating valid diagonal self pairs " + f"as nonmissing, but '{first_attr_name}' and '{attr_name}' " + "differ." + ) - return features + assert nonmissing_support is not None + pairwise_nonmissing_indices = torch.stack( + [ + valid_batch_indices[nonmissing_support], + valid_row_positions[nonmissing_support], + valid_col_positions[nonmissing_support], + ], + dim=1, + ) + return features, pairwise_nonmissing_indices def _get_k_hop_neighbors_sparse( @@ -964,11 +1019,35 @@ def _lookup_csr_values( Returns: (n,) values from csr_matrix[row, col], or default_value if not present """ + values, _ = _lookup_csr_values_and_found( + csr_matrix=csr_matrix, + row_indices=row_indices, + col_indices=col_indices, + default_value=default_value, + ) + return values + + +def _lookup_csr_values_and_found( + csr_matrix: Tensor, + row_indices: Tensor, + col_indices: Tensor, + default_value: float = 0.0, +) -> tuple[Tensor, Tensor]: + """ + Look up values in a CSR sparse matrix and report which entries were present. + + Returns both the looked-up values and a boolean found-mask so callers can + distinguish missing sparse entries from explicit zero-valued entries. + """ n = row_indices.size(0) device = row_indices.device if n == 0: - return torch.zeros(0, device=device, dtype=torch.float) + return ( + torch.zeros(0, device=device, dtype=torch.float), + torch.zeros(0, device=device, dtype=torch.bool), + ) crow_indices = csr_matrix.crow_indices() col_indices_csr = csr_matrix.col_indices() @@ -981,7 +1060,10 @@ def _lookup_csr_values( max_row_len = row_lengths.max().item() if max_row_len == 0: - return torch.full((n,), default_value, device=device, dtype=torch.float) + return ( + torch.full((n,), default_value, device=device, dtype=torch.float), + torch.zeros((n,), device=device, dtype=torch.bool), + ) # Build offset matrix: (n, max_row_len) offsets = row_starts.unsqueeze(1) + torch.arange(max_row_len, device=device) @@ -1007,4 +1089,4 @@ def _lookup_csr_values( value_indices = row_starts[found] + match_offsets[found] result[found] = values_csr[value_indices].float() - return result + return result, found diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index d0fce10c3..7ff07e601 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -388,6 +388,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: assert encoder._anchor_pe_attention_bias_projection is not None assert encoder._pairwise_pe_attention_bias_projection is not None + assert encoder._pairwise_nonmissing_attention_bias is not None with torch.no_grad(): encoder._anchor_pe_attention_bias_projection.weight.copy_( @@ -411,6 +412,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: ] ] ), + "pairwise_nonmissing_indices": None, "token_input": None, }, ) @@ -447,6 +449,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] ), "pairwise_bias": None, + "pairwise_nonmissing_indices": None, "token_input": None, }, ) @@ -457,6 +460,44 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( self.assertEqual(attn_bias[0, 0, 0, 2].item(), 4.25) self.assertEqual(attn_bias[0, 1, 0, 2].item(), 8.5) + def test_pairwise_nonmissing_indices_add_head_specific_bias(self) -> None: + encoder = self._create_encoder( + pairwise_attention_bias_attr_names=["pairwise_distance"], + ) + + assert encoder._pairwise_pe_attention_bias_projection is not None + assert encoder._pairwise_nonmissing_attention_bias is not None + + with torch.no_grad(): + encoder._pairwise_pe_attention_bias_projection.weight.zero_() + encoder._pairwise_nonmissing_attention_bias.copy_(torch.tensor([0.5, 1.5])) + + attn_bias = encoder._build_attention_bias( + valid_mask=torch.ones((1, 3), dtype=torch.bool), + sequences=torch.zeros((1, 3, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": None, + "pairwise_bias": torch.zeros((1, 3, 3, 1), dtype=torch.float), + "pairwise_nonmissing_indices": torch.tensor( + [ + [0, 0, 0], + [0, 0, 1], + [0, 1, 0], + [0, 1, 1], + [0, 2, 2], + ], + dtype=torch.long, + ), + "token_input": None, + }, + ) + + self.assertEqual(attn_bias.shape, (1, 2, 3, 3)) + self.assertEqual(attn_bias[0, 0, 0, 0].item(), 0.5) + self.assertEqual(attn_bias[0, 1, 0, 0].item(), 1.5) + self.assertEqual(attn_bias[0, 0, 0, 2].item(), 0.0) + self.assertEqual(attn_bias[0, 1, 1, 2].item(), 0.0) + def test_sinusoidal_sequence_positional_encoding_masks_padding(self) -> None: encoder = self._create_encoder( sequence_construction_method="ppr", diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 18551014b..8e8df387d 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -122,6 +122,27 @@ def create_ppr_sequence_hetero_data() -> HeteroData: return data +def _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices: torch.Tensor | None, + batch_size: int, + seq_len: int, + device: torch.device, +) -> torch.Tensor: + dense_mask = torch.zeros( + (batch_size, seq_len, seq_len), + dtype=torch.bool, + device=device, + ) + if pairwise_nonmissing_indices is None or pairwise_nonmissing_indices.numel() == 0: + return dense_mask + dense_mask[ + pairwise_nonmissing_indices[:, 0], + pairwise_nonmissing_indices[:, 1], + pairwise_nonmissing_indices[:, 2], + ] = True + return dense_mask + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -792,6 +813,7 @@ def test_transform_returns_base_sequences_and_anchor_relative_bias(self) -> None assert anchor_bias is not None self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertIsNone(attention_bias_data["pairwise_bias"]) + self.assertIsNone(attention_bias_data["pairwise_nonmissing_indices"]) self.assertTrue(valid_mask[0, 0].item()) def test_attention_bias_outputs_include_valid_mask_and_relative_features( @@ -817,17 +839,96 @@ def test_attention_bias_outputs_include_valid_mask_and_relative_features( self.assertEqual(valid_mask.shape, (1, 4)) anchor_bias = attention_bias_data["anchor_bias"] pairwise_bias = attention_bias_data["pairwise_bias"] + pairwise_nonmissing_indices = attention_bias_data["pairwise_nonmissing_indices"] assert anchor_bias is not None assert pairwise_bias is not None + assert pairwise_nonmissing_indices is not None + pairwise_nonmissing_mask = _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices=pairwise_nonmissing_indices, + batch_size=1, + seq_len=4, + device=pairwise_bias.device, + ) self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertEqual(pairwise_bias.shape, (1, 4, 4, 1)) + self.assertEqual(pairwise_nonmissing_indices.shape[1], 3) self.assertAlmostEqual(anchor_bias[0, 0, 0].item(), 0.0, places=5) self.assertAlmostEqual(anchor_bias[0, 1, 0].item(), 1.0, places=5) self.assertAlmostEqual(anchor_bias[0, 2, 0].item(), 3.0, places=5) self.assertAlmostEqual(pairwise_bias[0, 0, 0, 0].item(), 0.1, places=5) + self.assertTrue(torch.all(pairwise_nonmissing_mask[0, :3, :3])) invalid_pair_mask = ~(valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)) self.assertTrue(torch.all(pairwise_bias[..., 0][invalid_pair_mask] == 0)) + self.assertTrue(torch.all(~pairwise_nonmissing_mask[invalid_pair_mask])) + + def test_pairwise_nonmissing_indices_distinguish_self_from_missing(self) -> None: + data = _create_hetero_data_with_relative_pe() + data.pairwise_distance = torch.sparse_csr_tensor( + crow_indices=torch.tensor([0, 1, 1, 1, 2, 2]), + col_indices=torch.tensor([3, 0]), + values=torch.tensor([0.0, 0.7]), + size=(5, 5), + ) + + _, valid_mask, attention_bias_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + pairwise_attention_bias_attr_names=["pairwise_distance"], + ) + + pairwise_bias = attention_bias_data["pairwise_bias"] + pairwise_nonmissing_indices = attention_bias_data["pairwise_nonmissing_indices"] + assert pairwise_bias is not None + assert pairwise_nonmissing_indices is not None + pairwise_nonmissing_mask = _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices=pairwise_nonmissing_indices, + batch_size=1, + seq_len=4, + device=pairwise_bias.device, + ) + + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + self.assertEqual(pairwise_bias[0, 0, 2, 0].item(), 0.0) + self.assertTrue(pairwise_nonmissing_mask[0, 0, 0].item()) + self.assertTrue(pairwise_nonmissing_mask[0, 1, 1].item()) + self.assertTrue(pairwise_nonmissing_mask[0, 2, 2].item()) + self.assertTrue(pairwise_nonmissing_mask[0, 0, 2].item()) + self.assertTrue(pairwise_nonmissing_mask[0, 2, 0].item()) + self.assertFalse(pairwise_nonmissing_mask[0, 0, 1].item()) + self.assertFalse(pairwise_nonmissing_mask[0, 1, 0].item()) + self.assertFalse(pairwise_nonmissing_mask[0, 1, 2].item()) + self.assertFalse(pairwise_nonmissing_mask[0, 3, 3].item()) + + def test_pairwise_attention_bias_attr_support_mismatch_raises(self) -> None: + data = _create_hetero_data_with_relative_pe() + data.pairwise_distance_sparse_mismatch = torch.sparse_csr_tensor( + crow_indices=torch.tensor([0, 1, 1, 1, 1, 1]), + col_indices=torch.tensor([0]), + values=torch.tensor([1.0]), + size=(5, 5), + ) + + with self.assertRaisesRegex( + ValueError, + "Pairwise attention bias attributes must share identical nonmissing support", + ): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + pairwise_attention_bias_attr_names=[ + "pairwise_distance", + "pairwise_distance_sparse_mismatch", + ], + ) if __name__ == "__main__":