mstar.utils.flashinfer_utils#

FlashInfer utility wrappers for batched paged attention.

Provides: - run_rms_norm / run_attention: simple single-request helpers - FlashInferPrefillWrapper: batched prefill with paged KV cache, optional CUDA graph mode - FlashInferDecodeWrapper: batched decode with paged KV cache, optional CUDA graph mode

CUDA graph mode requires: - Static buffer pointers passed at construction (qo_indptr_buf, paged_kv_indptr_buf, etc.) - plan() updates values via .copy_() without reallocating - The same wrapper object must be used during both capture and replay

Adapted from VoxServe’s flashinfer_utils.py for our KV cache layout:

[num_layers, max_pages, 2, page_size, num_kv_heads, head_dim]

(VoxServe uses [n_pages, 2, page_size, n_heads, head_dim] without layer dim.)

Functions

run_attention(q, k, v[, scale, causal])

run_rms_norm(input, weight[, eps, ...])

Classes

FlashInferDecodeWrapper(workspace_buffer, ...)

Batched decode attention with paged KV cache.

FlashInferPrefillWrapper(workspace_buffer, ...)

Batched prefill attention with paged KV cache.

class mstar.utils.flashinfer_utils.FlashInferDecodeWrapper(workspace_buffer, num_qo_heads, num_kv_heads, head_dim, page_size, batch_size=None, max_num_pages=None, device=torch.device('cuda'), use_cuda_graph=False, enable_nvtx=False)[source]#

Bases: object

Batched decode attention with paged KV cache.

Optimized for the common decode case where each request appends exactly 1 new token. Uses BatchDecodeWithPagedKVCacheWrapper.

Parameters:
  • workspace_buffer (Tensor) – FlashInfer workspace

  • num_qo_heads (int) – number of query/output heads

  • num_kv_heads (int) – number of key/value heads

  • head_dim (int) – dimension per head

  • page_size (int) – KV cache page size

  • batch_size (int | None) – required for CUDA graph mode (max requests in batch)

  • max_num_pages (int | None) – required for CUDA graph mode (max pages across all requests)

  • device (device) – torch device

  • use_cuda_graph (bool) – if True, pre-allocate static buffers for graph capture

  • enable_nvtx (bool)

plan(paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, kv_cache_locations=None, dtype=torch.bfloat16)[source]#

Plan decode attention and compute KV write locations.

For decode, each request appends exactly 1 token. The write location is the last page at position = last_page_len (before the append; after append it becomes last_page_len).

Inputs may be on CPU; see prefill wrapper’s plan docstring.

Parameters:
run(q, kv_cache_layer)[source]#

Run planned batched decode attention.

Parameters:
  • q (Tensor) – [n_req, num_qo_heads, head_dim]

  • kv_cache_layer (Tensor) – [max_pages, 2, page_size, num_kv_heads, head_dim]

Returns:

[n_req, num_qo_heads, head_dim]

Return type:

output

set_kv_cache(kv_cache_layer, k, v)[source]#

Write K, V for decode (1 token per request).

Parameters:
  • kv_cache_layer (Tensor) – [max_pages, 2, page_size, num_kv_heads, head_dim]

  • k (Tensor) – [n_req, num_kv_heads, head_dim]

  • v (Tensor) – [n_req, num_kv_heads, head_dim]

class mstar.utils.flashinfer_utils.FlashInferPrefillWrapper(workspace_buffer, num_qo_heads, num_kv_heads, head_dim, page_size, batch_size=None, max_total_tokens=None, max_num_pages=None, device=torch.device('cuda'), use_cuda_graph=False, enable_nvtx=False)[source]#

Bases: object

Batched prefill attention with paged KV cache.

Wraps flashinfer.BatchPrefillWithPagedKVCacheWrapper with: - Pre-computed token_to_page / token_to_cache for vectorized KV writes - Optional CUDA graph mode with static buffers

Parameters:
  • workspace_buffer (Tensor) – FlashInfer workspace (256MB+ recommended)

  • num_qo_heads (int) – number of query/output heads

  • num_kv_heads (int) – number of key/value heads

  • head_dim (int) – dimension per head

  • page_size (int) – KV cache page size

  • batch_size (int | None) – required for CUDA graph mode (max requests in batch)

  • max_total_tokens (int | None) – required for CUDA graph mode (max total new tokens across batch)

  • max_num_pages (int | None) – required for CUDA graph mode (max pages across all requests)

  • device (device) – torch device

  • use_cuda_graph (bool) – if True, pre-allocate static buffers for graph capture

  • enable_nvtx (bool)

plan(qo_indptr, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len, causal=True, dtype=torch.bfloat16)[source]#

Plan attention and compute KV write indices.

In CUDA graph mode, updates static buffers via .copy_() so that the same GPU addresses are used during graph replay.

Inputs may be on CPU — that’s preferred because FlashInfer’s BatchPrefillWithPagedKVCacheWrapper.plan does indptr.to("cpu") / last_page_len.to("cpu") internally; passing GPU tensors there triggers a synchronous default-stream sync that drains the speculatively-queued next decode step. We let the inner plan consume them as CPU and async-H2D copy to the device for our own per-token bookkeeping below.

Parameters:
run(q, kv_cache_layer)[source]#

Run planned batched prefill attention.

Parameters:
  • q (Tensor) – [total_tokens, num_qo_heads, head_dim]

  • kv_cache_layer (Tensor) – [max_pages, 2, page_size, num_kv_heads, head_dim] (single layer slice of the full KV cache)

Returns:

[total_tokens, num_qo_heads, head_dim]

Return type:

output

set_kv_cache(kv_cache_layer, k, v)[source]#

Write K, V to the paged KV cache at pre-computed positions.

Parameters:
  • kv_cache_layer (Tensor) – [max_pages, 2, page_size, num_kv_heads, head_dim]

  • k (Tensor) – [total_tokens, num_kv_heads, head_dim]

  • v (Tensor) – [total_tokens, num_kv_heads, head_dim]

mstar.utils.flashinfer_utils.run_attention(q, k, v, scale=1.0, causal=True)[source]#
Parameters:
mstar.utils.flashinfer_utils.run_rms_norm(input, weight, eps=1e-06, rms_norm_dtype=None)[source]#
Parameters: