整体架构图#

┌─────────────────────────────────────────────────────────────────┐
│                         用户入口层                               │
│  example.py → LLM.generate() → add_request() + step() 循环      │
└─────────────────────────────────────────────────────────────────┘
                                │
                                ▼
┌─────────────────────────────────────────────────────────────────┐
│                         引擎层 (Engine)                          │
│  ┌─────────────┐   ┌─────────────┐   ┌──────────────────┐      │
│  │  Scheduler  │ → │ ModelRunner │ → │  BlockManager    │      │
│  │  (调度器)    │   │ (模型执行器) │   │  (KV Cache管理)  │      │
│  └─────────────┘   └─────────────┘   └──────────────────┘      │
└─────────────────────────────────────────────────────────────────┘
                                │
                                ▼
┌─────────────────────────────────────────────────────────────────┐
│                         模型层 (Models)                          │
│  Qwen3ForCausalLM → Qwen3Model → [Qwen3DecoderLayer x N]       │
└─────────────────────────────────────────────────────────────────┘
                                │
                                ▼
┌─────────────────────────────────────────────────────────────────┐
│                         算子层 (Layers)                          │
│  Attention │ Linear │ LayerNorm │ RoPE │ Activation │ Sampler  │
└─────────────────────────────────────────────────────────────────┘

推理服务流程详解#

阶段 1: 初始化阶段#

# example.py
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)

调用链:

LLM.__init__() 
  → Config.__post_init__()        # 加载 HF config,设置参数
  → ModelRunner.__init__()        # 核心初始化
       → dist.init_process_group()   # 初始化分布式环境
       → Qwen3ForCausalLM()          # 构建模型结构
       → load_model()                # 加载权重
       → Sampler()                   # 初始化采样器
       → warmup_model()              # 预热模型
       → allocate_kv_cache()         # 分配 KV Cache
       → capture_cudagraph()         # 捕获 CUDA Graph (可选)
  → Scheduler.__init__()          # 初始化调度器
       → BlockManager.__init__()     # 初始化块管理器

显存分配#

关键代码:

    def allocate_kv_cache(self):
        config = self.config
        hf_config = config.hf_config
        free, total = torch.cuda.mem_get_info()
        used = total - free
        peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
        current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
        num_kv_heads = hf_config.num_key_value_heads // self.world_size
        head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
        block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
        config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
        assert config.num_kvcache_blocks > 0
        self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
        layer_id = 0
        for module in self.model.modules():
            if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
                module.k_cache = self.kv_cache[0, layer_id]
                module.v_cache = self.kv_cache[1, layer_id]
                layer_id += 1
计算可用显存#
# 获取 GPU 显存信息
free, total = torch.cuda.mem_get_info()      # free=当前空闲, total=总显存
used = total - free                           # 已使用显存

# 获取 PyTorch 内存统计
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]     # 历史峰值
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]  # 当前占用

used 表示 GPU 上所有被占用的显存,包含但不限于:

  • PyTorch 分配的张量 (current)
  • PyTorch 的缓存池(已分配但暂未使用)
  • CUDA context 开销
  • cuDNN/cuBLAS 工作空间
  • 其他 CUDA 库的内存

peak 和 current 是 PyTorch 分配器统计的内存。

为什么需要 peak 和 current?

  • 在调用 allocate_kv_cache 之前,已经执行过 warmup_model() 并做了 torch.cuda.empty_cache()
  • peak 记录了 warmup 时的峰值内存(推理时可能达到的最大值)
  • current 是当前模型参数占用的内存
  • 公式 peak - current 表示推理过程中的额外临时内存开销
# 计算每个 block 的字节数
block_bytes = 2 * num_hidden_layers * block_size * num_kv_heads * head_dim * dtype_size
#             │   │                   │            │              │          │
#             │   │                   │            │              │          └─ 数据类型大小(如 bf16=2)
#             │   │                   │            │              └─ 每个 head 的维度
#             │   │                   │            └─ KV head 数量(GQA时可能小于Q)
#             │   │                   └─ 每个 block 存储的 token 数(默认256)
#             │   └─ 模型层数
#             └─ K 和 V 两个 cache

# 计算可分配的 block 数量
available_memory = total * gpu_memory_utilization - used - peak + current
#                  │       │                        │      │      │
#                  │       │                        │      │      └─ 加回当前占用(因为它已在used中)
#                  │       │                        │      └─ 减去峰值内存需求
#                  │       │                        └─ 减去已使用内存
#                  │       └─ 显存利用率(默认0.9)
#                  └─ 总显存

num_kvcache_blocks = available_memory // block_bytes
临时内存(peak - current)主要用途#
  • 激活值 (Activations)

模型前向传播时每层的中间输出:

Input → Embedding → [Layer 1] → [Layer 2] → ... → [Layer N] → Output
                      ↓            ↓                 ↓
                   激活值1       激活值2           激活值N

主要包括:

  • Embedding 输出:[batch, seq_len, hidden_size]
  • 每层 Attention 的 Q/K/V:[batch, seq_len, num_heads, head_dim]
  • Attention 输出:[batch, seq_len, hidden_size]
  • MLP 中间层:[batch, seq_len, intermediate_size](通常是 hidden_size 的 4 倍)
  • LayerNorm 输出

  • Attention 计算的临时张量

Flash Attention 内部也需要工作空间:

# 在 attention.py 中
o = flash_attn_varlen_func(q, k, v, ...)

包括:

  • Softmax 的中间结果
  • 分块计算时的临时缓冲区
  • Flash Attention 的 O(N) 额外空间(相比传统 O(N²) 已大幅减少)

  • Linear 层的计算
# 矩阵乘法 Y = X @ W
# X: [batch * seq_len, hidden_size]
# W: [hidden_size, output_size]
# Y: [batch * seq_len, output_size]  ← 需要分配临时空间

典型大小(以 Qwen3-0.6B 为例):

  • QKV Projection:输出 [batch * seq_len, 3 * hidden_size]
  • MLP gate_up:输出 [batch * seq_len, 2 * intermediate_size]

  • cuBLAS/cuDNN 工作空间

CUDA 库执行 GEMM 等操作时需要的临时内存:

┌─────────────────────────────────────┐
│         cuBLAS Workspace            │
│  - 用于矩阵分块乘法的临时缓冲区       │
│  - 大小取决于矩阵维度和算法选择       │
│  - 通常几十到几百 MB                 │
└─────────────────────────────────────┘

  • Triton Kernel 临时变量

nano-vllm 中的 Triton 内核:

# attention.py 中的 store_kvcache_kernel
@triton.jit
def store_kvcache_kernel(...):
    # Triton 编译时可能分配临时寄存器/共享内存
    key = tl.load(key_ptr + key_offsets)   # 寄存器
    value = tl.load(value_ptr + value_offsets)

  • 采样阶段
# sampler.py
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
    logits = logits.float().div_(temperatures.unsqueeze(dim=1))  # 类型转换可能产生副本
    probs = torch.softmax(logits, dim=-1)   # [batch, vocab_size] 临时张量
    sample_tokens = probs.div_(...)         # Gumbel sampling 临时张量

