Skip to content

vllm.v1.attention.ops.triton_quant_kv.packed_per_token_head

Sub-byte per-token-head KV cache quantization backends (INT4 + INT2).

Both modes share the same skeleton — per-(token, head) dynamic scale + Hadamard pre-rotation on the inputs and inverse Hadamard on the output — but differ in their quantization math and packing factor:

+----------+------------+---------------------+----------------------+ | Mode | Packing | Pre-rotation | Scale encodes | +==========+============+=====================+======================+ | INT4 | 2 / byte | Single RHT | scale + 4-bit zp | | | | (random Hadamard) | (stego in mantissa) | +----------+------------+---------------------+----------------------+ | INT2 | 4 / byte | Full Hadamard | norm / d^1.5 | | | | (no random sign) | (centroid lookup) | +----------+------------+---------------------+----------------------+

Attention kernel

A single :func:_attn_packed kernel handles both modes via the PACKING_FACTOR: tl.constexpr branch (2 → INT4, 4 → INT2). The mode-specific parts — how many Q/KV streams are split, how KV is dequantized (nibble unpack vs Lloyd-Max centroid lookup), how the score is corrected (asymmetric zero-point subtraction vs plain) — are gated by constexpr if on PACKING_FACTOR. Everything else (prologue, masking, online softmax, tile loop, 2D/3D epilogue) is shared.

Both backends register on module import; the attention and reshape kernels are Triton-compiled lazily with the PACKING_FACTOR / quant-specific constexprs of each mode.

Int2PerTokenHeadBackend

Bases: _PackedBackend

KV cache backend for KVQuantMode.INT2_PER_TOKEN_HEAD.

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
class Int2PerTokenHeadBackend(_PackedBackend):
    """KV cache backend for ``KVQuantMode.INT2_PER_TOKEN_HEAD``."""

    mode = KVQuantMode.INT2_PER_TOKEN_HEAD
    packing_factor = 4  # 4 × int2 per byte
    _reshape_kernel = _reshape_cache_int2_kernel

    # Full Hadamard (no random sign).  Its own inverse — so the output
    # rotation is identical.  No softmax_scale adjustment: the ``d^1.5``
    # factor is absorbed into the stored scale at write time.
    @staticmethod
    def _rotate_kv(x: torch.Tensor) -> torch.Tensor:
        return fast_hadamard_transform(x)

    @staticmethod
    def _rotate_q(q: torch.Tensor) -> torch.Tensor:
        return fast_hadamard_transform(q)

    @staticmethod
    def _unrotate_out(out: torch.Tensor, head_size: int) -> torch.Tensor:
        return fast_hadamard_transform(out.float())

Int4PerTokenHeadBackend

Bases: _PackedBackend

KV cache backend for KVQuantMode.INT4_PER_TOKEN_HEAD.

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
class Int4PerTokenHeadBackend(_PackedBackend):
    """KV cache backend for ``KVQuantMode.INT4_PER_TOKEN_HEAD``."""

    mode = KVQuantMode.INT4_PER_TOKEN_HEAD
    packing_factor = 2  # 2 × int4 per byte
    _reshape_kernel = _reshape_cache_int4_kernel

    # RHT pre-rotation gaussianizes data → better quantization.  The
    # forward RHT has norm ``sqrt(head_size)``, so ``softmax_scale`` is
    # divided by ``head_size`` and the inverse RHT divides the output
    # by ``head_size`` as well.
    @staticmethod
    def _rotate_kv(x: torch.Tensor) -> torch.Tensor:
        return single_rht(x)

    @staticmethod
    def _rotate_q(q: torch.Tensor) -> torch.Tensor:
        return single_rht(q)

    @staticmethod
    def _unrotate_out(out: torch.Tensor, head_size: int) -> torch.Tensor:
        return single_rht(out.float(), inverse=True) / head_size

    @staticmethod
    def _transform_softmax_scale(scale: float, head_size: int) -> float:
        return scale / head_size

_PackedBackend

Bases: QuantKVBackend

Shared Backend for sub-byte packed per-token-head modes.

Subclasses declare the mode-specific pieces as class attributes / classmethods; the reshape_and_cache / unified_attention bodies are identical and live here.

Mode-specific hooks (must be set/overridden by subclasses)

