vLLM Paged Attention¶
目前,vLLM 使用其自己实现的多头查询注意力内核(`csrc/attention/attention_kernels.cu`)。此内核设计用于兼容 vLLM 的分页 KV 缓存,其中键(key)和值(value)缓存存储在单独的块中(请注意,此“块”概念不同于 GPU 线程块。因此,在后续文档中,我将 vLLM 分页注意力块称为“块”,而将 GPU 线程块称为“线程块”)。
为了实现高性能,此内核依赖于专门设计的内存布局和访问方法,尤其是在线程将数据从全局内存读取到共享内存时。本文档旨在逐步提供对内核实现的高级解释,以帮助那些希望了解 vLLM 多头查询注意力内核的人。阅读本文档后,用户可能会更好地理解并更容易地跟随实际的实现。
请注意,本文档可能不会涵盖所有细节,例如如何计算相应数据的正确索引或点乘实现。但是,在阅读本文档并熟悉高级逻辑流程后,您应该更容易阅读实际代码并理解其细节。
输入¶
内核函数接受当前线程执行其分配工作的参数列表。其中最重要的三个参数是输入指针 `q`、`k_cache` 和 `v_cache`,它们指向需要从全局内存读取和处理的查询(query)、键(key)和值(value)数据。输出指针 `out` 指向应写入结果的全局内存。这四个指针实际上引用多维数组,但每个线程只访问分配给它的那部分数据。为了简化,我在此省略了所有其他运行时参数。
template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
__device__ void paged_attention_kernel(
... // Other side args.
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
... // Other side args.
)
函数签名上方还有一系列模板参数,它们在编译时确定。`scalar_t` 表示查询、键和值数据元素的数据类型,例如 FP16。`HEAD_SIZE` 表示每个头中的元素数量。`BLOCK_SIZE` 指每个块中的 token 数量。`NUM_THREADS` 表示每个线程块中的线程数量。`PARTITION_SIZE` 表示张量并行 GPU 的数量(为简化起见,我们假设此值为 0 且张量并行已禁用)。
有了这些参数,我们需要执行一系列准备工作。这包括计算当前头索引、块索引和其他必要的变量。但是,目前我们可以忽略这些准备工作,直接进行实际计算。一旦我们掌握了整个流程,再理解它们会更容易。
概念¶
在我们深入计算流程之前,我想先描述一些后续部分需要的概念。但是,如果您遇到任何令人困惑的术语,可以跳过此部分并稍后返回。
- 序列(Sequence):序列代表一个客户端请求。例如,`q` 指向的数据形状为 `[num_seqs, num_heads, head_size]`。这表示 `q` 共指向 `num_seqs` 个查询序列数据。由于此内核是单查询注意力内核,每个序列只有一个查询 token。因此,`num_seqs` 等于批处理中处理的 token 总数。
- 上下文(Context):上下文由序列中已生成的 token 组成。例如,`["What", "is", "your"]` 是上下文 token,而输入的查询 token 是 `"name"`。模型可能会生成 token `"?"`。
- 向量(Vec):向量是一组同时获取和计算的元素列表。对于查询和键数据,向量大小(`VEC_SIZE`)的确定是为了使每个线程组可以一次获取和计算 16 字节的数据。对于值数据,向量大小(`V_VEC_SIZE`)的确定是为了使每个线程可以一次获取和计算 16 字节的数据。例如,如果 `scalar_t` 是 FP16(2 字节),`THREAD_GROUP_SIZE` 为 2,则 `VEC_SIZE` 将为 4,而 `V_VEC_SIZE` 将为 8。
- 线程组(Thread group):线程组是一小群线程(`THREAD_GROUP_SIZE`),它们一次获取并计算一个查询 token 和一个键 token。每个线程只处理 token 数据的一部分。一个线程组处理的元素总数称为 `x`。例如,如果线程组包含 2 个线程,头大小为 8,那么线程 0 处理索引为 0, 2, 4, 6 的查询和键元素,而线程 1 处理索引为 1, 3, 5, 7 的元素。
- 块(Block):vLLM 中的键和值缓存数据被分成块。每个块存储一个头中固定数量(`BLOCK_SIZE`)token 的数据。每个块可能只包含整个上下文 token 的一部分。例如,如果块大小为 16 且头大小为 128,那么对于一个头,一个块可以存储 16 * 128 = 2048 个元素。
- Warp:Warp 是 32 个线程(`WARP_SIZE`)的组,它们在流多处理器(SM)上同时执行。在此内核中,每个 warp 一次处理一个查询 token 与一个完整块的键 token 之间的计算(它可能在多次迭代中处理多个块)。例如,如果一个上下文有 4 个 warp 和 6 个块,则分配方式将是:warp 0 处理第 0、第 4 个块,warp 1 处理第 1、第 5 个块,warp 2 处理第 2 个块,warp 3 处理第 3 个块。
- 线程块(Thread block):线程块是一组线程(`NUM_THREADS`),它们可以访问相同的共享内存。每个线程块包含多个 warp(`NUM_WARPS`),在此内核中,每个线程块处理一个查询 token 与整个上下文的键 token 之间的计算。
- 网格(Grid):网格是线程块的集合,并定义了该集合的形状。在此内核中,形状为 `(num_heads, num_seqs, max_num_partitions)`。因此,每个线程块只处理一个头、一个序列和一个分区的计算。
查询(Query)¶
本节将介绍查询数据在内存中如何存储以及如何由每个线程获取。如上所述,每个线程组获取一个查询 token 数据,而每个线程本身只处理一个查询 token 数据的一部分。在每个 warp 内,每个线程组将获取相同的查询 token 数据,但会将其与不同的键 token 数据相乘。