大小[batch_size, vocab_size],vocab_size 通常很大(如 151,936)


  • 内存占用估算(Prefill 阶段)

以 Qwen3-0.6B(hidden=1024, intermediate=2816, layers=28)处理 4096 tokens 为例:

用途 估算大小
Embedding 输出 4096 × 1024 × 2 = 8 MB
每层 QKV 4096 × 3 × 1024 × 2 = 24 MB
每层 MLP 中间 4096 × 2 × 2816 × 2 = 44 MB
每层 Attention 输出 4096 × 1024 × 2 = 8 MB
28 层累计 ~(24 + 44 + 8) × 28 ≈ 2.1 GB
Logits 4096 × 151936 × 4 = 2.4 GB
cuBLAS workspace ~100-500 MB

注意:PyTorch 会复用内存,实际峰值比简单累加要小。


分配 KV Cache 张量#
self.kv_cache = torch.empty(2, num_hidden_layers, num_kvcache_blocks, block_size, num_kv_heads, head_dim)
#                           │  │                  │                   │           │            │
#                           │  │                  │                   │           │            └─ D: head维度
#                           │  │                  │                   │           └─ H: KV head数
#                           │  │                  │                   └─ T: 每block的token数
#                           │  │                  └─ B: block总数
#                           │  └─ L: 层数
#                           └─ 2: K和V

张量布局: [K/V, Layer, Block, Token, Head, Dim]

例如对于 Qwen3-0.6B (假设参数):

  • 28 层, 4 KV heads, head_dim=128, block_size=256, 假设分配 100 blocks
  • Shape: [2, 28, 100, 256, 4, 128]
  • 大小: 2 * 28 * 100 * 256 * 4 * 128 * 2 bytes ≈ 1.4 GB
挂载 KV Cache#
layer_id = 0
for module in self.model.modules():
    if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
        module.k_cache = self.kv_cache[0, layer_id]  # shape: [num_blocks, block_size, num_kv_heads, head_dim]
        module.v_cache = self.kv_cache[1, layer_id]
        layer_id += 1

这里遍历模型所有模块,找到每层的 Attention 模块(在 attention.py 中初始化时设置了空的 k_cachev_cache),然后将预分配的 KV Cache 切片赋值给它们。

内存计算示意图#
┌─────────────────────────────────────────────────────────────┐
│                      GPU 总显存 (total)                      │
├─────────────────────────────────────────────────────────────┤
│  已使用 (used)  │  空闲 (free)                               │
├─────────────────┴───────────────────────────────────────────┤
│                                                             │
│  ┌─────────────┐  ┌──────────────────┐  ┌───────────────┐  │
│  │ 模型参数     │  │ 推理临时内存      │  │  KV Cache     │  │
│  │ (current)   │  │ (peak - current) │  │  (计算得到)    │  │
│  └─────────────┘  └──────────────────┘  └───────────────┘  │
│                                                             │
│  ←────────────── total * 0.9 (gpu_memory_utilization) ────→ │
└─────────────────────────────────────────────────────────────┘

关键设计点:

  1. Paged 设计: KV Cache 按 block 分配,而非按 sequence 分配,支持动态调度
  2. 统一存储: 所有层的 KV Cache 在一个连续张量中,便于管理
  3. 预分配: 启动时一次性分配,避免运行时内存碎片

阶段 2: 请求添加阶段#

# example.py
outputs = llm.generate(prompts, sampling_params)

调用链:

LLM.generate()
  → add_request(prompt, sampling_params)  # 对每个 prompt
       → tokenizer.encode(prompt)          # 文本转 token ids
       → Sequence(token_ids, sampling_params)  # 创建序列对象
       → scheduler.add(seq)                # 加入等待队列

Sequence 结构:

class Sequence:
    block_size = 256
    counter = count()

    def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
        self.seq_id = next(Sequence.counter)
        self.status = SequenceStatus.WAITING
        self.token_ids = copy(token_ids)
        self.last_token = token_ids[-1]
        self.num_tokens = len(self.token_ids)
        self.num_prompt_tokens = len(token_ids)
        self.num_cached_tokens = 0
        self.block_table = []
        self.temperature = sampling_params.temperature
        self.max_tokens = sampling_params.max_tokens
        self.ignore_eos = sampling_params.ignore_eos

Sequence 对象是 nano-vllm 中用来追踪和管理单个推理请求的核心数据结构。

Sequence 的核心作用#

1. 唯一标识一个推理请求#
self.seq_id = next(Sequence.counter)  # 全局递增的唯一 ID

每个用户 prompt 对应一个 Sequence,通过 seq_id 追踪。

2. 维护 Token 状态#
self.token_ids = [...]           # 所有 tokens(prompt + 已生成)
self.num_prompt_tokens = N       # prompt 长度(固定)
self.num_tokens = M              # 当前总长度(随生成增长)
self.last_token = token_id       # 最新的 token(decode 时只需要这个)

示意图:

token_ids: [prompt_token_1, ..., prompt_token_N, gen_token_1, ..., gen_token_K]
           ←────── num_prompt_tokens ──────────→←── completion_tokens ──→
           ←─────────────────── num_tokens ─────────────────────────────→
3. 追踪生命周期状态#
class SequenceStatus(Enum):
    WAITING = auto()    # 在等待队列,等待被调度
    RUNNING = auto()    # 正在执行推理
    FINISHED = auto()   # 生成完成(遇到 EOS 或达到 max_tokens)

状态转换:

WAITING ──(schedule)──→ RUNNING ──(生成完成)──→ FINISHED
    ↑                      │
    └───(preempt/抢占)─────┘
4. 管理 KV Cache 块映射#
self.block_table = []           # 分配给该序列的 KV Cache block IDs
self.num_cached_tokens = 0      # 已缓存的 token 数(prefix cache 命中时 > 0)

block_table 示意:

block_table = [5, 12, 3, 8]  # 该序列占用的 block 编号

KV Cache 物理布局:
┌───────┬───────┬───────┬───────┬───────┬───────┬─...
│ Blk 0 │ Blk 1 │ Blk 2 │ Blk 3 │ Blk 4 │ Blk 5 │
└───────┴───────┴───────┴───────┴───────┴───────┴─...
                          ↑               ↑
                   该序列的第3块    该序列的第0块

逻辑视图(该序列):
┌───────┬───────┬───────┬───────┐
│ 位置0 │ 位置1 │ 位置2 │ 位置3 │  ← block_table 索引
│ Blk 5 │ Blk 12│ Blk 3 │ Blk 8 │  ← 实际物理 block
└───────┴───────┴───────┴───────┘
5. 携带采样参数#
self.temperature = sampling_params.temperature  # 采样温度
self.max_tokens = sampling_params.max_tokens    # 最大生成长度
self.ignore_eos = sampling_params.ignore_eos    # 是否忽略 EOS

每个请求可以有不同的采样参数。

Sequence 在推理流程中的使用#

