Skip to content

vllm.v1.attention.ops.triton_unified_attention

_cast_kv_tile

_cast_kv_tile(
    data, Q, tensor_scale, KV_QUANT_MODE: constexpr
)

Cast a loaded KV tile to Q's dtype, dequantizing if needed.

Modes handled inside the core kernel:

  • KV_QUANT_MODE == 0 (NONE) and 2 (INT8 per-token-head) and 3 (FP8 per-token-head): plain cast. Per-token-head modes apply their scales separately on S/P inside the loop.
  • KV_QUANT_MODE == 1 (FP8 per-tensor): dequantize using the tensor-wide scale.

Sub-byte packed modes (INT4 / INT2) are dispatched to their own backends in :mod:vllm.v1.attention.ops.triton_quant_kv and never reach this kernel.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
@triton.jit
def _cast_kv_tile(data, Q, tensor_scale, KV_QUANT_MODE: tl.constexpr):
    """Cast a loaded KV tile to Q's dtype, dequantizing if needed.

    Modes handled inside the core kernel:

    - ``KV_QUANT_MODE == 0`` (NONE) and ``2`` (INT8 per-token-head) and
      ``3`` (FP8 per-token-head): plain cast.  Per-token-head modes apply
      their scales separately on S/P inside the loop.
    - ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize using the
      tensor-wide scale.

    Sub-byte packed modes (INT4 / INT2) are dispatched to their own
    backends in :mod:`vllm.v1.attention.ops.triton_quant_kv` and never
    reach this kernel.
    """
    if KV_QUANT_MODE == 1:
        if Q.dtype.is_fp8():
            return data.to(Q.dtype)
        return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype)
    return data.to(Q.dtype)

_get_tile_size

_get_tile_size(
    head_size: int,
    sliding_window: int,
    element_size: int,
    is_prefill: bool,
) -> int

Select tile size with Gemma3-specific optimization.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
def _get_tile_size(
    head_size: int,
    sliding_window: int,
    element_size: int,
    is_prefill: bool,
) -> int:
    """Select tile size with Gemma3-specific optimization."""
    if _is_gemma3_attention(head_size, sliding_window):
        return 32
    if is_prefill:
        return 32
    return 16 if element_size >= 2 else 32

_is_gemma3_attention

_is_gemma3_attention(
    head_size: int, sliding_window: int
) -> bool

Detect Gemma3 models via unique (head_size, sliding_window) signature.

Gemma3 models are the only ones using sliding_window=1024 with head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use different window sizes (Mistral=4096, Phi-3=2047).

Source code in vllm/v1/attention/ops/triton_unified_attention.py
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
    """Detect Gemma3 models via unique (head_size, sliding_window) signature.

    Gemma3 models are the only ones using sliding_window=1024 with
    head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
    different window sizes (Mistral=4096, Phi-3=2047).
    """
    return sliding_window == 1024 and head_size in (128, 256)