每个线程定义自己的 `q_ptr`,它指向全局内存上分配的查询 token 数据。例如,如果 `VEC_SIZE` 为 4 且 `HEAD_SIZE` 为 128,则 `q_ptr` 指向总共包含 128 个元素(分为 128 / 4 = 32 个向量)的数据。

接下来,我们需要将 `q_ptr` 指向的全局内存数据读取到共享内存中,作为 `q_vecs`。需要注意的是,每个向量都被分配到不同的行。例如,如果 `THREAD_GROUP_SIZE` 为 2,则线程 0 将处理第 0 行的向量,而线程 1 将处理第 1 行的向量。通过这种方式读取查询数据,线程 0 和线程 1 等相邻线程可以读取相邻内存,从而实现内存合并以提高性能。
键(Key)¶
与“查询”部分类似,本节介绍键的内存布局和分配。虽然每个线程组在一次内核运行中只处理一个查询 token,但它可能在多次迭代中处理多个键 token。同时,每个 warp 将在多次迭代中处理多个键 token 块,确保在内核运行后,所有上下文 token 都由整个线程组处理。在此上下文中,“处理”指的是执行查询数据和键数据之间的点乘。
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
与 `q_ptr` 不同,每个线程中的 `k_ptr` 将在不同的迭代中指向不同的键 token。如上所示,`k_ptr` 根据 `k_cache` 在分配的块、分配的头和分配的 token 处指向键 token 数据。

上图说明了键数据的内存布局。它假设 `BLOCK_SIZE` 为 16,`HEAD_SIZE` 为 128,`x` 为 8,`THREAD_GROUP_SIZE` 为 2,并且总共有 4 个 warp。每个矩形表示一个头中一个键 token 的所有元素,这些元素将由一个线程组处理。左半部分显示了 warp 0 的总共 16 个键 token 数据块,而右半部分表示其他 warp 或迭代的剩余键 token 数据。每个矩形内部总共有 32 个向量(一个 token 包含 128 个元素),将由 2 个线程(一个线程组)单独处理。