┌─────────────────────────────────────────────────────────────────┐
│                         推理流程                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. 创建 Sequence                                               │
│     add_request(prompt) → Sequence(token_ids, sampling_params)  │
│                                                                 │
│  2. 调度                                                        │
│     scheduler.schedule() → 选择哪些 Sequence 参与本次推理        │
│     - 读取 seq.status, len(seq), seq.block_table               │
│     - 分配/释放 KV Cache blocks                                 │
│                                                                 │
│  3. 准备输入                                                    │
│     prepare_prefill(seqs) / prepare_decode(seqs)               │
│     - Prefill: 读取 seq.token_ids[num_cached_tokens:]          │
│     - Decode:  读取 seq.last_token                             │
│     - 使用 seq.block_table 计算 slot_mapping                   │
│                                                                 │
│  4. 后处理                                                      │
│     postprocess(seqs, token_ids)                               │
│     - seq.append_token(new_token)                              │
│     - 检查是否完成 (EOS / max_tokens)                           │
│     - 更新 seq.status                                          │
│                                                                 │
│  5. 返回结果                                                    │
│     outputs[seq.seq_id] = seq.completion_token_ids             │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

功能 对应属性/方法
唯一标识 seq_id
Token 管理 token_ids, last_token, num_tokens
状态追踪 status (WAITING/RUNNING/FINISHED)
KV Cache 映射 block_table, num_cached_tokens
采样控制 temperature, max_tokens, ignore_eos
生成进度 num_completion_tokens, completion_token_ids

阶段 3: 调度阶段 (每次 step)#

# llm_engine.py
def step(self):
    seqs, is_prefill = self.scheduler.schedule()
    # ...

调度逻辑:

    def schedule(self) -> tuple[list[Sequence], bool]:
        # prefill
        scheduled_seqs = []
        num_seqs = 0
        num_batched_tokens = 0
        while self.waiting and num_seqs < self.max_num_seqs:
            seq = self.waiting[0]
            if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
                break
            num_seqs += 1
            self.block_manager.allocate(seq)
            num_batched_tokens += len(seq) - seq.num_cached_tokens
            seq.status = SequenceStatus.RUNNING
            self.waiting.popleft()
            self.running.append(seq)
            scheduled_seqs.append(seq)
        if scheduled_seqs:
            return scheduled_seqs, True

        # decode
        while self.running and num_seqs < self.max_num_seqs:
            seq = self.running.popleft()
            while not self.block_manager.can_append(seq):
                if self.running:
                    self.preempt(self.running.pop())
                else:
                    self.preempt(seq)
                    break
            else:
                num_seqs += 1
                self.block_manager.may_append(seq)
                scheduled_seqs.append(seq)
        assert scheduled_seqs
        self.running.extendleft(reversed(scheduled_seqs))
        return scheduled_seqs, False

调度策略:

  • 优先 Prefill:等待队列有请求时,优先处理 prefill 阶段
  • 批量限制max_num_seqsmax_num_batched_tokens 限制批大小
  • 抢占机制:KV Cache 不足时,通过 preempt() 回收正在运行的序列

调度器核心数据结构#

class Scheduler:
    def __init__(self, config: Config):
        self.waiting: deque[Sequence] = deque()  # 等待 prefill 的序列
        self.running: deque[Sequence] = deque()  # 正在 decode 的序列
        self.block_manager = BlockManager(...)    # KV Cache 管理
        self.max_num_seqs = 512                   # 最大并发序列数
        self.max_num_batched_tokens = 16384      # 每批最大 token 数

调度策略:Prefill 优先#

schedule() 返回: (seqs: list[Sequence], is_prefill: bool)

核心原则

  1. Prefill 优先:有等待的新请求时,优先处理 prefill
  2. Prefill 和 Decode 不混合:每次调度只返回一种类型
  3. 资源约束:受 max_num_seqsmax_num_batched_tokens 限制
Prefill 调度#
# prefill 阶段
scheduled_seqs = []
num_seqs = 0
num_batched_tokens = 0

while self.waiting and num_seqs < self.max_num_seqs:
    seq = self.waiting[0]  # 查看队首(不弹出)
    
    # 检查约束条件
    if num_batched_tokens + len(seq) > self.max_num_batched_tokens:
        break  # token 数量超限
    if not self.block_manager.can_allocate(seq):
        break  # KV Cache 空间不足
    
    # 可以调度该序列
    num_seqs += 1
    self.block_manager.allocate(seq)           # 分配 KV Cache blocks
    num_batched_tokens += len(seq) - seq.num_cached_tokens  # 累计 token 数
    seq.status = SequenceStatus.RUNNING
    self.waiting.popleft()                     # 从等待队列移除
    self.running.append(seq)                   # 加入运行队列
    scheduled_seqs.append(seq)

if scheduled_seqs:
    return scheduled_seqs, True  # is_prefill = True

Prefill 调度流程图:

waiting 队列: [Seq A (len=1000), Seq B (len=2000), Seq C (len=500)]
              ↓
              检查 Seq A
              ├── tokens + 1000 <= 16384? ✓
              ├── can_allocate(A)? ✓
              └── 调度 A, tokens = 1000
              ↓
              检查 Seq B  
              ├── tokens + 2000 = 3000 <= 16384? ✓
              ├── can_allocate(B)? ✓
              └── 调度 B, tokens = 3000
              ↓
              检查 Seq C
              ├── tokens + 500 = 3500 <= 16384? ✓
              ├── can_allocate(C)? ✗ (KV Cache 不足)
              └── 停止调度
              ↓
返回: ([Seq A, Seq B], is_prefill=True)
Decode 调度#

只有当 waiting 队列为空无法调度任何 prefill 时才进入 decode:

# decode 阶段
while self.running and num_seqs < self.max_num_seqs:
    seq = self.running.popleft()  # 取出队首
    
    # 检查是否有足够空间追加新 token
    while not self.block_manager.can_append(seq):
        # 空间不足,需要抢占其他序列
        if self.running:
            self.preempt(self.running.pop())  # 抢占队尾序列
        else:
            self.preempt(seq)  # 连自己都要抢占(极端情况)
            break
    else:
        # 有足够空间
        num_seqs += 1
        self.block_manager.may_append(seq)  # 可能需要分配新 block
        scheduled_seqs.append(seq)

# 将调度的序列放回 running 队列头部
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False  # is_prefill = False

Decode 调度流程图:

running 队列: [Seq A, Seq B, Seq C, Seq D]
              ↓
              取出 Seq A
              ├── can_append(A)? ✓
              └── 调度 A
              ↓
              取出 Seq B
              ├── can_append(B)? ✗ (需要新 block,但 KV Cache 满)
              │   ├── preempt(Seq D)  → D 回到 waiting
              │   └── can_append(B)? ✓
              └── 调度 B
              ↓
              取出 Seq C
              ├── can_append(C)? ✓
              └── 调度 C
              ↓
running 队列恢复: [Seq A, Seq B, Seq C]  (D 被抢占)
返回: ([Seq A, Seq B, Seq C], is_prefill=False)

抢占机制 (Preemption)#

def preempt(self, seq: Sequence):
    seq.status = SequenceStatus.WAITING
    self.block_manager.deallocate(seq)  # 释放 KV Cache
    self.waiting.appendleft(seq)        # 放回等待队列头部(优先重新调度)

