Skip to content

vllm.v1.attention.ops.triton_quant_kv

Per-mode KV cache quantization backends.

The core attention kernel (:mod:vllm.v1.attention.ops.triton_unified_attention) handles modes NONE, FP8_PER_TENSOR, INT8_PER_TOKEN_HEAD and FP8_PER_TOKEN_HEAD directly via constexpr branches. Backends registered here own:

  • the write side for any mode that needs more than a plain copy (per-token-head absmax, asymmetric INT4 with zero-point packing, INT2 Lloyd-Max + Hadamard, …); and
  • the attention read side for sub-byte packed modes (INT4 / INT2) whose inner loop is structurally different from the core kernel (split-dot, centroid lookup, etc.).

Adding a new quantization mode

  1. Add a new value to :class:KVQuantMode in vllm/v1/kv_cache_interface.py.
  2. Add a new entry to _MODULES below mapping the mode to a module path.
  3. Create a new file under quant_kv/ that defines a subclass of :class:QuantKVBackend and calls :func:register at module level. If the mode can use the core attention kernel, override only reshape_and_cache / allocate_scale_caches; otherwise also override unified_attention.

Modules:

Name Description
base

Backend protocol for KV cache quantization modes.

int8_fp8_per_token_head

INT8 and FP8 per-token-head KV cache quantization backends.

packed_per_token_head

Sub-byte per-token-head KV cache quantization backends (INT4 + INT2).

QuantKVBackend

Bases: ABC

Cache write + attention read for one KV quantization mode.

Subclasses implement reshape_and_cache and unified_attention for a single :class:KVQuantMode, and call :func:vllm.v1.attention.ops.triton_quant_kv.register at module import. The dispatcher in :mod:triton_unified_attention and :mod:triton_reshape_and_cache_flash looks up the backend lazily on first use, so unused modes pay zero import or compile cost.