接下来,我们需要从 `k_ptr` 读取键 token 数据并将其存储在寄存器内存中作为 `k_vecs`。我们使用寄存器内存存储 `k_vecs`,因为它只会被一个线程访问一次,而 `q_vecs` 会被多个线程访问多次。每个 `k_vecs` 将包含多个向量,用于后续计算。每个向量将在每次内部迭代中设置。向量的分配允许 warp 中的相邻线程一起读取相邻内存,这再次促进了内存合并。例如,线程 0 将读取向量 0,而线程 1 将读取向量 1。在下一个内循环中,线程 0 将读取向量 2,而线程 1 将读取向量 3,依此类推。
您可能仍然对整个流程感到有些困惑。不用担心,请继续阅读下一节“QK”。它将以更清晰、更高级的方式说明查询和键的计算流程。
QK¶
如下图伪代码所示,在整个 for 循环块之前,我们获取一个 token 的查询数据并将其存储在 `q_vecs` 中。然后,在外层 for 循环中,我们迭代指向不同 token 的不同 `k_ptrs`,并在内层 for 循环中准备 `k_vecs`。最后,我们对 `q_vecs` 和每个 `k_vecs` 执行点乘。
q_vecs = ...
for ... {
k_ptr = ...
for ... {
k_vecs[i] = ...
}
...
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
}
如前所述,对于每个线程,它一次只获取部分查询和键 token 数据。然而,在 `Qk_dot<>::dot` 中会发生跨线程组的规约(reduction)。因此,这里返回的 `qk` 不仅仅是部分查询和键 token 的点乘结果,而实际上是整个查询和键 token 数据之间的完整结果。
例如,如果 `HEAD_SIZE` 的值为 128 且 `THREAD_GROUP_SIZE` 为 2,则每个线程的 `k_vecs` 将总共包含 64 个元素。然而,返回的 `qk` 实际上是 128 个查询元素和 128 个键元素之间点乘的结果。如果您想了解更多关于点乘和规约的细节,可以参考 `Qk_dot<>::dot` 的实现。但是,为简化起见,本文档中将不予介绍。
Softmax¶
接下来,我们需要计算所有 `qk` 的归一化 Softmax,如上所示,其中每个代表一个 `qk`。为此,我们必须获得 `qk_max` 的规约值()和 `exp_sum`()的所有 `qk` 值。规约应在整个线程块中执行,涵盖查询 token 和所有上下文键 token 之间的结果。
`qk_max` 和 `logits`¶
在我们获得 `qk` 结果后,我们可以立即用 `qk` 设置临时的 `logits` 结果(最终,`logits` 应该存储归一化 Softmax 结果)。同时,我们还可以比较并收集当前线程组计算出的所有 `qk` 中的 `qk_max`。
if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
请注意,这里的 `logits` 位于共享内存中,因此每个线程组将为其自己分配的上下文 token 设置字段。总的来说,logits 的大小应为上下文 token 的数量。
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
然后我们需要获取每个 warp 的规约 `qk_max`。主要思想是让 warp 中的线程相互通信并获得最终的 `qk` 最大值。
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
最后,通过比较此线程块中所有 warp 的 `qk_max`,我们可以从整个线程块中获得规约后的 `qk_max`。然后我们需要将最终结果广播给每个线程。
`exp_sum`¶
与 `qk_max` 类似,我们还需要从整个线程块中获取规约后的求和值。
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
首先,对每个线程组的所有 exp 值求和,同时将 `logits` 的每个条目从 `qk` 转换为 `exp(qk - qk_max)`。请注意,这里的 `qk_max` 已经是整个线程块中的最大 `qk`。然后,我们可以像处理 `qk_max` 一样,对整个线程块进行 `exp_sum` 的规约。
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
最后,通过规约后的 `qk_max` 和 `exp_sum`,我们可以得到最终的归一化 Softmax 结果作为 `logits`。这个 `logits` 变量将在后续步骤中用于与值数据进行点乘。现在,它应该存储所有分配的上下文 token 的 `qk` 的归一化 Softmax 结果。
值(Value)¶