抢占时机:Decode 阶段需要新 block 但 KV Cache 已满

抢占策略:LIFO(后进先出)- 抢占 running 队列尾部的序列

抢占前:
  running: [A, B, C, D]  ← D 最后加入,优先被抢占
  waiting: []
  
抢占后:
  running: [A, B, C]
  waiting: [D]  ← D 放到队首,下次优先 prefill

完整调度决策树#

schedule()
    │
    ├── waiting 非空?
    │       │
    │       ├── Yes → 尝试 Prefill
    │       │           │
    │       │           ├── 能调度至少 1 个? → 返回 (seqs, True)
    │       │           │
    │       │           └── 一个都调度不了 → 进入 Decode
    │       │
    │       └── No → 进入 Decode
    │
    └── Decode 阶段
            │
            ├── 遍历 running 队列
            │       │
            │       ├── can_append? → 调度
            │       │
            │       └── 不能? → 抢占其他序列,再试
            │
            └── 返回 (seqs, False)

阶段 4: 模型执行阶段#

# llm_engine.py
token_ids = self.model_runner.call("run", seqs, is_prefill)

执行流程:

ModelRunner.run(seqs, is_prefill)
  │
  ├─ if is_prefill:
  │     prepare_prefill()      # 准备 prefill 输入
  │       → 构建 input_ids, positions
  │       → 构建 cu_seqlens_q/k (用于 flash attention)
  │       → 计算 slot_mapping (KV Cache 写入位置)
  │       → set_context()      # 设置全局上下文
  │
  └─ else (decode):
        prepare_decode()       # 准备 decode 输入
          → 只取 last_token 作为 input_ids
          → 构建 context_lens, block_tables
          → set_context()

  → run_model(input_ids, positions, is_prefill)
      │
      ├─ if is_prefill or enforce_eager:
      │     model.forward() → model.compute_logits()  # 直接执行
      │
      └─ else (decode + CUDA Graph):
            graph.replay()  # 使用预捕获的 CUDA Graph

  → sampler(logits, temperatures)  # 采样得到 token_ids

prepare_prefill#

任务 说明
收集 input_ids 所有序列的 token IDs(跳过已缓存部分)
计算 positions 每个 token 的位置索引(用于 RoPE)
构建 cu_seqlens Flash Attention 需要的累积序列长度
计算 slot_mapping KV Cache 写入位置
详细流程#
# 假设有 2 个序列:
# Seq A: 长度 1000, num_cached_tokens = 0 (无缓存)
# Seq B: 长度 500,  num_cached_tokens = 256 (有 prefix cache)

input_ids = []
positions = []
cu_seqlens_q = [0]  # Query 累积长度
cu_seqlens_k = [0]  # Key 累积长度

for seq in [Seq A, Seq B]:
    # Seq A:
    input_ids.extend(seq[0:1000])      # 全部 1000 个 token
    positions.extend([0, 1, ..., 999]) # 位置 0-999
    cu_seqlens_q = [0, 1000]           # Q 长度 1000
    cu_seqlens_k = [0, 1000]           # K 长度 1000
    
    # Seq B (有 prefix cache):
    input_ids.extend(seq[256:500])     # 只需要 244 个新 token
    positions.extend([256, 257, ..., 499])  # 位置从 256 开始
    cu_seqlens_q = [0, 1000, 1244]     # Q 长度只有 244
    cu_seqlens_k = [0, 1000, 1500]     # K 长度是 500(需要 attend 到全部)

num_cached_tokens 在 BlockManager allocate 时更新。

Prefill 输入构建:

Seq A (1000 tokens, 无缓存):
input_ids:  [t0, t1, t2, ..., t999]     ← 全部 token
positions:  [0,  1,  2,  ..., 999]      ← 对应位置

Seq B (500 tokens, 256 已缓存):
input_ids:  [t256, t257, ..., t499]     ← 只有新 token
positions:  [256,  257,  ..., 499]      ← 从缓存位置开始

合并后:
input_ids:  [A的t0-t999] + [B的t256-t499]  = 1244 个 token
positions:  [0-999]      + [256-499]        = 1244 个位置

cu_seqlens_q: [0, 1000, 1244]  ← Query 的累积长度
cu_seqlens_k: [0, 1000, 1500]  ← Key 的累积长度(包含缓存)
slot_mapping 计算#
# slot_mapping: 告诉 Triton kernel 把 KV 写到哪里

for seq in seqs:
    for i in range(seq.num_cached_blocks, seq.num_blocks):
        block_id = seq.block_table[i]
        start = block_id * block_size  # 该 block 在 KV Cache 中的起始位置
        end = start + ( block  token )
        slot_mapping.extend(range(start, end))

# 示例: Seq A 有 4 个 block [5, 12, 3, 8], block_size=256
# slot_mapping = [5*256..5*256+255, 12*256..12*256+255, ...]

prepare_decode#

任务 说明
收集 input_ids 每个序列只取 last_token
计算 positions 每个序列只有 一个位置
计算 slot_mapping 新 token 的 KV 写入位置
收集 context_lens 每个序列的总长度(用于 attention)
构建 block_tables KV Cache 的 block 映射
详细流程#
# 假设有 3 个序列正在 decode:
# Seq A: 长度 1001 (刚生成第 1001 个 token)
# Seq B: 长度 502
# Seq C: 长度 300

for seq in [Seq A, Seq B, Seq C]:
    input_ids.append(seq.last_token)     # 只取最后一个 token
    positions.append(len(seq) - 1)       # 最后位置
    context_lens.append(len(seq))        # 总长度
    
    # 计算新 token 的 KV Cache 写入位置
    last_block = seq.block_table[-1]
    offset_in_block = seq.last_block_num_tokens - 1
    slot = last_block * block_size + offset_in_block
    slot_mapping.append(slot)

# 结果:
# input_ids = [A.last, B.last, C.last]  ← 3 个 token
# positions = [1000, 501, 299]
# context_lens = [1001, 502, 300]
# slot_mapping = [slot_A, slot_B, slot_C]
Decode 输入构建 (batch_size=3):

Seq A (len=1001):  input_id = last_token, position = 1000
Seq B (len=502):   input_id = last_token, position = 501
Seq C (len=300):   input_id = last_token, position = 299

input_ids:   [tok_A, tok_B, tok_C]     ← 只有 3 个 token!
positions:   [1000,  501,   299]
context_lens:[1001,  502,   300]       ← attention 需要知道要看多少历史

slot_mapping 示例 (block_size=256):
Seq A: block_table=[5,12,3,8], len=1001
       last_block=8, offset=1001%256-1=232
       slot = 8*256 + 232 = 2280

block_tables: 
  [[5, 12, 3, 8, -1, -1],   ← Seq A 的 block 映射
   [0, 1, -1, -1, -1, -1],  ← Seq B 的 block 映射
   [2, 3, -1, -1, -1, -1]]  ← Seq C 的 block 映射
为什么 decode 阶段只需要取 last_token#

这是 自回归生成(Autoregressive Generation)+ KV Cache 机制的核心设计。