Source code in vllm/v1/attention/ops/triton_quant_kv/base.py
class QuantKVBackend(ABC):
    """Cache write + attention read for one KV quantization mode.

    Subclasses implement ``reshape_and_cache`` and ``unified_attention``
    for a single :class:`KVQuantMode`, and call
    :func:`vllm.v1.attention.ops.triton_quant_kv.register` at module import.
    The dispatcher in :mod:`triton_unified_attention` and
    :mod:`triton_reshape_and_cache_flash` looks up the backend lazily on
    first use, so unused modes pay zero import or compile cost.
    """

    # ----- Static metadata --------------------------------------------------
    #: Mode this backend implements.  Must be set by subclasses.
    mode: KVQuantMode
    #: Number of cache *bytes* used per logical KV element (1 unless packed).
    packing_factor: int = 1
    #: Whether this mode allocates its own per-(token, head) scale buffers.
    needs_scale_caches: bool = False

    # ----- Cache shape introspection ----------------------------------------
    def packed_head_size(self, head_size: int) -> int:
        """Storage head size after packing: ``head_size // packing_factor``."""
        assert head_size % self.packing_factor == 0, (
            f"head_size={head_size} is not divisible by packing factor "
            f"{self.packing_factor} required by {self.mode.name}"
        )
        return head_size // self.packing_factor

    def allocate_scale_caches(
        self,
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        device: torch.device,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        """Allocate aux per-(token, head) scale buffers.

        Default: when ``needs_scale_caches`` is True, allocate one
        ``float32`` per (block, slot, kv_head) for both K and V — the
        layout shared by every per-token-head mode (INT8 / FP8 store
        one absmax-derived scale; INT4 steganographs the zero-point in
        the low 4 mantissa bits of that scale; INT2 stores
        ``norm / d^1.5``).  Modes that need a different shape or dtype
        override this method.  Modes that don't need scale caches at
        all (``needs_scale_caches = False``) get ``(None, None)``.
        """
        if not self.needs_scale_caches:
            return (None, None)
        shape = (num_blocks, block_size, num_kv_heads)
        return (
            torch.zeros(shape, dtype=torch.float32, device=device),
            torch.zeros(shape, dtype=torch.float32, device=device),
        )

    # ----- Cache write path -------------------------------------------------
    @abstractmethod
    def reshape_and_cache(
        self,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        *,
        k_scale_cache: torch.Tensor | None = None,
        v_scale_cache: torch.Tensor | None = None,
    ) -> None:
        """Write *key*/*value* into the paged cache for this mode.

        Per-token-head modes also write into ``k_scale_cache`` /
        ``v_scale_cache``.
        """

    # ----- Attention read path ----------------------------------------------
    # Only modes that need a bespoke attention loop (INT4 / INT2 with
    # split-dot + sub-byte unpack) override this.  INT8 / FP8 per-token-head
    # use the core kernel via a constexpr branch and never call this method.
    def unified_attention(
        self,
        q: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        out: torch.Tensor,
        *,
        cu_seqlens_q: torch.Tensor,
        max_seqlen_q: int,
        seqused_k: torch.Tensor,
        max_seqlen_k: int,
        softmax_scale: float,
        window_size: tuple[int, int],
        block_table: torch.Tensor,
        softcap: float,
        sinks: torch.Tensor | None,
        alibi_slopes: torch.Tensor | None,
        use_alibi_sqrt: bool,
        qq_bias: torch.Tensor | None,
        output_scale: torch.Tensor | None,
        mm_prefix_range: torch.Tensor | None,
        k_scale_cache: torch.Tensor | None = None,
        v_scale_cache: torch.Tensor | None = None,
        # Optional 3D-decode pre-allocated buffers (same as the core kernel)
        seq_threshold_3D: int | None = None,
        num_par_softmax_segments: int | None = None,
        softmax_segm_output: torch.Tensor | None = None,
        softmax_segm_max: torch.Tensor | None = None,
        softmax_segm_expsum: torch.Tensor | None = None,
    ) -> None:
        """Run paged attention with this mode's KV layout, writing into *out*."""
        raise NotImplementedError(
            f"{type(self).__name__} does not implement a bespoke attention "
            f"kernel.  This mode should be handled by the core kernel."
        )

allocate_scale_caches

allocate_scale_caches(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    device: device,
) -> tuple[Tensor | None, Tensor | None]

Allocate aux per-(token, head) scale buffers.

Default: when needs_scale_caches is True, allocate one float32 per (block, slot, kv_head) for both K and V — the layout shared by every per-token-head mode (INT8 / FP8 store one absmax-derived scale; INT4 steganographs the zero-point in the low 4 mantissa bits of that scale; INT2 stores norm / d^1.5). Modes that need a different shape or dtype override this method. Modes that don't need scale caches at all (needs_scale_caches = False) get (None, None).

Source code in vllm/v1/attention/ops/triton_quant_kv/base.py
def allocate_scale_caches(
    self,
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    device: torch.device,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
    """Allocate aux per-(token, head) scale buffers.

    Default: when ``needs_scale_caches`` is True, allocate one
    ``float32`` per (block, slot, kv_head) for both K and V — the
    layout shared by every per-token-head mode (INT8 / FP8 store
    one absmax-derived scale; INT4 steganographs the zero-point in
    the low 4 mantissa bits of that scale; INT2 stores
    ``norm / d^1.5``).  Modes that need a different shape or dtype
    override this method.  Modes that don't need scale caches at
    all (``needs_scale_caches = False``) get ``(None, None)``.
    """
    if not self.needs_scale_caches:
        return (None, None)
    shape = (num_blocks, block_size, num_kv_heads)
    return (
        torch.zeros(shape, dtype=torch.float32, device=device),
        torch.zeros(shape, dtype=torch.float32, device=device),
    )

packed_head_size

packed_head_size(head_size: int) -> int

Storage head size after packing: head_size // packing_factor.

Source code in vllm/v1/attention/ops/triton_quant_kv/base.py
def packed_head_size(self, head_size: int) -> int:
    """Storage head size after packing: ``head_size // packing_factor``."""
    assert head_size % self.packing_factor == 0, (
        f"head_size={head_size} is not divisible by packing factor "
        f"{self.packing_factor} required by {self.mode.name}"
    )
    return head_size // self.packing_factor

reshape_and_cache abstractmethod

reshape_and_cache(
    key: Tensor,
    value: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    slot_mapping: Tensor,
    *,
    k_scale_cache: Tensor | None = None,
    v_scale_cache: Tensor | None = None,
) -> None

Write key/value into the paged cache for this mode.

Per-token-head modes also write into k_scale_cache / v_scale_cache.

Source code in vllm/v1/attention/ops/triton_quant_kv/base.py
@abstractmethod
def reshape_and_cache(
    self,
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    *,
    k_scale_cache: torch.Tensor | None = None,
    v_scale_cache: torch.Tensor | None = None,
) -> None:
    """Write *key*/*value* into the paged cache for this mode.

    Per-token-head modes also write into ``k_scale_cache`` /
    ``v_scale_cache``.
    """

unified_attention

unified_attention(
    q: Tensor,
    k_cache: Tensor,
    v_cache: Tensor,
    out: Tensor,
    *,
    cu_seqlens_q: Tensor,
    max_seqlen_q: int,
    seqused_k: Tensor,
    max_seqlen_k: int,
    softmax_scale: float,
    window_size: tuple[int, int],
    block_table: Tensor,
    softcap: float,
    sinks: Tensor | None,
    alibi_slopes: Tensor | None,
    use_alibi_sqrt: bool,
    qq_bias: Tensor | None,
    output_scale: Tensor | None,
    mm_prefix_range: Tensor | None,
    k_scale_cache: Tensor | None = None,
    v_scale_cache: Tensor | None = None,
    seq_threshold_3D: int | None = None,
    num_par_softmax_segments: int | None = None,
    softmax_segm_output: Tensor | None = None,
    softmax_segm_max: Tensor | None = None,
    softmax_segm_expsum: Tensor | None = None,
) -> None

Run paged attention with this mode's KV layout, writing into out.

Source code in vllm/v1/attention/ops/triton_quant_kv/base.py
def unified_attention(
    self,
    q: torch.Tensor,
    k_cache: torch.Tensor,
    v_cache: torch.Tensor,
    out: torch.Tensor,
    *,
    cu_seqlens_q: torch.Tensor,
    max_seqlen_q: int,
    seqused_k: torch.Tensor,
    max_seqlen_k: int,
    softmax_scale: float,
    window_size: tuple[int, int],
    block_table: torch.Tensor,
    softcap: float,
    sinks: torch.Tensor | None,
    alibi_slopes: torch.Tensor | None,
    use_alibi_sqrt: bool,
    qq_bias: torch.Tensor | None,
    output_scale: torch.Tensor | None,
    mm_prefix_range: torch.Tensor | None,
    k_scale_cache: torch.Tensor | None = None,
    v_scale_cache: torch.Tensor | None = None,
    # Optional 3D-decode pre-allocated buffers (same as the core kernel)
    seq_threshold_3D: int | None = None,
    num_par_softmax_segments: int | None = None,
    softmax_segm_output: torch.Tensor | None = None,
    softmax_segm_max: torch.Tensor | None = None,
    softmax_segm_expsum: torch.Tensor | None = None,
) -> None:
    """Run paged attention with this mode's KV layout, writing into *out*."""
    raise NotImplementedError(
        f"{type(self).__name__} does not implement a bespoke attention "
        f"kernel.  This mode should be handled by the core kernel."
    )

get_backend

get_backend(mode: KVQuantMode) -> QuantKVBackend

Lazy-import and return the backend for mode.

Raises ValueError if no backend module is configured for mode.

Source code in vllm/v1/attention/ops/triton_quant_kv/__init__.py
def get_backend(mode: KVQuantMode) -> QuantKVBackend:
    """Lazy-import and return the backend for *mode*.

    Raises ``ValueError`` if no backend module is configured for *mode*.
    """
    if mode == KVQuantMode.NONE:
        raise ValueError("KVQuantMode.NONE is the unquantized path and has no backend")
    if mode not in _REGISTRY:
        module_path = _MODULES.get(mode)
        if module_path is None:
            raise ValueError(
                f"No QuantKVBackend module configured for {mode.name}.  "
                f"Add an entry to _MODULES in "
                f"vllm/v1/attention/ops/quant_kv/__init__.py."
            )
        importlib.import_module(module_path)
        if mode not in _REGISTRY:
            raise RuntimeError(
                f"Module {module_path} did not register a backend for "
                f"{mode.name}.  Each backend module must call "
                f"`quant_kv.register(MyBackend())` at the bottom."
            )
    return _REGISTRY[mode]

has_backend

has_backend(mode: KVQuantMode) -> bool

Return True if mode has a backend (loaded or lazily available).

Source code in vllm/v1/attention/ops/triton_quant_kv/__init__.py
def has_backend(mode: KVQuantMode) -> bool:
    """Return True if *mode* has a backend (loaded or lazily available)."""
    return mode in _REGISTRY or mode in _MODULES