现在我们需要检索值数据并与 `logits` 进行点乘。与查询和键不同,值数据没有线程组的概念。如图所示,与键 token 的内存布局不同,来自同一列的元素对应于相同的值 token。对于一个值数据块,有 `HEAD_SIZE` 行和 `BLOCK_SIZE` 列,它们被分成多个 `v_vecs`。
每个线程总是从相同数量的 `V_VEC_SIZE` 个 token 中一次获取 `V_VEC_SIZE` 个元素。因此,单个线程通过多次内部迭代从不同行和相同列中检索多个 `v_vec`。对于每个 `v_vec`,它需要与相应的 `logits_vec` 进行点乘,后者也是 `logits` 中的 `V_VEC_SIZE` 个元素。总的来说,通过多次内部迭代,每个 warp 将处理一个值 token 块。通过多次外部迭代,将处理整个上下文的值 token。
float accs[NUM_ROWS_PER_THREAD];
for ... { // Iteration over different blocks.
logits_vec = ...
for ... { // Iteration over different rows.
v_vec = ...
...
accs[i] += dot(logits_vec, v_vec);
}
}
如以上伪代码所示,在外层循环中,与 `k_ptr` 类似,`logits_vec` 遍历不同的块并从 `logits` 中读取 `V_VEC_SIZE` 个元素。在内层循环中,每个线程从相同的 token 中读取 `V_VEC_SIZE` 个元素作为 `v_vec` 并执行点乘。重要的是要注意,在每次内部迭代中,线程为相同的 token 获取不同的头位置元素。点乘结果随后累积在 `accs` 中。因此,`accs` 的每个条目都映射到分配给当前线程的头位置。
例如,如果 `BLOCK_SIZE` 为 16 且 `V_VEC_SIZE` 为 8,则每个线程一次为 8 个 token 获取 8 个值元素。每个元素都来自不同 token 的相同头位置。如果 `HEAD_SIZE` 为 128 且 `WARP_SIZE` 为 32,则在每个内循环中,一个 warp 需要获取 `WARP_SIZE * V_VEC_SIZE = 256` 个元素。这意味着一个 warp 总共需要进行 128 * 16 / 256 = 8 次内部迭代来处理一个完整的值 token 块。每个线程中的 `accs` 包含 8 个元素,这些元素累积在 8 个不同的头位置。对于线程 0,`accs` 变量将包含 8 个元素,它们是值头中的第 0、第 32 … 第 224 个元素,这些元素从所有分配的 8 个 token 中累积而来。
LV¶
现在,我们需要在每个 warp 内对 `accs` 执行规约。此过程允许每个线程累积一个块中所有 token 的分配头位置的 `accs`。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
接下来,我们对所有 warp 的 `accs` 执行规约,从而使每个线程都能累积所有上下文 token 的分配头位置的 `accs`。请注意,每个线程中的每个 `accs` 只存储整个头在所有上下文 token 中的一部分元素的累积。但是,总的来说,所有输出结果都已计算出来,只是存储在不同的线程寄存器内存中。
代码
float* out_smem = reinterpret_cast<float*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) {
// Upper warps write to shared memory.
...
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
dst[row_idx] = accs[i];
}
// Lower warps update the output.
const float* src = &out_smem[warp_idx * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
accs[i] += src[row_idx];
}
// Write out the accs.
}
输出¶
现在我们可以将所有计算结果从局部寄存器内存写入最终的输出全局内存。
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE
+ partition_idx * HEAD_SIZE;
首先,我们需要定义 `out_ptr` 变量,它指向分配的序列和分配的头的起始地址。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
最后,我们需要遍历不同的分配头位置,并根据 `out_ptr` 写入相应的累积结果。