自回归生成的原理#

LLM 是逐个 token 生成的:

输入: "Hello world"
      ↓
模型预测下一个 token → "!" 
      ↓
输入: "Hello world!"
      ↓
模型预测下一个 token → "How"
      ↓
...

每次只生成一个 token,这就是为什么 decode 阶段输入只有 1 个 token。

没有 KV Cache 时的问题#

如果没有 KV Cache,每次生成都要重新计算:

Step 1: 输入 [t0, t1, t2] → 计算全部 attention → 得到 t3
Step 2: 输入 [t0, t1, t2, t3] → 重新计算全部 attention → 得到 t4
Step 3: 输入 [t0, t1, t2, t3, t4] → 重新计算全部 attention → 得到 t5
...

问题: 计算量是 O(N²),随着序列变长,越来越慢!
有 KV Cache 时#

Attention 计算公式

Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d)) @ V

对于位置 i 的 token:

  • Q_i:只依赖当前 token 的 hidden state
  • K, V:需要所有历史 token 的 hidden state
KV Cache 的作用#
Prefill 阶段:
输入: [t0, t1, t2]
      ↓
计算每个 token 的 K, V 并存入 cache:
  KV Cache: [K0,V0], [K1,V1], [K2,V2]
      ↓
用 Q0,Q1,Q2 与全部 K,V 做 attention
      ↓
得到 logits → 采样得 t3

Decode 阶段:
输入: [t3]  ← 只需要这一个!
      ↓
计算 t3 的 K3, V3 并存入 cache:
  KV Cache: [K0,V0], [K1,V1], [K2,V2], [K3,V3]
      ↓
用 Q3 与全部历史 K,V 做 attention
      ↓
得到 logits → 采样得 t4

关键洞察

  • 历史 token 的 K, V 已经在 cache 里了
  • 新 token 只需要计算自己的 Q, K, V
  • Attention 时,用新的 Q 去查询所有历史的 K, V
图解#
Prefill (一次处理多个 token):
┌─────────────────────────────────────────────┐
│  Token:    t0      t1      t2               │
│            ↓       ↓       ↓                │
│  Embed:   [e0]    [e1]    [e2]              │
│            ↓       ↓       ↓                │
│  Q,K,V:  Q0,K0,V0  Q1,K1,V1  Q2,K2,V2       │
│            │       │       │                │
│            ↓       ↓       ↓                │
│  KV Cache: [K0,V0, K1,V1, K2,V2] ← 写入     │
│            ↓       ↓       ↓                │
│  Attention: Q0→K0:2, Q1→K0:1, Q2→K0:2       │
│            ↓       ↓       ↓                │
│  Output:  [o0]    [o1]    [o2] → predict t3 │
└─────────────────────────────────────────────┘

Decode (一次只处理一个 token):
┌─────────────────────────────────────────────┐
│  Token:                    t3               │
│                            ↓                │
│  Embed:                   [e3]              │
│                            ↓                │
│  Q,K,V:                 Q3,K3,V3            │
│                            │                │
│                            ↓                │
│  KV Cache: [K0,V0, K1,V1, K2,V2, K3,V3] ←追加│
│                            ↓                │
│  Attention: Q3 attend to K0,K1,K2,K3        │
│             (从 cache 读取历史 K,V)          │
│                            ↓                │
│  Output:                  [o3] → predict t4 │
└─────────────────────────────────────────────┘
代码对应#
# model_runner.py - prepare_decode
def prepare_decode(self, seqs: list[Sequence]):
    for seq in seqs:
        input_ids.append(seq.last_token)  # ← 只取最后一个 token
        positions.append(len(seq) - 1)     # ← 它的位置
        context_lens.append(len(seq))      # ← 告诉 attention 要看多少历史
# attention.py - decode 时使用 flash_attn_with_kvcache
else:  # decode
    o = flash_attn_with_kvcache(
        q.unsqueeze(1),           # Q: [batch, 1, heads, dim] ← 只有 1 个 query
        k_cache, v_cache,         # 从 cache 读取所有历史 K, V
        cache_seqlens=context_lens,  # 每个序列的历史长度
        block_table=block_tables,
        softmax_scale=self.scale, 
        causal=True
    )
效率对比#
方式 每步计算量 生成 N 个 token 总量
无 KV Cache O(N × d) O(N² × d)
有 KV Cache O(1 × d) O(N × d)

KV Cache 让 decode 阶段从 O(N²) 降到 O(N)!

总结#
阶段 输入 计算 原因
Prefill 全部 prompt tokens 全部 Q,K,V + 写入 cache 第一次见到这些 token
Decode 仅 last token 只算新 Q,K,V + 读 cache 历史 K,V 已缓存,只需计算新 token

一句话:因为 attention 需要的历史 K,V 已经存在 cache 里了,新 token 只需要计算自己的 K,V 并追加到 cache,然后用自己的 Q 去查询全部历史即可。

对比总结#
特征 prepare_prefill prepare_decode
input_ids 数量 所有待处理 token(可能上千) 每序列仅 1 个
positions 连续范围 [start, end) 单个位置
cu_seqlens 需要(Flash Attention varlen) 不需要
context_lens 不需要 需要(告诉 attention 历史长度)
block_tables 仅 prefix cache 时需要 始终需要
slot_mapping 多个连续槽位 每序列仅 1 个槽位
数据流向#
prepare_prefill / prepare_decode
        │
        ├── input_ids, positions → Model.forward() → Embedding + Transformer
        │
        └── set_context() → 全局 Context 对象
                │
                ├── cu_seqlens_q/k, max_seqlen → flash_attn_varlen_func (prefill)
                │
                ├── context_lens, block_tables → flash_attn_with_kvcache (decode)
                │
                └── slot_mapping → store_kvcache kernel (写入 KV Cache)

阶段 5: Attention 计算#

Attention 层核心逻辑:

class Attention(nn.Module):

    def __init__(
        self,
        num_heads,
        head_dim,
        scale,
        num_kv_heads,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = scale
        self.num_kv_heads = num_kv_heads
        self.k_cache = self.v_cache = torch.tensor([])

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        context = get_context()
        k_cache, v_cache = self.k_cache, self.v_cache
        if k_cache.numel() and v_cache.numel():
            store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
        if context.is_prefill:
            if context.block_tables is not None:    # prefix cache
                k, v = k_cache, v_cache
            o = flash_attn_varlen_func(q, k, v,
                                       max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
                                       max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
                                       softmax_scale=self.scale, causal=True, block_table=context.block_tables)
        else:    # decode
            o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
                                        cache_seqlens=context.context_lens, block_table=context.block_tables, 
                                        softmax_scale=self.scale, causal=True)
        return o

整体流程#

                    ┌─────────────────────────────────────────┐
                    │          Qwen3Attention.forward()       │
                    └─────────────────────────────────────────┘
                                       │
                                       ▼