_reshape_kernel The @triton.jit reshape kernel for this mode. _rotate_kv(x) Pre-rotation applied to K / V before packing (RHT for INT4, full Hadamard for INT2). _rotate_q(q) Pre-rotation applied to Q before attention. Typically the same rotation as _rotate_kv so the dot product is preserved. _unrotate_out(out, head_size) Inverse rotation on the kernel output, written back in-place. _transform_softmax_scale(scale, head_size) Optional rescaling of softmax_scale before the kernel (INT4 divides by head_size to absorb the RHT scale; INT2 is a no-op).

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
class _PackedBackend(QuantKVBackend):
    """Shared Backend for sub-byte packed per-token-head modes.

    Subclasses declare the mode-specific pieces as class attributes /
    classmethods; the ``reshape_and_cache`` / ``unified_attention``
    bodies are identical and live here.

    Mode-specific hooks (must be set/overridden by subclasses)
    ---------------------------------------------------------
    ``_reshape_kernel``
        The ``@triton.jit`` reshape kernel for this mode.
    ``_rotate_kv(x)``
        Pre-rotation applied to K / V before packing (RHT for INT4,
        full Hadamard for INT2).
    ``_rotate_q(q)``
        Pre-rotation applied to Q before attention.  Typically the same
        rotation as ``_rotate_kv`` so the dot product is preserved.
    ``_unrotate_out(out, head_size)``
        Inverse rotation on the kernel output, written back in-place.
    ``_transform_softmax_scale(scale, head_size)``
        Optional rescaling of ``softmax_scale`` before the kernel (INT4
        divides by ``head_size`` to absorb the RHT scale; INT2 is a
        no-op).
    """

    needs_scale_caches = True

    # Filled in by subclasses.
    _reshape_kernel: object

    @staticmethod
    def _rotate_kv(x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _rotate_q(q: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _unrotate_out(out: torch.Tensor, head_size: int) -> torch.Tensor:
        raise NotImplementedError

    @staticmethod
    def _transform_softmax_scale(scale: float, head_size: int) -> float:
        return scale

    def reshape_and_cache(
        self,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        *,
        k_scale_cache: torch.Tensor | None = None,
        v_scale_cache: torch.Tensor | None = None,
    ) -> None:
        assert k_scale_cache is not None and v_scale_cache is not None, (
            f"{self.mode.name} requires k_scale_cache / v_scale_cache"
        )
        key = self._rotate_kv(key.float()).to(key.dtype)
        value = self._rotate_kv(value.float()).to(value.dtype)
        _run_reshape_kernel(
            self._reshape_kernel,
            key=key,
            value=value,
            key_cache=key_cache,
            value_cache=value_cache,
            k_scale_cache=k_scale_cache,
            v_scale_cache=v_scale_cache,
            slot_mapping=slot_mapping,
            packing_factor=self.packing_factor,
        )

    def unified_attention(
        self,
        q: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        out: torch.Tensor,
        *,
        cu_seqlens_q: torch.Tensor,
        max_seqlen_q: int,
        seqused_k: torch.Tensor,
        max_seqlen_k: int,
        softmax_scale: float,
        window_size: tuple[int, int],
        block_table: torch.Tensor,
        softcap: float,
        sinks: torch.Tensor | None,
        alibi_slopes: torch.Tensor | None,
        use_alibi_sqrt: bool,
        qq_bias: torch.Tensor | None,
        output_scale: torch.Tensor | None,
        mm_prefix_range: torch.Tensor | None,
        k_scale_cache: torch.Tensor | None = None,
        v_scale_cache: torch.Tensor | None = None,
        seq_threshold_3D: int | None = None,
        num_par_softmax_segments: int | None = None,
        softmax_segm_output: torch.Tensor | None = None,
        softmax_segm_max: torch.Tensor | None = None,
        softmax_segm_expsum: torch.Tensor | None = None,
    ) -> None:
        assert k_scale_cache is not None and v_scale_cache is not None

        q_orig_dtype = q.dtype
        q = self._rotate_q(q.float()).to(q_orig_dtype)
        head_size = q.shape[2]
        softmax_scale = self._transform_softmax_scale(softmax_scale, head_size)

        _launch_packed_attn(
            q=q,
            k_cache=k_cache,
            v_cache=v_cache,
            out=out,
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            seqused_k=seqused_k,
            softmax_scale=softmax_scale,
            window_size=window_size,
            block_table=block_table,
            softcap=softcap,
            sinks=sinks,
            alibi_slopes=alibi_slopes,
            use_alibi_sqrt=use_alibi_sqrt,
            qq_bias=qq_bias,
            output_scale=output_scale,
            mm_prefix_range=mm_prefix_range,
            k_scale_cache=k_scale_cache,
            v_scale_cache=v_scale_cache,
            seq_threshold_3D=seq_threshold_3D,
            num_par_softmax_segments=num_par_softmax_segments,
            softmax_segm_output=softmax_segm_output,
            softmax_segm_max=softmax_segm_max,
            softmax_segm_expsum=softmax_segm_expsum,
            packing_factor=self.packing_factor,
        )

        out_f = self._unrotate_out(out, head_size)
        out.copy_(out_f.to(q_orig_dtype))

_launch_packed_attn

_launch_packed_attn(
    *,
    q,
    k_cache,
    v_cache,
    out,
    cu_seqlens_q,
    max_seqlen_q,
    seqused_k,
    softmax_scale,
    window_size,
    block_table,
    softcap,
    sinks,
    alibi_slopes,
    use_alibi_sqrt,
    qq_bias,
    output_scale,
    mm_prefix_range,
    k_scale_cache,
    v_scale_cache,
    seq_threshold_3D,
    num_par_softmax_segments,
    softmax_segm_output,
    softmax_segm_max,
    softmax_segm_expsum,
    packing_factor: int,
)

Launch _attn_packed for one of the sub-byte modes.

Handles 2D-vs-3D dispatch, placeholder pointers for the unused side of that split, and the trailing reduce_segments pass. Writes into out (directly for 2D; via the segm buffers for 3D).

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
def _launch_packed_attn(
    *,
    q,
    k_cache,
    v_cache,
    out,
    cu_seqlens_q,
    max_seqlen_q,
    seqused_k,
    softmax_scale,
    window_size,
    block_table,
    softcap,
    sinks,
    alibi_slopes,
    use_alibi_sqrt,
    qq_bias,
    output_scale,
    mm_prefix_range,
    k_scale_cache,
    v_scale_cache,
    seq_threshold_3D,
    num_par_softmax_segments,
    softmax_segm_output,
    softmax_segm_max,
    softmax_segm_expsum,
    packing_factor: int,
):
    """Launch ``_attn_packed`` for one of the sub-byte modes.

    Handles 2D-vs-3D dispatch, placeholder pointers for the unused side
    of that split, and the trailing ``reduce_segments`` pass.  Writes
    into ``out`` (directly for 2D; via the segm buffers for 3D).
    """
    import vllm.envs as envs
    from vllm.v1.attention.ops.triton_unified_attention import _get_tile_size

    is_batch_invariant = envs.VLLM_BATCH_INVARIANT

    use_mm_prefix = False
    max_mm_ranges = 0
    if mm_prefix_range is not None:
        assert mm_prefix_range.ndim == 3, (
            f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
        )
        use_mm_prefix = True
        max_mm_ranges = mm_prefix_range.shape[1]

    block_size = v_cache.shape[1]
    num_seqs = len(seqused_k)
    num_query_heads = q.shape[1]
    num_kv_heads = k_cache.shape[2]
    num_queries_per_kv = num_query_heads // num_kv_heads
    head_size = q.shape[2]

    BLOCK_M = (
        16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
    )
    BLOCK_Q = BLOCK_M // num_queries_per_kv
    total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
    sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
    TILE_SIZE_PREFILL = _get_tile_size(
        head_size, sliding_window_val, q.element_size(), is_prefill=True
    )
    TILE_SIZE_DECODE = _get_tile_size(
        head_size, sliding_window_val, q.element_size(), is_prefill=False
    )

    use_3d = not (
        seq_threshold_3D is None
        or num_par_softmax_segments is None
        or softmax_segm_output is None
        or softmax_segm_max is None
        or softmax_segm_expsum is None
        or max_seqlen_q > 1
        or num_seqs > seq_threshold_3D
        or is_batch_invariant
    )

    # 3D never reads ``output_ptr`` and 2D never reads the segm tensors,
    # but Triton needs a non-null pointer everywhere; reuse ``out`` as
    # the placeholder for the unused side.
    segm_output_ptr = softmax_segm_output if use_3d else out
    segm_max_ptr = softmax_segm_max if use_3d else out
    segm_expsum_ptr = softmax_segm_expsum if use_3d else out
    num_segments = num_par_softmax_segments if use_3d else 1

    grid: tuple[Any, ...]
    if use_3d:
        grid = (total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
        tile_size = TILE_SIZE_DECODE
    else:
        grid = (total_num_q_blocks, num_kv_heads)
        tile_size = TILE_SIZE_PREFILL

    _attn_packed[grid](
        output_ptr=out,
        segm_output_ptr=segm_output_ptr,
        segm_max_ptr=segm_max_ptr,
        segm_expsum_ptr=segm_expsum_ptr,
        query_ptr=q,
        key_cache_ptr=k_cache,
        value_cache_ptr=v_cache,
        sink_ptr=sinks,
        block_tables_ptr=block_table,
        seq_lens_ptr=seqused_k,
        alibi_slopes_ptr=alibi_slopes,
        qq_bias_ptr=qq_bias,
        scale=softmax_scale,
        out_scale=1 / output_scale if output_scale is not None else 1.0,
        softcap=softcap,
        k_scale_cache_ptr=k_scale_cache,
        v_scale_cache_ptr=v_scale_cache,
        num_query_heads=num_query_heads,
        num_queries_per_kv=num_queries_per_kv,
        block_table_stride=block_table.stride(0),
        query_stride_0=q.stride(0),
        query_stride_1=q.stride(1),
        output_stride_0=out.stride(0),
        output_stride_1=out.stride(1),
        qq_bias_stride_0=qq_bias.stride(0) if qq_bias is not None else 0,
        BLOCK_SIZE=block_size,
        TILE_SIZE=tile_size,
        HEAD_SIZE=head_size,
        HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
        PACKED_HEAD_PADDED=triton.next_power_of_2(head_size) // packing_factor,
        USE_ALIBI_SLOPES=alibi_slopes is not None,
        USE_ALIBI_SQRT=use_alibi_sqrt,
        USE_QQ_BIAS=qq_bias is not None,
        USE_SOFTCAP=(softcap > 0),
        USE_SINKS=(sinks is not None),
        SLIDING_WINDOW=(1 + window_size[0]),
        USE_MM_PREFIX=use_mm_prefix,
        MAX_MM_RANGES=max_mm_ranges,
        mm_prefix_range_ptr=mm_prefix_range,
        stride_k_cache_0=k_cache.stride(0),
        stride_k_cache_1=k_cache.stride(1),
        stride_k_cache_2=k_cache.stride(2),
        stride_k_cache_3=k_cache.stride(3),
        stride_v_cache_0=v_cache.stride(0),
        stride_v_cache_1=v_cache.stride(1),
        stride_v_cache_2=v_cache.stride(2),
        stride_v_cache_3=v_cache.stride(3),
        stride_ks_blk=k_scale_cache.stride(0),
        stride_ks_slot=k_scale_cache.stride(1),
        stride_ks_head=k_scale_cache.stride(2),
        stride_vs_blk=v_scale_cache.stride(0),
        stride_vs_slot=v_scale_cache.stride(1),
        stride_vs_head=v_scale_cache.stride(2),
        query_start_len_ptr=cu_seqlens_q,
        BLOCK_Q=BLOCK_Q,
        num_seqs=num_seqs,
        BLOCK_M=BLOCK_M,
        NUM_SEGMENTS_PER_SEQ=num_segments,
        USE_FP8=output_scale is not None,
        IS_3D=use_3d,
        PACKING_FACTOR=packing_factor,
    )

    if use_3d:
        reduce_segments[(q.shape[0], num_query_heads)](
            output_ptr=out,
            segm_output_ptr=softmax_segm_output,
            segm_max_ptr=softmax_segm_max,
            segm_expsum_ptr=softmax_segm_expsum,
            seq_lens_ptr=seqused_k,
            num_seqs=num_seqs,
            num_query_heads=num_query_heads,
            out_scale_inv=1 / output_scale if output_scale is not None else 1.0,
            output_stride_0=out.stride(0),
            output_stride_1=out.stride(1),
            block_table_stride=block_table.stride(0),
            TILE_SIZE=TILE_SIZE_DECODE,
            HEAD_SIZE=head_size,
            HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
            query_start_len_ptr=cu_seqlens_q,
            BLOCK_Q=BLOCK_Q,
            NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
            USE_FP8=output_scale is not None,
        )

_lloyd_max_dequant_4

_lloyd_max_dequant_4(idx)

Look up INT2 Lloyd-Max centroid for N(0,1). idx in [0..3].

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
@triton.jit
def _lloyd_max_dequant_4(idx):
    """Look up INT2 Lloyd-Max centroid for N(0,1).  idx in [0..3]."""
    return tl.where(
        idx < 2,
        tl.where(idx == 0, -1.5104, -0.4528),
        tl.where(idx == 2, 0.4528, 1.5104),
    )

_lloyd_max_quantize_4

_lloyd_max_quantize_4(z)

Quantize N(0,1) values to 4 Lloyd-Max centroids (INT2).

Returns index in [0, 3]. Boundaries: [-0.9816, 0, 0.9816].

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
@triton.jit
def _lloyd_max_quantize_4(z):
    """Quantize N(0,1) values to 4 Lloyd-Max centroids (INT2).

    Returns index in [0, 3].  Boundaries: [-0.9816, 0, 0.9816].
    """
    return tl.where(
        z < 0.0,
        tl.where(z < -0.9816, 0, 1).to(tl.uint8),
        tl.where(z < 0.9816, 2, 3).to(tl.uint8),
    )

_reshape_cache_int2_kernel

_reshape_cache_int2_kernel(
    key_ptr,
    value_ptr,
    key_cache_ptr,
    value_cache_ptr,
    k_scale_cache_ptr,
    v_scale_cache_ptr,
    slot_mapping_ptr,
    stride_key_tok: int64,
    stride_key_head: int64,
    stride_val_tok: int64,
    stride_val_head: int64,
    stride_kc_blk: int64,
    stride_kc_slot: int64,
    stride_kc_head: int64,
    stride_vc_blk: int64,
    stride_vc_slot: int64,
    stride_vc_head: int64,
    stride_ks_blk: int64,
    stride_ks_slot: int64,
    stride_ks_head: int64,
    stride_vs_blk: int64,
    stride_vs_slot: int64,
    stride_vs_head: int64,
    block_size: constexpr,
    head_size: constexpr,
    head_size_v: constexpr,
    PACKED_HEAD_PADDED: constexpr,
)

INT2 Hadamard + Lloyd-Max 4-centroid quantization.

Packs 4 × 2-bit indices per byte → head_size/4 bytes per head.

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
@triton.jit
def _reshape_cache_int2_kernel(
    key_ptr,
    value_ptr,
    key_cache_ptr,
    value_cache_ptr,
    k_scale_cache_ptr,
    v_scale_cache_ptr,
    slot_mapping_ptr,
    stride_key_tok: tl.int64,
    stride_key_head: tl.int64,
    stride_val_tok: tl.int64,
    stride_val_head: tl.int64,
    stride_kc_blk: tl.int64,
    stride_kc_slot: tl.int64,
    stride_kc_head: tl.int64,
    stride_vc_blk: tl.int64,
    stride_vc_slot: tl.int64,
    stride_vc_head: tl.int64,
    stride_ks_blk: tl.int64,
    stride_ks_slot: tl.int64,
    stride_ks_head: tl.int64,
    stride_vs_blk: tl.int64,
    stride_vs_slot: tl.int64,
    stride_vs_head: tl.int64,
    block_size: tl.constexpr,
    head_size: tl.constexpr,
    head_size_v: tl.constexpr,
    PACKED_HEAD_PADDED: tl.constexpr,
):
    """INT2 Hadamard + Lloyd-Max 4-centroid quantization.

    Packs 4 × 2-bit indices per byte → head_size/4 bytes per head.
    """
    tok = tl.program_id(0)
    head = tl.program_id(1)

    slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
    if slot < 0:
        return

    blk = slot // block_size
    slot_in_blk = slot % block_size

    qtr_offs = tl.arange(0, PACKED_HEAD_PADDED)
    offs_0 = qtr_offs * 4
    offs_1 = qtr_offs * 4 + 1
    offs_2 = qtr_offs * 4 + 2
    offs_3 = qtr_offs * 4 + 3

    qtr_k = head_size // 4
    mask_0k = offs_0 < head_size
    mask_1k = offs_1 < head_size
    mask_2k = offs_2 < head_size
    mask_3k = offs_3 < head_size
    key_base = key_ptr + tok * stride_key_tok + head * stride_key_head

    k0 = tl.load(key_base + offs_0, mask=mask_0k, other=0.0).to(tl.float32)
    k1 = tl.load(key_base + offs_1, mask=mask_1k, other=0.0).to(tl.float32)
    k2 = tl.load(key_base + offs_2, mask=mask_2k, other=0.0).to(tl.float32)
    k3 = tl.load(key_base + offs_3, mask=mask_3k, other=0.0).to(tl.float32)

    k_sq = (
        tl.sum(tl.where(mask_0k, k0 * k0, 0.0))
        + tl.sum(tl.where(mask_1k, k1 * k1, 0.0))
        + tl.sum(tl.where(mask_2k, k2 * k2, 0.0))
        + tl.sum(tl.where(mask_3k, k3 * k3, 0.0))
    )
    k_norm = tl.sqrt(k_sq + 1e-12)

    k_inv_sigma = tl.sqrt(float(head_size)) / k_norm
    q0 = _lloyd_max_quantize_4(k0 * k_inv_sigma)
    q1 = _lloyd_max_quantize_4(k1 * k_inv_sigma)
    q2 = _lloyd_max_quantize_4(k2 * k_inv_sigma)
    q3 = _lloyd_max_quantize_4(k3 * k_inv_sigma)

    k_packed = pack_int2_quartet(q0, q1, q2, q3)
    tl.store(
        key_cache_ptr
        + blk * stride_kc_blk
        + slot_in_blk * stride_kc_slot
        + head * stride_kc_head
        + qtr_offs,
        k_packed,
        mask=qtr_offs < qtr_k,
    )

    # Store norm/d^1.5 as scale; see module docstring for the math.
    k_scale = k_norm / float(head_size**1.5)
    tl.store(
        k_scale_cache_ptr
        + blk * stride_ks_blk
        + slot_in_blk * stride_ks_slot
        + head * stride_ks_head,
        k_scale,
    )

    qtr_v = head_size_v // 4
    mask_0v = offs_0 < head_size_v
    mask_1v = offs_1 < head_size_v
    mask_2v = offs_2 < head_size_v
    mask_3v = offs_3 < head_size_v
    val_base = value_ptr + tok * stride_val_tok + head * stride_val_head

    v0 = tl.load(val_base + offs_0, mask=mask_0v, other=0.0).to(tl.float32)
    v1 = tl.load(val_base + offs_1, mask=mask_1v, other=0.0).to(tl.float32)
    v2 = tl.load(val_base + offs_2, mask=mask_2v, other=0.0).to(tl.float32)
    v3 = tl.load(val_base + offs_3, mask=mask_3v, other=0.0).to(tl.float32)

    v_sq = (
        tl.sum(tl.where(mask_0v, v0 * v0, 0.0))
        + tl.sum(tl.where(mask_1v, v1 * v1, 0.0))
        + tl.sum(tl.where(mask_2v, v2 * v2, 0.0))
        + tl.sum(tl.where(mask_3v, v3 * v3, 0.0))
    )
    v_norm = tl.sqrt(v_sq + 1e-12)
    v_inv_sigma = tl.sqrt(float(head_size_v)) / v_norm
    vq0 = _lloyd_max_quantize_4(v0 * v_inv_sigma)
    vq1 = _lloyd_max_quantize_4(v1 * v_inv_sigma)
    vq2 = _lloyd_max_quantize_4(v2 * v_inv_sigma)
    vq3 = _lloyd_max_quantize_4(v3 * v_inv_sigma)

    v_packed = pack_int2_quartet(vq0, vq1, vq2, vq3)
    tl.store(
        value_cache_ptr
        + blk * stride_vc_blk
        + slot_in_blk * stride_vc_slot
        + head * stride_vc_head
        + qtr_offs,
        v_packed,
        mask=qtr_offs < qtr_v,
    )

    v_scale = v_norm / float(head_size_v**1.5)
    tl.store(
        v_scale_cache_ptr
        + blk * stride_vs_blk
        + slot_in_blk * stride_vs_slot
        + head * stride_vs_head,
        v_scale,
    )

_reshape_cache_int4_kernel

_reshape_cache_int4_kernel(
    key_ptr,
    value_ptr,
    key_cache_ptr,
    value_cache_ptr,
    k_scale_cache_ptr,
    v_scale_cache_ptr,
    slot_mapping_ptr,
    stride_key_tok: int64,
    stride_key_head: int64,
    stride_val_tok: int64,
    stride_val_head: int64,
    stride_kc_blk: int64,
    stride_kc_slot: int64,
    stride_kc_head: int64,
    stride_vc_blk: int64,
    stride_vc_slot: int64,
    stride_vc_head: int64,
    stride_ks_blk: int64,
    stride_ks_slot: int64,
    stride_ks_head: int64,
    stride_vs_blk: int64,
    stride_vs_slot: int64,
    stride_vs_head: int64,
    block_size: constexpr,
    head_size: constexpr,
    head_size_v: constexpr,
    PACKED_HEAD_PADDED: constexpr,
)

INT4 asymmetric quantization with zero-point steganography.

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
@triton.jit
def _reshape_cache_int4_kernel(
    key_ptr,
    value_ptr,
    key_cache_ptr,
    value_cache_ptr,
    k_scale_cache_ptr,
    v_scale_cache_ptr,
    slot_mapping_ptr,
    stride_key_tok: tl.int64,
    stride_key_head: tl.int64,
    stride_val_tok: tl.int64,
    stride_val_head: tl.int64,
    stride_kc_blk: tl.int64,
    stride_kc_slot: tl.int64,
    stride_kc_head: tl.int64,
    stride_vc_blk: tl.int64,
    stride_vc_slot: tl.int64,
    stride_vc_head: tl.int64,
    stride_ks_blk: tl.int64,
    stride_ks_slot: tl.int64,
    stride_ks_head: tl.int64,
    stride_vs_blk: tl.int64,
    stride_vs_slot: tl.int64,
    stride_vs_head: tl.int64,
    block_size: tl.constexpr,
    head_size: tl.constexpr,
    head_size_v: tl.constexpr,
    PACKED_HEAD_PADDED: tl.constexpr,
):
    """INT4 asymmetric quantization with zero-point steganography."""
    tok = tl.program_id(0)
    head = tl.program_id(1)

    slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
    if slot < 0:
        return

    blk = slot // block_size
    slot_in_blk = slot % block_size

    half_offs = tl.arange(0, PACKED_HEAD_PADDED)
    even_offs = half_offs * 2
    odd_offs = half_offs * 2 + 1

    half_k = head_size // 2
    even_k_mask = even_offs < head_size
    odd_k_mask = odd_offs < head_size
    key_base = key_ptr + tok * stride_key_tok + head * stride_key_head

    k_even = tl.load(key_base + even_offs, mask=even_k_mask, other=0.0).to(tl.float32)
    k_odd = tl.load(key_base + odd_offs, mask=odd_k_mask, other=0.0).to(tl.float32)

    k_min = tl.minimum(
        tl.min(tl.where(even_k_mask, k_even, float("inf"))),
        tl.min(tl.where(odd_k_mask, k_odd, float("inf"))),
    )
    k_max = tl.maximum(
        tl.max(tl.where(even_k_mask, k_even, float("-inf"))),
        tl.max(tl.where(odd_k_mask, k_odd, float("-inf"))),
    )
    k_scale = tl.maximum((k_max - k_min) / 15.0, 1e-6)
    k_zp_f = tl.clamp(
        tl.where(
            -k_min / k_scale >= 0,
            (-k_min / k_scale + 0.5).to(tl.int32),
            (-k_min / k_scale - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    inv_k = 1.0 / k_scale
    k_even_s = k_even * inv_k + k_zp_f
    k_odd_s = k_odd * inv_k + k_zp_f
    k_even_q = tl.clamp(
        tl.where(
            k_even_s >= 0,
            (k_even_s + 0.5).to(tl.int32),
            (k_even_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )
    k_odd_q = tl.clamp(
        tl.where(
            k_odd_s >= 0,
            (k_odd_s + 0.5).to(tl.int32),
            (k_odd_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    k_zp_int = k_zp_f.to(tl.int32)
    k_scale_bits = k_scale.to(tl.int32, bitcast=True)
    k_scale_packed = ((k_scale_bits & -16) | (k_zp_int & 0xF)).to(
        tl.float32, bitcast=True
    )

    tl.store(
        k_scale_cache_ptr
        + blk * stride_ks_blk
        + slot_in_blk * stride_ks_slot
        + head * stride_ks_head,
        k_scale_packed,
    )

    k_packed = pack_int4_nibbles(k_even_q.to(tl.uint8), k_odd_q.to(tl.uint8))
    tl.store(
        key_cache_ptr
        + blk * stride_kc_blk
        + slot_in_blk * stride_kc_slot
        + head * stride_kc_head
        + half_offs,
        k_packed,
        mask=half_offs < half_k,
    )

    half_v = head_size_v // 2
    even_v_mask = even_offs < head_size_v
    odd_v_mask = odd_offs < head_size_v
    val_base = value_ptr + tok * stride_val_tok + head * stride_val_head

    v_even = tl.load(val_base + even_offs, mask=even_v_mask, other=0.0).to(tl.float32)
    v_odd = tl.load(val_base + odd_offs, mask=odd_v_mask, other=0.0).to(tl.float32)

    v_min = tl.minimum(
        tl.min(tl.where(even_v_mask, v_even, float("inf"))),
        tl.min(tl.where(odd_v_mask, v_odd, float("inf"))),
    )
    v_max = tl.maximum(
        tl.max(tl.where(even_v_mask, v_even, float("-inf"))),
        tl.max(tl.where(odd_v_mask, v_odd, float("-inf"))),
    )
    v_scale = tl.maximum((v_max - v_min) / 15.0, 1e-6)
    v_zp_f = tl.clamp(
        tl.where(
            -v_min / v_scale >= 0,
            (-v_min / v_scale + 0.5).to(tl.int32),
            (-v_min / v_scale - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    inv_v = 1.0 / v_scale
    v_even_s = v_even * inv_v + v_zp_f
    v_odd_s = v_odd * inv_v + v_zp_f
    v_even_q = tl.clamp(
        tl.where(
            v_even_s >= 0,
            (v_even_s + 0.5).to(tl.int32),
            (v_even_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )
    v_odd_q = tl.clamp(
        tl.where(
            v_odd_s >= 0,
            (v_odd_s + 0.5).to(tl.int32),
            (v_odd_s - 0.5).to(tl.int32),
        ).to(tl.float32),
        0.0,
        15.0,
    )

    v_zp_int = v_zp_f.to(tl.int32)
    v_scale_bits = v_scale.to(tl.int32, bitcast=True)
    v_scale_packed = ((v_scale_bits & -16) | (v_zp_int & 0xF)).to(
        tl.float32, bitcast=True
    )

    tl.store(
        v_scale_cache_ptr
        + blk * stride_vs_blk
        + slot_in_blk * stride_vs_slot
        + head * stride_vs_head,
        v_scale_packed,
    )

    v_packed = pack_int4_nibbles(v_even_q.to(tl.uint8), v_odd_q.to(tl.uint8))
    tl.store(
        value_cache_ptr
        + blk * stride_vc_blk
        + slot_in_blk * stride_vc_slot
        + head * stride_vc_head
        + half_offs,
        v_packed,
        mask=half_offs < half_v,
    )

_run_reshape_kernel

_run_reshape_kernel(
    kernel,
    *,
    key: Tensor,
    value: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    k_scale_cache: Tensor,
    v_scale_cache: Tensor,
    slot_mapping: Tensor,
    packing_factor: int,
) -> None

Launch a packed reshape kernel (INT4 or INT2).

Source code in vllm/v1/attention/ops/triton_quant_kv/packed_per_token_head.py
def _run_reshape_kernel(
    kernel,
    *,
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    k_scale_cache: torch.Tensor,
    v_scale_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    packing_factor: int,
) -> None:
    """Launch a packed reshape kernel (INT4 or INT2)."""
    num_tokens, num_kv_heads, head_size = key.shape
    head_size_v = value.shape[2]
    assert head_size % packing_factor == 0 and head_size_v % packing_factor == 0
    packed_padded = triton.next_power_of_2(
        max(head_size, head_size_v) // packing_factor
    )
    if current_platform.is_rocm() or current_platform.is_xpu():
        num_warps = 4
    else:
        num_warps = min(16, max(1, packed_padded // 32))

    kernel[(num_tokens, num_kv_heads)](
        key_ptr=key,
        value_ptr=value,
        key_cache_ptr=key_cache,
        value_cache_ptr=value_cache,
        k_scale_cache_ptr=k_scale_cache,
        v_scale_cache_ptr=v_scale_cache,
        slot_mapping_ptr=slot_mapping,
        stride_key_tok=key.stride(0),
        stride_key_head=key.stride(1),
        stride_val_tok=value.stride(0),
        stride_val_head=value.stride(1),
        stride_kc_blk=key_cache.stride(0),
        stride_kc_slot=key_cache.stride(1),
        stride_kc_head=key_cache.stride(2),
        stride_vc_blk=value_cache.stride(0),
        stride_vc_slot=value_cache.stride(1),
        stride_vc_head=value_cache.stride(2),
        stride_ks_blk=k_scale_cache.stride(0),
        stride_ks_slot=k_scale_cache.stride(1),
        stride_ks_head=k_scale_cache.stride(2),
        stride_vs_blk=v_scale_cache.stride(0),
        stride_vs_slot=v_scale_cache.stride(1),
        stride_vs_head=v_scale_cache.stride(2),
        block_size=key_cache.shape[1],
        head_size=head_size,
        head_size_v=head_size_v,
        PACKED_HEAD_PADDED=packed_padded,
        num_warps=num_warps,
    )