Paged Attention¶
警告
这是一份基于 vLLM 原始论文 的历史文档。它已不再反映 vLLM 当前使用的代码。
目前,vLLM 使用其自定义实现的多头查询注意力内核(csrc/attention/attention_kernels.cu)。该内核旨在兼容 vLLM 的分页 KV 缓存(Paged KV Cache),其中 Key 和 Value 缓存存储在独立的块中(注意,此处的“块”概念不同于 GPU 线程块。因此,在后续文档中,我将 vLLM 的分页注意力块称为“block”,而将 GPU 线程块称为“thread block”)。
为了实现高性能,该内核依赖于专门设计的内存布局和访问方法,特别是在线程从全局内存读取数据到共享内存时。本文件的目的是分步骤地对内核实现进行高层解释,以帮助那些希望了解 vLLM 多头查询注意力内核的人。阅读完本文档后,用户通常会更好地理解并更容易上手跟踪实际代码。
请注意,本文档可能未涵盖所有细节,例如如何计算对应数据的正确索引或点积的具体实现。然而,在阅读本文档并熟悉高层逻辑流程后,您将更容易阅读实际代码并理解其中的细节。
输入¶
内核函数接收一系列参数,供当前线程执行其分配的任务。其中最重要的三个参数是输入指针 q、k_cache 和 v_cache,它们指向需要读取和处理的全局内存中的 Query、Key 和 Value 数据。输出指针 out 指向应写入结果的全局内存。这四个指针实际上引用的是多维数组,但每个线程仅访问分配给它的那部分数据。为了简洁起见,我在此省略了所有其他运行时参数。
template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
__device__ void paged_attention_kernel(
... // Other side args.
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
... // Other side args.
)
函数签名上方还有一系列在编译时确定的模板参数。scalar_t 代表 Query、Key 和 Value 数据元素的数据类型,例如 FP16。HEAD_SIZE 表示每个头部的元素数量。BLOCK_SIZE 指的是每个块中的 token 数量。NUM_THREADS 表示每个线程块中的线程数量。PARTITION_SIZE 代表张量并行 GPU 的数量(为简单起见,我们假设该值为 0,即禁用了张量并行)。
有了这些参数,我们需要执行一系列准备工作。这包括计算当前的头部索引、块索引以及其他必要的变量。然而,现在我们可以忽略这些准备工作,直接进入实际计算。一旦我们掌握了整个流程,理解这些准备工作就会变得容易得多。
概念¶
在深入计算流程之前,我想先描述一下后续章节中需要的几个概念。不过,如果您遇到不熟悉的术语,也可以跳过此部分,以后再回来看。
- Sequence (序列):一个序列代表一个客户端请求。例如,
q指向的数据形状为[num_seqs, num_heads, head_size]。这表示q指向的数据总共有num_seqs个查询序列。由于该内核是单查询注意力内核,每个序列只有一个查询 token。因此,num_seqs等于批处理中处理的 token 总数。 - Context (上下文):上下文由序列生成的 token 组成。例如,
["What", "is", "your"]是上下文 token,而输入查询 token 是"name"。模型可能会生成 token"?"。 - Vec:Vec 是一组被一起获取和计算的元素列表。对于 Query 和 Key 数据,Vec 大小(
VEC_SIZE)的确定是为了让每个线程组每次能够获取和计算 16 字节的数据。对于 Value 数据,Vec 大小(V_VEC_SIZE)的确定是为了让每个线程每次能够获取和计算 16 字节的数据。例如,如果scalar_t是 FP16(2 字节)且THREAD_GROUP_SIZE为 2,则VEC_SIZE为 4,而V_VEC_SIZE为 8。 - Thread group (线程组):线程组是一小组线程(
THREAD_GROUP_SIZE),它们一次获取并计算一个查询 token 和一个 Key token。每个线程只处理 token 数据的一部分。一个线程组处理的元素总数称为x。例如,如果线程组包含 2 个线程且头部大小为 8,则线程 0 处理索引为 0, 2, 4, 6 的 Query 和 Key 元素,而线程 1 处理索引为 1, 3, 5, 7 的元素。 - Block (块):vLLM 中的 Key 和 Value 缓存数据被分割成块。每个块存储一个头部固定数量(
BLOCK_SIZE)的 token 数据。每个块可能只包含全部上下文 token 的一部分。例如,如果块大小为 16,头部大小为 128,那么对于一个头部,一个块可以存储 16 * 128 = 2048 个元素。 - Warp:Warp 是一组 32 个线程(
WARP_SIZE),它们在流式多处理器(SM)上同时执行。在此内核中,每个 Warp 一次处理一个查询 token 与一个完整块的 Key token 之间的计算(它可能会在多次迭代中处理多个块)。例如,如果一个上下文有 4 个 Warp 和 6 个块,分配方式可能是:Warp 0 处理第 0、4 块,Warp 1 处理第 1、5 块,Warp 2 处理第 2 块,Warp 3 处理第 3 块。 - Thread block (线程块):线程块是一组线程(
NUM_THREADS),它们可以访问相同的共享内存。每个线程块包含多个 Warp(NUM_WARPS),在此内核中,每个线程块处理一个查询 token 与整个上下文 Key token 之间的计算。 - Grid:Grid 是线程块的集合,定义了该集合的形状。在此内核中,形状为
(num_heads, num_seqs, max_num_partitions)。因此,每个线程块仅处理一个头部、一个序列和一个分区的计算。
Query (查询)¶
本节将介绍查询数据如何存储在内存中并被每个线程获取。如上所述,每个线程组获取一个查询 token 的数据,而每个线程本身只处理一个查询 token 数据的一部分。在每个 Warp 中,每个线程组将获取相同的查询 token 数据,但会将其与不同的 Key token 数据相乘。
每个线程定义自己的 q_ptr,指向分配给它的全局内存中的查询 token 数据。例如,如果 VEC_SIZE 为 4 且 HEAD_SIZE 为 128,则 q_ptr 指向总共包含 128 个元素的数据,这些元素被划分为 128 / 4 = 32 个 Vecs。
接下来,我们需要将 q_ptr 指向的全局内存数据读取到共享内存中作为 q_vecs。值得注意的是,每个 Vec 被分配到不同的行。例如,如果 THREAD_GROUP_SIZE 为 2,线程 0 将处理第 0 行 Vecs,而线程 1 处理第 1 行 Vecs。通过这种方式读取查询数据,相邻的线程(如线程 0 和线程 1)可以读取相邻的内存,从而实现内存合并以提高性能。
Key (键)¶
与“Query”部分类似,本节介绍 Key 的内存布局和分配。虽然每个线程组在一次内核运行中只处理一个查询 token,但它可能会在多次迭代中处理多个 Key token。同时,每个 Warp 将在多次迭代中处理多个 Key token 块,确保在内核运行结束后,所有上下文 token 都被整个线程组处理。在此上下文中,“处理”是指执行查询数据和 Key 数据之间的点积。
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
与 q_ptr 不同,每个线程中的 k_ptr 在不同的迭代中指向不同的 Key token。如上所示,k_ptr 根据分配的块、分配的头部和分配的 token 指向 k_cache 中的 Key token 数据。
上图说明了 Key 数据的内存布局。它假设 BLOCK_SIZE 为 16,HEAD_SIZE 为 128,x 为 8,THREAD_GROUP_SIZE 为 2,且总共有 4 个 Warp。每个矩形代表一个头部中一个 Key token 的所有元素,由一个线程组处理。左半部分显示了 Warp 0 的全部 16 块 Key token 数据,右半部分代表其他 Warp 或迭代的剩余 Key token 数据。在每个矩形内,总共有 32 个 Vecs(一个 token 128 个元素),将由 2 个线程(一个线程组)分别处理。
接下来,我们需要从 k_ptr 读取 Key token 数据并将其存储在寄存器内存中作为 k_vecs。我们使用寄存器内存存储 k_vecs,因为它只会被一个线程访问一次,而 q_vecs 会被多个线程多次访问。每个 k_vecs 将包含多个后续计算所需的向量。每个 Vec 将在每次内部迭代中设置。Vec 的分配允许 Warp 中的相邻线程一起读取相邻内存,这再次促进了内存合并。例如,线程 0 读取 Vec 0,而线程 1 读取 Vec 1。在下一个内部循环中,线程 0 读取 Vec 2,而线程 1 读取 Vec 3,依此类推。
您可能仍然对整体流程感到困惑。别担心,请继续阅读下一节“QK”。它将以更清晰、更高层的方式说明查询和键的计算流程。
QK¶
如下方伪代码所示,在整个 for 循环块之前,我们获取一个 token 的查询数据并将其存储在 q_vecs 中。然后,在外部 for 循环中,我们遍历指向不同 token 的不同 k_ptrs,并在内部 for 循环中准备 k_vecs。最后,我们执行 q_vecs 和每个 k_vecs 之间的点积。
q_vecs = ...
for ... {
k_ptr = ...
for ... {
k_vecs[i] = ...
}
...
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
}
如前所述,每个线程一次只获取部分查询和 Key token 数据。然而,在 Qk_dot<>::dot 中会发生跨线程组的归约(reduction)。因此,这里返回的 qk 不仅仅是查询和 Key token 数据部分点积的结果,实际上是整个查询和 Key token 数据之间点积的完整结果。
例如,如果 HEAD_SIZE 为 128 且 THREAD_GROUP_SIZE 为 2,每个线程的 k_vecs 总共包含 64 个元素。然而,返回的 qk 实际上是 128 个查询元素和 128 个 Key 元素之间点积的结果。如果您想了解关于点积和归约的更多细节,可以参考 Qk_dot<>::dot 的实现。为了简洁起见,我不会在本文档中涵盖这些细节。
Softmax¶
接下来,我们需要计算所有 qk 的归一化 Softmax,如下所示,其中每个 \(x\) 代表一个 qk。为此,我们必须获得所有 qk 的 qk_max 的归约值(\(m(x)\))和 exp_sum(\(\ell(x)\))。归约应在整个线程块中执行,涵盖查询 token 与所有上下文 Key token 之间的结果。
qk_max 和 logits¶
在获得 qk 结果后,我们可以用 qk 设置临时的 logits 结果(最终,logits 应该存储归一化后的 Softmax 结果)。此外,我们可以比较并收集当前线程组计算出的所有 qk 的 qk_max。
if (thread_group_offset == 0) {
const bool mask = token_idx >= context_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
请注意,此处的 logits 位于共享内存中,因此每个线程组将为其分配的上下文 token 设置字段。总体而言,logits 的大小应为上下文 token 的数量。
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
然后,我们需要获得跨每个 Warp 的归约后的 qk_max。主要思想是让 Warp 中的线程相互通信,并获得最终的最大 qk。
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
最后,我们可以通过比较该线程块中所有 Warp 的 qk_max,获得整个线程块的归约后 qk_max。然后我们需要将最终结果广播给每个线程。
exp_sum¶
与 qk_max 类似,我们也需要获得整个线程块的归约后求和值。
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
首先,对每个线程组的所有 exp 值求和,同时将 logits 的每个条目从 qk 转换为 exp(qk - qk_max)。请注意,此处的 qk_max 已经是整个线程块的最大 qk。然后,我们可以像 qk_max 一样对整个线程块进行 exp_sum 的归约。
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
最后,利用归约后的 qk_max 和 exp_sum,我们可以获得最终的归一化 Softmax 结果作为 logits。此 logits 变量将在后续步骤中用于与 Value 数据进行点积。现在,它应该存储所有已分配上下文 token 的 qk 的归一化 Softmax 结果。
Value (值)¶
现在我们需要检索 Value 数据并与 logits 执行点积。与 Query 和 Key 不同,Value 数据没有线程组的概念。如图所示,与 Key token 内存布局不同,来自同一列的元素对应于同一个 Value token。对于一个 Value 数据块,有 HEAD_SIZE 行和 BLOCK_SIZE 列,被分割成多个 v_vecs。
每个线程总是同时从相同的 V_VEC_SIZE 个 token 中获取 V_VEC_SIZE 个元素。因此,单个线程通过多次内部迭代从不同的行和相同的列检索多个 v_vecs。对于每个 v_vec,它需要与相应的 logits_vec 进行点积,后者也是来自 logits 的 V_VEC_SIZE 个元素。总体而言,通过多次内部迭代,每个 Warp 将处理一个 Value token 块。通过多次外部迭代,处理整个上下文的 Value token。
float accs[NUM_ROWS_PER_THREAD];
for ... { // Iteration over different blocks.
logits_vec = ...
for ... { // Iteration over different rows.
v_vec = ...
...
accs[i] += dot(logits_vec, v_vec);
}
}
如上方的伪代码所示,在外部循环中,类似于 k_ptr,logits_vec 遍历不同的块并从 logits 中读取 V_VEC_SIZE 个元素。在内部循环中,每个线程从相同的 token 读取 V_VEC_SIZE 个元素作为 v_vec 并执行点积。重要的是,在每次内部迭代中,线程都会为相同的 token 获取不同的头部位置元素。点积结果随后累加到 accs 中。因此,accs 的每个条目都映射到分配给当前线程的头部位置。
例如,如果 BLOCK_SIZE 为 16 且 V_VEC_SIZE 为 8,每个线程一次获取 8 个 token 的 8 个 Value 元素。每个元素来自相同头部位置的不同 token。如果 HEAD_SIZE 为 128 且 WARP_SIZE 为 32,则对于每次内部循环,Warp 需要获取 WARP_SIZE * V_VEC_SIZE = 256 个元素。这意味着 Warp 处理整个 Value token 块总共需要 8 次内部迭代。每个线程中的每个 accs 包含 8 个元素,这些元素累加在 8 个不同的头部位置。对于线程 0,accs 变量将有 8 个元素,即来自所有 8 个已分配 token 的 Value 头部第 0, 32, ..., 224 个元素。
LV¶
现在,我们需要在每个 Warp 内对 accs 执行归约。此过程允许每个线程为其在一个块中所有 token 的分配头部位置累加 accs。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
接下来,我们在所有 Warp 之间对 accs 执行归约,使每个线程能够获得所有上下文 token 的分配头部位置的 accs 累加值。请注意,每个线程中的每个 accs 仅存储所有上下文 token 的整个头部部分元素的累加值。然而,总体而言,输出的所有结果都已经计算出来,只是存储在不同的线程寄存器内存中。
代码
float* out_smem = reinterpret_cast<float*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) {
// Upper warps write to shared memory.
...
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
dst[row_idx] = accs[i];
}
// Lower warps update the output.
const float* src = &out_smem[warp_idx * HEAD_SIZE];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...
accs[i] += src[row_idx];
}
// Write out the accs.
}
Output (输出)¶
现在我们可以将所有计算出的结果从本地寄存器内存写入最终的输出全局内存。
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE
+ partition_idx * HEAD_SIZE;
首先,我们需要定义 out_ptr 变量,它指向分配的序列和分配的头部的起始地址。
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
最后,我们需要遍历不同的分配头部位置,并根据 out_ptr 写出相应的累加结果。
引用¶
@inproceedings{kwon2023efficient,
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
year={2023}
}