┌──────────────────────────────────────────────────────────────────┐
│  qkv = self.qkv_proj(hidden_states)                              │
│  q, k, v = qkv.split([q_size, kv_size, kv_size])                │
│  q = q.view(-1, num_heads, head_dim)                            │
│  k = k.view(-1, num_kv_heads, head_dim)                         │
│  v = v.view(-1, num_kv_heads, head_dim)                         │
│  q, k = self.rotary_emb(positions, q, k)  ← RoPE                │
└──────────────────────────────────────────────────────────────────┘
                                       │
                                       ▼
                    ┌─────────────────────────────────────────┐
                    │          Attention.forward(q, k, v)     │
                    │               (本文重点)                 │
                    └─────────────────────────────────────────┘
                                       │
                                       ▼
┌──────────────────────────────────────────────────────────────────┐
│  output = self.o_proj(o.flatten(1, -1))                         │
└──────────────────────────────────────────────────────────────────┘

Attention.forward() 详解#

Step 1: 获取全局上下文#
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache

context 包含当前 batch 的元信息(在 prepare_prefill/decode 中设置):

字段 Prefill Decode
is_prefill True False
cu_seqlens_q 累积 Q 长度 -
cu_seqlens_k 累积 K 长度 -
slot_mapping KV 写入位置 单个位置
context_lens - 每序列历史长度
block_tables (prefix cache时) 必需
Step 2: 写入 KV Cache#
if k_cache.numel() and v_cache.numel():
    store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)

使用 Triton kernel 将当前计算的 K, V 写入 cache:

@triton.jit
def store_kvcache_kernel(key_ptr, key_stride, value_ptr, value_stride,
                         k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D):
    idx = tl.program_id(0)                    # 当前处理第几个 token
    slot = tl.load(slot_mapping_ptr + idx)    # 该 token 写入的槽位
    if slot == -1: return                     # 跳过无效槽位
    
    # 读取当前 token 的 K, V
    key = tl.load(key_ptr + idx * key_stride + tl.arange(0, D))
    value = tl.load(value_ptr + idx * value_stride + tl.arange(0, D))
    
    # 写入 cache 的对应位置
    cache_offsets = slot * D + tl.arange(0, D)
    tl.store(k_cache_ptr + cache_offsets, key)
    tl.store(v_cache_ptr + cache_offsets, value)

图解

当前 batch: 3 个新 token
k: [[k0], [k1], [k2]]  shape: [3, num_kv_heads, head_dim]
slot_mapping: [1280, 1281, 1536]

KV Cache 写入:
┌─────────────────────────────────────────────────────────┐
│  ...  │slot1280│slot1281│  ...  │slot1536│  ...        │
│       │   k0   │   k1   │       │   k2   │             │
└─────────────────────────────────────────────────────────┘
Step 3: Prefill 阶段的 Attention#
if context.is_prefill:
    if context.block_tables is not None:    # prefix cache 命中
        k, v = k_cache, v_cache             # 使用 cache 中的 K, V
    
    o = flash_attn_varlen_func(
        q, k, v,
        max_seqlen_q=context.max_seqlen_q,
        cu_seqlens_q=context.cu_seqlens_q,
        max_seqlen_k=context.max_seqlen_k,
        cu_seqlens_k=context.cu_seqlens_k,
        softmax_scale=self.scale,
        causal=True,
        block_table=context.block_tables
    )

flash_attn_varlen_func:处理变长序列的 Flash Attention

输入示例 (2 个序列):
Seq A: 1000 tokens (其中 512 已缓存)
Seq B: 500 tokens (无缓存)

q: [tokens to compute] = [488 + 500] = 988 个 query
k, v: 取决于是否有 prefix cache

cu_seqlens_q: [0, 488, 988]      ← Q 的累积长度
cu_seqlens_k: [0, 1000, 1500]    ← K 的累积长度

Flash Attention 内部:
  Seq A: Q[0:488] attend to K[0:1000]
  Seq B: Q[488:988] attend to K[1000:1500]

Prefix Cache 时的特殊处理

当 cu_seqlens_k > cu_seqlens_q 时(有前缀缓存):
  - k, v 切换为 k_cache, v_cache
  - 提供 block_tables 让 Flash Attention 从 paged cache 读取
  
无 Prefix Cache 时:
  - 直接用当前计算的 k, v
  - block_tables = None
Step 4: Decode 阶段的 Attention#
else:  # decode
    o = flash_attn_with_kvcache(
        q.unsqueeze(1),        # [batch, 1, heads, dim]
        k_cache, v_cache,      # 完整的 KV Cache
        cache_seqlens=context.context_lens,  # 每个序列的历史长度
        block_table=context.block_tables,    # block 映射表
        softmax_scale=self.scale,
        causal=True
    )

flash_attn_with_kvcache:专为 decode 优化的 Attention

输入示例 (batch_size=3):
q: [3, 1, num_heads, head_dim]  ← 每个序列只有 1 个 query

context_lens: [1001, 502, 300]  ← 每个序列要 attend 到多少历史

block_tables:
  Seq 0: [5, 12, 3, 8]     ← 1001 tokens 分布在这 4 个 block
  Seq 1: [0, 1]            ← 502 tokens 分布在这 2 个 block
  Seq 2: [2, 3]            ← 300 tokens 分布在这 2 个 block

Flash Attention 内部:
  对于每个 query,根据 block_table 从 paged KV Cache 读取历史 K, V
完整数据流图#
┌─────────────────────────────────────────────────────────────────────┐
│                        Attention.forward(q, k, v)                   │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  输入:                                                              │
│    q: [total_tokens, num_heads, head_dim]                          │
│    k: [total_tokens, num_kv_heads, head_dim]                       │
│    v: [total_tokens, num_kv_heads, head_dim]                       │
│                                                                     │
│                              │                                      │
│                              ▼                                      │
│  ┌───────────────────────────────────────────────────────────┐     │
│  │  Step 1: store_kvcache(k, v, k_cache, v_cache, slots)     │     │
│  │          将新的 K, V 写入 cache 对应位置                    │     │
│  └───────────────────────────────────────────────────────────┘     │
│                              │                                      │
│                              ▼                                      │
│              ┌───────────────┴───────────────┐                     │
│              │                               │                     │
│        is_prefill?                      is_decode?                 │
│              │                               │                     │
│              ▼                               ▼                     │
│  ┌─────────────────────┐         ┌─────────────────────────┐      │
│  │ flash_attn_varlen   │         │ flash_attn_with_kvcache │      │
│  │                     │         │                         │      │
│  │ - 多 query          │         │ - 单 query per seq      │      │
│  │ - 变长序列          │         │ - 读 paged cache        │      │
│  │ - cu_seqlens 定位   │         │ - context_lens 定界     │      │
│  └─────────────────────┘         └─────────────────────────┘      │
│              │                               │                     │
│              └───────────────┬───────────────┘                     │
│                              │                                      │
│                              ▼                                      │
│  输出:                                                              │
│    o: [total_tokens, num_heads, head_dim]                          │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘
GQA (Grouped Query Attention) 支持#
self.num_heads = num_heads        # Q heads (如 16)
self.num_kv_heads = num_kv_heads  # KV heads (如 4)

GQA 让多个 Q head 共享同一组 K, V

