Skip to content

vllm.v1.attention.ops.triton_reshape_and_cache_flash

Core paged-cache reshape kernels.

This file owns the canonical (mode NONE / FP8 per-tensor) reshape kernels and the diff-kv variant. All per-token-head and packed-int modes (INT8 / FP8 / INT4 / INT2) live in dedicated backend modules under :mod:vllm.v1.attention.ops.triton_quant_kv.

For backwards compatibility this module still exposes triton_reshape_and_cache_flash_per_token_head_quant, fast_hadamard_transform and _single_rht as thin re-exports / dispatchers, so existing tests and benchmarks keep working.

triton_reshape_and_cache_flash_per_token_head_quant

triton_reshape_and_cache_flash_per_token_head_quant(
    key: Tensor,
    value: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    k_scale_cache: Tensor,
    v_scale_cache: Tensor,
    slot_mapping: Tensor,
    kv_quant_mode: KVQuantMode | None = None,
)

Quantize key/value per (token, head) and write to the paged cache.

Dispatches to the appropriate backend in :mod:vllm.v1.attention.ops.triton_quant_kv. When kv_quant_mode is None (legacy callers) the mode is inferred from the cache dtype — but this disambiguation cannot tell INT4 from INT2 (both stored as torch.uint8) and is deprecated; pass kv_quant_mode explicitly.

Source code in vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
def triton_reshape_and_cache_flash_per_token_head_quant(
    key: torch.Tensor,  # [num_tokens, num_kv_heads, head_size]
    value: torch.Tensor,  # [num_tokens, num_kv_heads, head_size_v]
    key_cache: torch.Tensor,  # [num_blocks, block_size, num_kv_heads, head_size]
    value_cache: torch.Tensor,  # [num_blocks, block_size, num_kv_heads, head_size_v]
    k_scale_cache: torch.Tensor,  # [num_blocks, block_size, num_kv_heads] float32
    v_scale_cache: torch.Tensor,  # [num_blocks, block_size, num_kv_heads] float32
    slot_mapping: torch.Tensor,  # [num_tokens]
    kv_quant_mode: KVQuantMode | None = None,
):
    """Quantize key/value per (token, head) and write to the paged cache.

    Dispatches to the appropriate backend in
    :mod:`vllm.v1.attention.ops.triton_quant_kv`.  When *kv_quant_mode* is
    ``None`` (legacy callers) the mode is inferred from the cache dtype
    — but this disambiguation cannot tell INT4 from INT2 (both stored as
    ``torch.uint8``) and is deprecated; pass ``kv_quant_mode``
    explicitly.
    """
    if kv_quant_mode is None:
        from vllm.model_executor.layers.quantization.utils.quant_utils import (
            FP8_DTYPE,
        )

        warnings.warn(
            "triton_reshape_and_cache_flash_per_token_head_quant: calling "
            "without `kv_quant_mode` is deprecated and will be removed in a "
            "future release.  Pass the KVQuantMode explicitly.",
            DeprecationWarning,
            stacklevel=2,
        )
        if key_cache.dtype == FP8_DTYPE:
            kv_quant_mode = KVQuantMode.FP8_PER_TOKEN_HEAD
        else:
            kv_quant_mode = KVQuantMode.INT8_PER_TOKEN_HEAD

    from vllm.v1.attention.ops.triton_quant_kv import get_backend

    backend = get_backend(kv_quant_mode)
    backend.reshape_and_cache(
        key=key,
        value=value,
        key_cache=key_cache,
        value_cache=value_cache,
        slot_mapping=slot_mapping,
        k_scale_cache=k_scale_cache,
        v_scale_cache=v_scale_cache,
    )