Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions gigl/nn/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,12 +674,16 @@ def __init__(
)

self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None
self._pairwise_nonmissing_attention_bias: Optional[nn.Parameter] = None
if self._pairwise_attention_bias_attr_names:
self._pairwise_pe_attention_bias_projection = nn.Linear(
len(self._pairwise_attention_bias_attr_names),
num_heads,
bias=False,
)
self._pairwise_nonmissing_attention_bias = nn.Parameter(
torch.zeros(num_heads)
)

# Transformer encoder layers
# Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU
Expand Down Expand Up @@ -1003,6 +1007,7 @@ def _build_attention_bias(
attention_bias_data: Dictionary containing optional PE tensors:
- "anchor_bias": (batch, seq, num_anchor_attrs) or None
- "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None
- "pairwise_nonmissing_indices": (num_pairs, 3) or None

Returns:
Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len)
Expand Down Expand Up @@ -1066,6 +1071,40 @@ def _build_attention_bias(
) # (batch, num_heads, seq, seq)
attn_bias = attn_bias + pairwise_bias

pairwise_nonmissing_indices = attention_bias_data.get(
"pairwise_nonmissing_indices"
)
if pairwise_nonmissing_indices is not None:
if self._pairwise_nonmissing_attention_bias is None:
raise ValueError(
"Pairwise nonmissing attention bias is not initialized."
)
if pairwise_nonmissing_indices.numel() > 0:
if attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len):
attn_bias = attn_bias.expand(
batch_size,
self._num_heads,
seq_len,
seq_len,
).clone()
pairwise_nonmissing_indices = pairwise_nonmissing_indices.to(
device=device,
dtype=torch.long,
)
attn_bias_by_position = attn_bias.permute(0, 2, 3, 1)
nonmissing_bias = self._pairwise_nonmissing_attention_bias.to(
dtype=attn_bias.dtype
).view(1, -1)
attn_bias_by_position.index_put_(
(
pairwise_nonmissing_indices[:, 0],
pairwise_nonmissing_indices[:, 1],
pairwise_nonmissing_indices[:, 2],
),
nonmissing_bias.expand(pairwise_nonmissing_indices.size(0), -1),
accumulate=True,
)

return attn_bias