Q heads:  [q0, q1, q2, q3, q4, q5, q6, q7, q8, q9, q10, q11, q12, q13, q14, q15]
           ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓   ↓    ↓    ↓    ↓    ↓    ↓    ↓
KV heads: [  kv0     ] [  kv1     ] [  kv2      ] [   kv3      ]

每 4 个 Q head 共享 1 个 KV head
KV Cache 大小减少 4 倍!
为什么用 Flash Attention?#
传统 Attention Flash Attention
显存 O(N²) 显存 O(N)
需要存完整 attention matrix 分块计算,不存中间结果
内存带宽瓶颈 IO-aware,减少 HBM 访问

Flash Attention 的核心思想

传统: Q @ K^T → [N, N] attention matrix → softmax → @ V
      ↑ 需要 O(N²) 显存

Flash: 分块处理,在 SRAM 中完成 softmax
       - 不需要存储完整 attention matrix
       - 通过 online softmax 技巧保证正确性
总结#
阶段 函数 特点
Prefill flash_attn_varlen_func 多 token,变长 batch,自己算的 K,V
Prefill + Prefix Cache flash_attn_varlen_func + block_table 部分 K,V 从 cache 读取
Decode flash_attn_with_kvcache 单 token,全部 K,V 从 cache 读取

阶段 6: 后处理与输出#

# llm_engine.py
self.scheduler.postprocess(seqs, token_ids)

后处理在 Scheduler.postprocess() 中完成,主要做 追加新 token检查完成状态

    def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
        for seq, token_id in zip(seqs, token_ids):
            seq.append_token(token_id)
            if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
                seq.status = SequenceStatus.FINISHED
                self.block_manager.deallocate(seq)
                self.running.remove(seq)

详细流程#

Step 1: 追加新 token#
seq.append_token(token_id)
    def append_token(self, token_id: int):
        self.token_ids.append(token_id)
        self.last_token = token_id
        self.num_tokens += 1

作用

  • 将模型采样得到的新 token 加入序列
  • 更新 last_token(decode 阶段只需要这个)
  • 递增 num_tokens
Step 2: 检查完成条件#
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:

两种完成条件

条件 说明
token_id == self.eos 生成了 EOS (End of Sequence) token
num_completion_tokens == max_tokens 达到最大生成长度
# sampling_params.py
@dataclass
class SamplingParams:
    temperature: float = 1.0
    max_tokens: int = 64        # ← 最大生成长度
    ignore_eos: bool = False    # ← 是否忽略 EOS
Step 3: 完成时的清理#
if 完成:
    seq.status = SequenceStatus.FINISHED   # 标记完成
    self.block_manager.deallocate(seq)     # 释放 KV Cache
    self.running.remove(seq)               # 从运行队列移除

BlockManager.deallocate

    def deallocate(self, seq: Sequence):
        for block_id in reversed(seq.block_table):
            block = self.blocks[block_id]
            block.ref_count -= 1
            if block.ref_count == 0:
                self._deallocate_block(block_id)
        seq.num_cached_tokens = 0
        seq.block_table.clear()
完整数据流#
model_runner.run(seqs, is_prefill)
        │
        ▼
    token_ids = [103, 256, 78]  ← 采样得到的新 tokens
        │
        ▼
scheduler.postprocess(seqs, token_ids)
        │
        ├── Seq A: append_token(103)
        │          ├── token_ids: [..., 103]
        │          ├── last_token: 103
        │          ├── num_tokens: 1001
        │          └── 103 == EOS? No → 继续运行
        │
        ├── Seq B: append_token(256)
        │          ├── token_ids: [..., 256]
        │          ├── last_token: 256
        │          ├── num_tokens: 502
        │          └── num_completion_tokens(65) == max_tokens(64)? 
        │              Yes → FINISHED, deallocate, remove
        │
        └── Seq C: append_token(78)
                   ├── token_ids: [..., 78]
                   ├── last_token: 78
                   ├── num_tokens: 301
                   └── 78 == EOS(151643)? 
                       Yes → FINISHED, deallocate, remove

核心特性总结#

特性 实现方式
Continuous Batching Scheduler 的 waiting/running 双队列管理
Paged Attention BlockManager + flash_attn_with_kvcache
Prefix Cache 基于 xxhash 的 block hash 匹配
CUDA Graph decode 阶段预捕获多种 batch size 的 graph
Tensor Parallelism SharedMemory + Event 同步多 GPU
抢占机制 preempt() 回收 KV Cache 资源

完整一轮推理的调用序列#

generate()
  └── while not is_finished():
        step()
          ├── scheduler.schedule()
          │     ├── block_manager.allocate() / can_allocate()
          │     └── 返回 (seqs, is_prefill)
          │
          ├── model_runner.call("run", seqs, is_prefill)
          │     ├── prepare_prefill() / prepare_decode()
          │     │     └── set_context()
          │     ├── run_model()
          │     │     └── model.forward() → Attention.forward()
          │     │           ├── store_kvcache()  # Triton kernel
          │     │           └── flash_attn_*()
          │     └── sampler()
          │
          └── scheduler.postprocess()
                ├── seq.append_token()
                └── block_manager.deallocate()  # if finished

BlockManager 内存管理#

BlockManager 是 KV Cache 的内存管理器,负责以 分页(Paged) 的方式管理 GPU 上的 KV Cache 空间。

BlockManager 解决的核心问题

  1. 内存碎片:固定大小的 block 避免碎片
  2. 动态长度:序列长度变化时按需分配 block
  3. 内存复用:Prefix Cache 让相同前缀共享 KV Cache
  4. 资源控制:为调度器提供资源可用性判断
┌─────────────────────────────────────────────────────────────────┐
│                         Scheduler                                │
│                            │                                     │
│           ┌────────────────┼────────────────┐                   │
│           ▼                ▼                ▼                   │
│      can_allocate    can_append       deallocate                │
│           │                │                │                   │
└───────────┼────────────────┼────────────────┼───────────────────┘
            │                │                │
            ▼                ▼                ▼
┌─────────────────────────────────────────────────────────────────┐
│                       BlockManager                               │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │  free_block_ids: [5, 6, 7, 8, ...]                      │    │
│  │  used_block_ids: {0, 1, 2, 3, 4}                        │    │
│  │  hash_to_block_id: {0xABC: 0, 0xDEF: 1, ...}           │    │
│  └─────────────────────────────────────────────────────────┘    │
│                            │                                     │
│                            ▼                                     │
│  ┌─────────────────────────────────────────────────────────┐    │
│  │  Sequence.block_table (逻辑 → 物理映射)                  │    │
│  │  Seq A: [0, 1, 2]                                       │    │
│  │  Seq B: [3, 4]                                          │    │
│  │  Seq C: [0, 5]  ← Block 0 被共享 (Prefix Cache)         │    │
│  └─────────────────────────────────────────────────────────┘    │
└─────────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────────┐
│                   GPU KV Cache (物理存储)                        │
│  ┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬...  │
│  │Blk 0  │Blk 1  │Blk 2  │Blk 3  │Blk 4  │Blk 5  │Blk 6  │     │
│  │Seq A,C│Seq A  │Seq A  │Seq B  │Seq B  │Seq C  │ free  │     │
│  └───────┴───────┴───────┴───────┴───────┴───────┴───────┴...  │
└─────────────────────────────────────────────────────────────────┘

