[Feature] 集成 TurboQuant KV Cache 压缩#6
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces TurboQuant KV cache quantization to optimize memory usage during prefill and decode phases. The implementation includes a new turboquant module featuring MSE-optimal compression using random rotation and Lloyd-Max quantization. Feedback focuses on critical performance bottlenecks: the current compression logic re-processes the entire sequence prefix in each step, leading to
| if alloc.request_id not in pool.compressed_segments: | ||
| pool.compressed_segments[alloc.request_id] = {} | ||
|
|
||
| for layer_idx in range(pool.num_layers): |
| for row in range(keys.shape[0]): | ||
| token_index = start_token_index + row | ||
| page_idx = token_index // pool.page_size | ||
| offset = token_index % pool.page_size | ||
| physical_page = alloc.page_ids[page_idx] | ||
| pool.key_pages[layer_idx, physical_page, :, offset, :] = keys[row].to(cache_dtype) | ||
| pool.value_pages[layer_idx, physical_page, :, offset, :] = values[row].to(cache_dtype) |
There was a problem hiding this comment.
在 Python 循环中遍历 token 并写入分页缓存(paged cache)非常低效,特别是对于设备张量(NPU/GPU)。应使用 PyTorch 的高级索引(advanced indexing)进行向量化,以在单个操作中完成更新。
| for row in range(keys.shape[0]): | |
| token_index = start_token_index + row | |
| page_idx = token_index // pool.page_size | |
| offset = token_index % pool.page_size | |
| physical_page = alloc.page_ids[page_idx] | |
| pool.key_pages[layer_idx, physical_page, :, offset, :] = keys[row].to(cache_dtype) | |
| pool.value_pages[layer_idx, physical_page, :, offset, :] = values[row].to(cache_dtype) | |
| token_indices = torch.arange(start_token_index, start_token_index + keys.shape[0], device=keys.device) | |
| page_indices = token_indices // pool.page_size | |
| offsets = token_indices % pool.page_size | |
| physical_pages = torch.tensor(alloc.page_ids, device=keys.device)[page_indices] | |
| pool.key_pages[layer_idx, physical_pages, :, offsets, :] = keys.to(cache_dtype) | |
| pool.value_pages[layer_idx, physical_pages, :, offsets, :] = values.to(cache_dtype) |
| idx_powers = torch.tensor( | ||
| [2 ** (self.bits * i) for i in range(indices_per_byte - 1, -1, -1)], | ||
| dtype=torch.long, | ||
| device=idx_flat.device, | ||
| ) |
| pdf_vals = torch.tensor([pdf(x) for x in xs]) | ||
| weighted = xs * pdf_vals |
ee84d61 to
4b65bf3
Compare
63f8ac7 to
0c160bf
Compare
将 TurboQuant V3(纯 MSE 模式)集成到 pypto-serving 的 KV Cache 管理中,支持在线压缩/解压缩历史 token,减少 KV Cache 内存占用。