def _encode_and_readout(
Expand Down
124 changes: 103 additions & 21 deletions gigl/transforms/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
class SequenceAuxiliaryData(TypedDict):
anchor_bias: Optional[Tensor]
pairwise_bias: Optional[Tensor]
pairwise_nonmissing_indices: Optional[Tensor]
token_input: Optional[TokenInputData]


Expand Down Expand Up @@ -150,6 +151,9 @@ def heterodata_to_graph_transformer_input(
``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None
``"pairwise_bias"`` shaped
``(batch, seq, seq, num_pairwise_attrs)`` or None
``"pairwise_nonmissing_indices"`` shaped ``(num_pairs, 3)`` or None,
storing ``(batch_idx, row_pos, col_pos)`` coordinates for
nonmissing pairwise entries
``"token_input"`` as a dict mapping attribute name to a
``(batch, seq, 1)`` tensor, or None

Expand Down Expand Up @@ -326,11 +330,14 @@ def heterodata_to_graph_transformer_input(
device=device,
)

pairwise_feature_sequences = _lookup_pairwise_relative_features(
node_index_sequences=node_index_sequences,
valid_mask=valid_mask,
csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None,
device=device,
pairwise_feature_sequences, pairwise_nonmissing_indices = (
_lookup_pairwise_relative_features(
node_index_sequences=node_index_sequences,
valid_mask=valid_mask,
csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None,
attr_names=pairwise_bias_attr_names,
device=device,
)
)

anchor_bias_features = _compose_anchor_feature_tensor(
Expand All @@ -352,6 +359,7 @@ def heterodata_to_graph_transformer_input(
{
"anchor_bias": anchor_bias_features,
"pairwise_bias": pairwise_feature_sequences,
"pairwise_nonmissing_indices": pairwise_nonmissing_indices,
"token_input": token_input_features,
},
)
Expand Down Expand Up @@ -818,8 +826,9 @@ def _lookup_pairwise_relative_features(
node_index_sequences: Tensor,
valid_mask: Tensor,
csr_matrices: Optional[list[Tensor]],
attr_names: Optional[list[str]],
device: torch.device,
) -> Optional[Tensor]:
) -> tuple[Optional[Tensor], Optional[Tensor]]:
"""
Look up pairwise sparse values for each valid token pair in the sequence.

Expand All @@ -835,13 +844,20 @@ def _lookup_pairwise_relative_features(
node_index_sequences: (batch_size, max_seq_len) node indices for each sequence position
valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions
csr_matrices: List of sparse CSR matrices, each (num_nodes, num_nodes)
attr_names: Optional names for the pairwise attributes. Used only to
produce clearer error messages when multiple attrs disagree on
sparse support.
device: Device for output tensor

Returns:
features: (batch_size, max_seq_len, max_seq_len, num_attrs) tensor where
features[b, i, j, k] = csr_matrices[k][node_index_sequences[b, i], node_index_sequences[b, j]]
for valid (i, j) pairs, 0.0 for padding positions.
Returns None if csr_matrices is empty.
nonmissing_indices: (num_nonmissing_pairs, 3) long tensor containing
``(batch_idx, row_pos, col_pos)`` coordinates for valid diagonal
self pairs and valid sparse entries. Missing non-self pairs and
padding are omitted.
Returns (None, None) if csr_matrices is empty.

Example:
# batch_size=2, max_seq_len=3, num_attrs=1 (e.g., random_walk_se)
Expand All @@ -864,7 +880,7 @@ def _lookup_pairwise_relative_features(
# (pad) [0.0, 0.0, 0.0]
"""
if not csr_matrices:
return None
return None, None

batch_size, max_seq_len = node_index_sequences.shape
num_attrs = len(csr_matrices)
Expand All @@ -876,23 +892,62 @@ def _lookup_pairwise_relative_features(

pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
if not pair_valid_mask.any():
return features

row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len)
col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1)

valid_row_indices = row_indices[pair_valid_mask]
valid_col_indices = col_indices[pair_valid_mask]
return features, torch.zeros((0, 3), dtype=torch.long, device=device)

valid_batch_indices, valid_row_positions, valid_col_positions = torch.nonzero(
pair_valid_mask,
as_tuple=True,
)
valid_row_indices = node_index_sequences[
valid_batch_indices,
valid_row_positions,
]
valid_col_indices = node_index_sequences[
valid_batch_indices,
valid_col_positions,
]
self_pair_mask = valid_row_positions == valid_col_positions

first_attr_name = attr_names[0] if attr_names else "attr_0"
nonmissing_support: Optional[Tensor] = None
for attr_idx, pe_matrix in enumerate(csr_matrices):
pe_values = _lookup_csr_values(
pe_values, found_mask = _lookup_csr_values_and_found(
csr_matrix=pe_matrix,
row_indices=valid_row_indices,
col_indices=valid_col_indices,
)
features[..., attr_idx][pair_valid_mask] = pe_values
features[
valid_batch_indices,
valid_row_positions,
valid_col_positions,
attr_idx,
] = pe_values
attr_nonmissing_support = found_mask | self_pair_mask
if attr_idx == 0:
nonmissing_support = attr_nonmissing_support
continue
if nonmissing_support is None or not torch.equal(
nonmissing_support,
attr_nonmissing_support,
):
attr_name = attr_names[attr_idx] if attr_names else f"attr_{attr_idx}"
raise ValueError(
"Pairwise attention bias attributes must share identical "
"nonmissing support after treating valid diagonal self pairs "
f"as nonmissing, but '{first_attr_name}' and '{attr_name}' "
"differ."
)

return features
assert nonmissing_support is not None
pairwise_nonmissing_indices = torch.stack(
[
valid_batch_indices[nonmissing_support],
valid_row_positions[nonmissing_support],
valid_col_positions[nonmissing_support],
],
dim=1,
)
return features, pairwise_nonmissing_indices


def _get_k_hop_neighbors_sparse(
Expand Down Expand Up @@ -984,11 +1039,35 @@ def _lookup_csr_values(
Returns:
(n,) values from csr_matrix[row, col], or default_value if not present
"""
values, _ = _lookup_csr_values_and_found(
csr_matrix=csr_matrix,
row_indices=row_indices,
col_indices=col_indices,
default_value=default_value,
)
return values


def _lookup_csr_values_and_found(
csr_matrix: Tensor,
row_indices: Tensor,
col_indices: Tensor,
default_value: float = 0.0,
) -> tuple[Tensor, Tensor]:
"""
Look up values in a CSR sparse matrix and report which entries were present.

Returns both the looked-up values and a boolean found-mask so callers can
distinguish missing sparse entries from explicit zero-valued entries.
"""
n = row_indices.size(0)
device = row_indices.device

if n == 0:
return torch.zeros(0, device=device, dtype=torch.float)
return (
torch.zeros(0, device=device, dtype=torch.float),
torch.zeros(0, device=device, dtype=torch.bool),
)

crow_indices = csr_matrix.crow_indices()
col_indices_csr = csr_matrix.col_indices()
Expand All @@ -1001,7 +1080,10 @@ def _lookup_csr_values(
max_row_len = row_lengths.max().item()

if max_row_len == 0:
return torch.full((n,), default_value, device=device, dtype=torch.float)
return (
torch.full((n,), default_value, device=device, dtype=torch.float),
torch.zeros((n,), device=device, dtype=torch.bool),
)

# Build offset matrix: (n, max_row_len)
offsets = row_starts.unsqueeze(1) + torch.arange(max_row_len, device=device)
Expand All @@ -1027,4 +1109,4 @@ def _lookup_csr_values(
value_indices = row_starts[found] + match_offsets[found]
result[found] = values_csr[value_indices].float()

return result
return result, found
41 changes: 41 additions & 0 deletions tests/unit/nn/graph_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None:

assert encoder._anchor_pe_attention_bias_projection is not None
assert encoder._pairwise_pe_attention_bias_projection is not None
assert encoder._pairwise_nonmissing_attention_bias is not None

with torch.no_grad():
encoder._anchor_pe_attention_bias_projection.weight.copy_(
Expand All @@ -570,6 +571,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None:
]
]
),
"pairwise_nonmissing_indices": None,
"token_input": None,
},
)
Expand Down Expand Up @@ -606,6 +608,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights(
[[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]]
),
"pairwise_bias": None,
"pairwise_nonmissing_indices": None,
"token_input": None,
},
)
Expand All @@ -616,6 +619,44 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights(
self.assertEqual(attn_bias[0, 0, 0, 2].item(), 4.25)
self.assertEqual(attn_bias[0, 1, 0, 2].item(), 8.5)

def test_pairwise_nonmissing_indices_add_head_specific_bias(self) -> None:
encoder = self._create_encoder(
pairwise_attention_bias_attr_names=["pairwise_distance"],
)

assert encoder._pairwise_pe_attention_bias_projection is not None
assert encoder._pairwise_nonmissing_attention_bias is not None

with torch.no_grad():
encoder._pairwise_pe_attention_bias_projection.weight.zero_()
encoder._pairwise_nonmissing_attention_bias.copy_(torch.tensor([0.5, 1.5]))

attn_bias = encoder._build_attention_bias(
valid_mask=torch.ones((1, 3), dtype=torch.bool),
sequences=torch.zeros((1, 3, 8), dtype=torch.float),
attention_bias_data={
"anchor_bias": None,
"pairwise_bias": torch.zeros((1, 3, 3, 1), dtype=torch.float),
"pairwise_nonmissing_indices": torch.tensor(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[0, 2, 2],
],
dtype=torch.long,
),
"token_input": None,
},
)

self.assertEqual(attn_bias.shape, (1, 2, 3, 3))
self.assertEqual(attn_bias[0, 0, 0, 0].item(), 0.5)
self.assertEqual(attn_bias[0, 1, 0, 0].item(), 1.5)
self.assertEqual(attn_bias[0, 0, 0, 2].item(), 0.0)
self.assertEqual(attn_bias[0, 1, 1, 2].item(), 0.0)

def test_sinusoidal_sequence_positional_encoding_masks_padding(self) -> None:
encoder = self._create_encoder(
sequence_construction_method="ppr",
Expand Down
Loading