核心数据结构#

BlockManager:

class BlockManager:

    def __init__(self, num_blocks: int, block_size: int):
        self.block_size = block_size
        self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
        self.hash_to_block_id: dict[int, int] = dict()
        self.free_block_ids: deque[int] = deque(range(num_blocks))
        self.used_block_ids: set[int] = set()
属性 作用
blocks 所有 Block 对象的列表
free_block_ids 空闲 block 的 ID 队列
used_block_ids 已使用 block 的 ID 集合
hash_to_block_id 用于 Prefix Cache 的哈希映射

Block:

class Block:

    def __init__(self, block_id):
        self.block_id = block_id
        self.ref_count = 0      # 引用计数(多个序列可共享)
        self.hash = -1          # 内容哈希(用于 Prefix Cache)
        self.token_ids = []     # 存储的 token IDs(用于验证)

每个 Block 对应 KV Cache 中固定大小的一段空间:

KV Cache 物理布局(假设 block_size=256):
┌─────────────┬─────────────┬─────────────┬─────────────┬─...
│   Block 0   │   Block 1   │   Block 2   │   Block 3   │
│ 256 tokens  │ 256 tokens  │ 256 tokens  │ 256 tokens  │
└─────────────┴─────────────┴─────────────┴─────────────┴─...

核心功能#

1. 分配 KV Cache(allocate)#

当新序列进入 prefill 时调用:

    def allocate(self, seq: Sequence):
        assert not seq.block_table
        h = -1
        cache_miss = False
        for i in range(seq.num_blocks):
            token_ids = seq.block(i)
            h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
            block_id = self.hash_to_block_id.get(h, -1)
            if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
                cache_miss = True
            if cache_miss:
                block_id = self.free_block_ids[0]
                block = self._allocate_block(block_id)
            else:
                seq.num_cached_tokens += self.block_size
                if block_id in self.used_block_ids:
                    block = self.blocks[block_id]
                    block.ref_count += 1
                else:
                    block = self._allocate_block(block_id)
            if h != -1:
                block.update(h, token_ids)
                self.hash_to_block_id[h] = block_id
            seq.block_table.append(block_id)

分配逻辑图解:

序列 A (1000 tokens, block_size=256):
需要 ceil(1000/256) = 4 个 blocks

分配前:
  free_block_ids: [0, 1, 2, 3, 4, 5, ...]
  seq.block_table: []

分配后:
  free_block_ids: [4, 5, ...]
  seq.block_table: [0, 1, 2, 3]
  
  Block 映射:
  ┌─────────┬─────────┬─────────┬─────────┐
  │ Block 0 │ Block 1 │ Block 2 │ Block 3 │
  │tok 0-255│tok256-511│tok512-767│tok768-999│
  └─────────┴─────────┴─────────┴─────────┘

2. 释放 KV Cache(deallocate)#

当序列完成或被抢占时调用:

    def deallocate(self, seq: Sequence):
        for block_id in reversed(seq.block_table):
            block = self.blocks[block_id]
            block.ref_count -= 1
            if block.ref_count == 0:
                self._deallocate_block(block_id)
        seq.num_cached_tokens = 0
        seq.block_table.clear()

引用计数机制:Block 只在 ref_count == 0 时才真正释放,支持 Prefix Cache 共享。

3. 追加新 token(may_append)#

Decode 阶段每生成一个 token,可能需要新 block:

    def can_append(self, seq: Sequence) -> bool:
        return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)

    def may_append(self, seq: Sequence):
        block_table = seq.block_table
        last_block = self.blocks[block_table[-1]]
        if len(seq) % self.block_size == 1:
            # 当前 block 刚好满了,需要新 block
            assert last_block.hash != -1
            block_id = self.free_block_ids[0]
            self._allocate_block(block_id)
            block_table.append(block_id)
        elif len(seq) % self.block_size == 0:
            # 刚填满一个 block,更新哈希
            token_ids = seq.block(seq.num_blocks-1)
            prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
            h = self.compute_hash(token_ids, prefix)
            last_block.update(h, token_ids)
            self.hash_to_block_id[h] = last_block.block_id

何时需要新 block?

block_size = 256

序列长度 255 → 256: 不需要新 block(当前 block 还有空间)
序列长度 256 → 257: 需要新 block(当前 block 满了)

判断条件: len(seq) % block_size == 1
  - 257 % 256 = 1 → 需要新 block
  - 258 % 256 = 2 → 不需要

4. Prefix Cache(前缀缓存)#

核心思想:相同 prompt 前缀的 KV Cache 可以复用。

    @classmethod
    def compute_hash(cls, token_ids: list[int], prefix: int = -1):
        h = xxhash.xxh64()
        if prefix != -1:
            h.update(prefix.to_bytes(8, "little"))
        h.update(np.array(token_ids).tobytes())
        return h.intdigest()

每个 block 的 hash 依赖于前一个 block 的 hash,形成链式结构:

Block 0 hash = hash(tokens_0)
Block 1 hash = hash(tokens_1, prefix=Block_0_hash)
Block 2 hash = hash(tokens_2, prefix=Block_1_hash)
...

链式 hash 的设计保证了 KV Cache 的正确性。Transformer 的 attention 是 因果的(causal),每个 token 只能 attend 到之前的 token。

如果中间 block 复用了不同前缀的缓存:

序列 A: [t0, t1, t2, t3, t4, t5, t6, t7]
               ↓
        Block 0: KV(t0,t1,t2,t3)
        Block 1: KV(t4,t5,t6,t7) ← 基于 t0-t3 计算的

序列 B: [x0, x1, x2, x3, t4, t5, t6, t7]
               ↓
        Block 0: KV(x0,x1,x2,x3)  ← 新的
        Block 1: 复用 A 的?       ← 错误!因为 t4-t7 需要 attend 到 x0-x3
                                    而不是 t0-t3

KV Cache 中存的是基于特定上下文计算的 Key 和 Value,前缀不同就不能复用。

同时,链式 hash 设计让查找变得简单:

  • 只需要检查 hash 是否匹配
  • 不需要额外验证前缀一致性

Prefix Cache 工作流程:

请求 1: "Hello, how are you? I am fine."
  Block 0: hash(["Hello", ",", " how", ...])  = 0xABC123
  Block 1: hash([...], prefix=0xABC123)       = 0xDEF456
  
请求 2: "Hello, how are you? What's your name?"
  Block 0: 查找 hash → 命中 0xABC123 → 复用!不需要重新计算
  Block 1: 内容不同 → cache miss → 分配新 block

效果: 请求 2 跳过了 Block 0 的 256 tokens 的 prefill 计算

总结#

功能 方法 调用时机
检查能否分配 can_allocate(seq) Prefill 调度前
分配 blocks allocate(seq) Prefill 调度时
检查能否追加 can_append(seq) Decode 调度前
追加 block may_append(seq) Decode 调度时
释放 blocks deallocate(seq) 完成/抢占时