Skip to content

[Feature] 集成 TurboQuant KV Cache 压缩#6

Open
sunghajung6688 wants to merge 7 commits into
hw-native-sys:mainfrom
sunghajung6688:turboquant
Open

[Feature] 集成 TurboQuant KV Cache 压缩#6
sunghajung6688 wants to merge 7 commits into
hw-native-sys:mainfrom
sunghajung6688:turboquant

Conversation

@sunghajung6688
Copy link
Copy Markdown

将 TurboQuant V3(纯 MSE 模式)集成到 pypto-serving 的 KV Cache 管理中,支持在线压缩/解压缩历史 token,减少 KV Cache 内存占用。

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces TurboQuant KV cache quantization to optimize memory usage during prefill and decode phases. The implementation includes a new turboquant module featuring MSE-optimal compression using random rotation and Lloyd-Max quantization. Feedback focuses on critical performance bottlenecks: the current compression logic re-processes the entire sequence prefix in each step, leading to $O(N^2)$ complexity; the manual Python loops used for writing to paged cache should be replaced with vectorized advanced indexing; and several constant tensor allocations and PDF calculations should be pre-computed or vectorized to reduce overhead.

Comment thread python/core/kv_cache.py
if alloc.request_id not in pool.compressed_segments:
pool.compressed_segments[alloc.request_id] = {}

for layer_idx in range(pool.num_layers):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

当前实现在每次调用时都会重新压缩整个序列前缀(从 token 0 到 tokens_used - residual_window)。在 decode 循环中调用时,这将导致相对于序列长度的 $O(N^2)$ 复杂度,随着上下文增长,性能将严重下降。建议实现增量压缩,仅处理刚刚超出 residual_window 的新 token。此外,调用的 read_context 方法内部也存在类似的循环读取瓶颈,建议一并优化。

Comment thread python/core/kv_cache.py
Comment on lines +379 to +385
for row in range(keys.shape[0]):
token_index = start_token_index + row
page_idx = token_index // pool.page_size
offset = token_index % pool.page_size
physical_page = alloc.page_ids[page_idx]
pool.key_pages[layer_idx, physical_page, :, offset, :] = keys[row].to(cache_dtype)
pool.value_pages[layer_idx, physical_page, :, offset, :] = values[row].to(cache_dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

在 Python 循环中遍历 token 并写入分页缓存(paged cache)非常低效,特别是对于设备张量(NPU/GPU)。应使用 PyTorch 的高级索引(advanced indexing)进行向量化,以在单个操作中完成更新。

Suggested change
for row in range(keys.shape[0]):
token_index = start_token_index + row
page_idx = token_index // pool.page_size
offset = token_index % pool.page_size
physical_page = alloc.page_ids[page_idx]
pool.key_pages[layer_idx, physical_page, :, offset, :] = keys[row].to(cache_dtype)
pool.value_pages[layer_idx, physical_page, :, offset, :] = values[row].to(cache_dtype)
token_indices = torch.arange(start_token_index, start_token_index + keys.shape[0], device=keys.device)
page_indices = token_indices // pool.page_size
offsets = token_indices % pool.page_size
physical_pages = torch.tensor(alloc.page_ids, device=keys.device)[page_indices]
pool.key_pages[layer_idx, physical_pages, :, offsets, :] = keys.to(cache_dtype)
pool.value_pages[layer_idx, physical_pages, :, offsets, :] = values.to(cache_dtype)

Comment on lines +68 to +72
idx_powers = torch.tensor(
[2 ** (self.bits * i) for i in range(indices_per_byte - 1, -1, -1)],
dtype=torch.long,
device=idx_flat.device,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

idx_powers 是一个常量张量。在每次 compress 调用时创建它是低效的。建议在 __init__ 中根据 bits 预先计算并存储为 buffer,以减少冗余分配和 H2D 拷贝。

Comment on lines +47 to +48
pdf_vals = torch.tensor([pdf(x) for x in xs])
weighted = xs * pdf_vals
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

高斯 PDF 计算目前在 Python 列表推导式中执行,对于 2048 个样本来说较慢。可以使用 PyTorch 操作轻松实现向量化,提高初始化速度。

            xs = torch.linspace(a, b, n_samples)
            pdf_vals = (1.0 / (math.sqrt(2 * math.pi) * sigma)) * torch.exp(-xs**2 / (2 * sigma**2))
            weighted = xs * pdf_vals

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant