今天的主要成果是写了一个pseudo的Continous Batching框架。pseudo的意思是,engine内部的调用其实仍然是按照loop的模式一个个进行,但scheduler的功能本身是完善的,step模式的完整流程也已经写完。这么做主要是因为paged Attention本身作为框架的必须组件还没有实现——在实现paged Attention的过程中,从GQAAttention算子的模式本身到engine到KV cache的cache类和BlockManager全部都得重构或者从头开始写——因此在一个必然要推翻的框架内写一版必然要在接下来几天内全部重构的代码没什么意义,不如接下来直接实现完整的既有paged又有batching的版本。
毕竟我又没有交付压力🤓。
Continous Batching是什么
是一种可以增加推理框架调度效率的方法。在这个技术普及之前,Batching作为提高推理效率的技术已经被广泛应用。因此,要首先理解Batching在推理当中为什么可以提高推理效率。
一些博客或者技术分享可能会说Batching能提高算术强度什么的,但这其实是一种误解,提高算术强度的技术是speculative decoding,而不是Batching技术。这是因为Batching产生的矩阵其实是对角线分块的:互相并不产生计算交叠,反而会产生大量0的空洞,这一点(在naive Kernel里)通过Attention mask遮盖。同时,KV cache本身的访问也是不会减少的,每一个之前算过的依然要完整读一遍。它最大的意义其实在于减少了权重的读取次数——原本计算一个sequence,吐一个token就要从主存(低层次的存储层次)读一遍完整的权重到高层次当中,现在计算$N$个sequence才需要读一遍,中间的过程均驻留在高层次存储里不离开。对于权重很大的大模型来说,这个效率提升其实是非常重要的。从这个角度来说,Continous Batching计的优化思想根基和GEMM矩阵乘法的优化也没有根本性的差别——计算机基本原理,加速大概率事件和有效利用时空局部性的又一个良好体现。
Batching产生了另一个问题:不同的sequence不太可能是同等长度的,而且不太可能同时吐出eos,到达停止状态。将类似长度的sequence打包在一起可以部分地解决第一个问题,但不能解决第二个问题,这个问题同样被后面要做的paged Attention大部分地解决。这就导致一个结果:在一个batch中的一部分sequence已经吐出了eos,变成停止状态了,而另外一些则没有。只要一个sequence还在跑,整个batch是不能停止的。这个时候,batch可能反而变成负收益:许多读写已经是完全空转的了,只是为了陪跑还在自回归过程的sequence,但内存读写消耗反而一点也没有减少。Continous Batching解决的是第二个问题。
所以实际上,Continous Batching和paged Attention本就是深度耦合的两个技术:如果没有paged Attention,Continous Batching每次都要完整的重新分配和整理内存,那么Continous Batching本身就毫无意义了。反过来说,没有Batching和Continous Batching,在单序列请求下,内存碎片这个问题可能还没有那么严重,paged Attention的威力也完全无法有效发挥。试想:在有了paged Attention之后,Continous Batching其实完全不需要进行任何内存操作:只需要把换出的sequence的KV cache页表索引换掉,变成下一个sequence的就可以了,也就是说对sequence本身其实无感知。这也难怪这两件事被同一时间提出,作者在研究的过程中很可能遇到了这个问题然后把它想明白了。
Continous Batching的设计、实现和调度策略
其实这个技术的基本实现也在思路上比较顺畅。它的基本思想是把所有sequence分成waiting、prefilling、decoding和finished四种状态(也有的实现是waiting、running和finished三种,取决于你如何写你的推理engine和kernel,两者是耦合的),以及engine的step模式设计。通过一个scheduler的调度,scheduler管理所有请求、赋予每一个请求一个unique的id编号,并且决定下一个step处理哪些,如何改变序列的状态,并且维持每个step内的负载大致均衡。engine则负责进行prefilling、decoding,并且检查decoding序列中的情况,将已经达到max_len的和已经生成出eos的序列中的sequence移出,进入finished,而finished则负责记录生成的结果、等待外部请求获取,并且提供将finished队列清空的方法,实际上是为了方便外部请求的设计。
在这个过程中有意思的反而是scheduler的调度策略设计:由于prefilling和decoding的计算资源需求并不相同(其实这里已经初步开始产生PD分离的思想了:异质化的负载调度总是比较难处理的),scheduler如何同时确保效率和公平性就是一个问题。prefilling一次应该多少?decoding一次应该多少?这个问题相当open,理论上来说应该是相当值得研究的(尤其是对于生产环境、ToC环境来说,因为它会极大的影响TTFT指标)。在目前的版本中,为了简单起见,实现的是先到先服务(FCFS)和Prefilling+Decoding中的sequence数量总和固定。这是一个粗糙的调度,不过对于框架的可运行性没有影响。也许后面可以考虑把策略拿出来单独设计。
Continous Batching的相关代码
主要的类有两个,分别是Request和Scheduler。Request记录一个请求的基本信息,Scheduler基于前者的基本信息进行调度。需要注意的是,在我的框架里我本人选择了把KVcache和Request放在一起管理。其他框架似乎不是这样,但按这个模式似乎也没什么问题。
Request类:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
| class RequestStatus(Enum):
WAITING = "waiting" # 在 waiting 队列,还没做 prefill
PREFILL = "prefill" # 本步正在做 prefill
DECODING = "decoding" # 已经 prefill 完,正在 decode
FINISHED = "finished" # 已完成
''' 请求对象,包含请求的所有信息和状态
- request_id: 请求 ID,唯一标识一个请求
- input_ids: 输入的 token ids,shape (1, init_seq_len),包含整个 prompt
- generated_ids: 已经生成的 token ids,shape (1, ),初始为 空
- max_new_tokens: 最多生成多少个 token
- temperature: 采样温度,传递给 sampler
- top_p: top-p 截断,传递给 sampler
- kv_cache: 每个请求独享一个 KV cache 实例,存储生成过程中的 KV 状态
'''
class Request:
request_id: int
input_ids: List[int] # (1, init_seq_len),包含整个 prompt
generated_ids: List[int] # (1, ),包含已经生成的 token ids,初始为 空
max_new_tokens: int
temperature: float
top_p: float
kv_cache: Optional[KVCache] # 每个请求独享一个 KV cache 实例
request_status: RequestStatus
has_eos_token: bool # 是否已经生成 eos_token,scheduler 不直接接触 tokenizer 和 eos_token_id,这个由 engine 在 decode_step 后更新
def __init__(self, request_id: int, input_ids: List[int], max_new_tokens: int, temperature: float, top_p: float):
self.request_id = request_id
self.input_ids = input_ids
self.generated_ids = [] # 初始为 空
self.request_status = RequestStatus.WAITING
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.kv_cache = None
self.has_finished_notification = False # engine改变这个状态,scheduler根据这个状态改变 request_status和移出队列
def is_max_len_finished(self) -> bool:
return len(self.generated_ids) >= self.max_new_tokens
@property
def prompt_len(self) -> int:
return len(self.input_ids)
@property
def total_len(self) -> int:
return len(self.input_ids) + len(self.generated_ids)
|
Scheduler类:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
| class Scheduler:
_next_request_id: int # 用于生成唯一的 request_id
max_num_seqs: int # 同时处理的最大请求数,超过这个数的新请求会排队等待
kv_cache_cls : type[KVCache] # KVCache 的类,用于创建请求的 KV cache 实例
kv_cache_kwargs = {} # 创建 KV cache 实例的参数字典
waiting: List[Request] # 还没开始处理的请求
prefilling: List[Request] # 正在 prefill 的请求
decoding: List[Request] # 正在 decode 的请求
finished: List[Request] # 已完成的请求
def __init__(self, kv_cache_cls: type[KVCache], kv_cache_kwargs={}, _next_request_id: int = 0, max_num_seqs: int = 4):
self._next_request_id = _next_request_id
self.max_num_seqs = max_num_seqs
self.kv_cache_cls = kv_cache_cls
self.kv_cache_kwargs = kv_cache_kwargs
self.waiting = []
self.prefilling = []
self.decoding = []
self.finished = []
''' 插入新请求
- input_ids: 输入的 token ids,shape (1, seq_len),包含整个 prompt
- max_new_tokens: 最多生成多少个 token
- temperature: 采样温度,传递给 sampler
- top_p: top-p 截断,传递给 sampler
- cache_cls: KVCache 的类,用于创建请求的 KV cache 实例
- cache_kwargs: 创建 KV cache 实例的参数字典
return: request_id,唯一标识一个请求'''
def insert_request(self,
input_ids: List[int],
max_new_tokens: int,
temperature: float,
top_p: float, ) -> int:
request = self.create_request(input_ids, max_new_tokens, temperature, top_p, self.kv_cache_cls, self.kv_cache_kwargs)
self.add_request(request)
return request.request_id
'''' 创建新的请求对象
- input_ids: 输入的 token ids,shape (1, seq_len),包含整个 prompt
- max_new_tokens: 最多生成多少个 token
- temperature: 采样温度,传递给 sampler
- top_p: top-p 截断,传递给 sampler
- cache_cls: KVCache 的类,用于创建请求的 KV cache 实例
- cache_kwargs: 创建 KV cache 实例的参数字典
return: request_id,唯一标识一个请求'''
def create_request(self,
input_ids: List[int],
max_new_tokens: int,
temperature: float,
top_p: float,
cache_cls,
cache_kwargs) -> Request:
request_id = self._next_request_id
self._next_request_id += 1
return Request(request_id, input_ids, max_new_tokens, temperature, top_p)
'''' 添加新请求到等待队列
- request: 新的请求对象
'''
def add_request(self, request: Request):
self.waiting.append(request)
''' 调度器主循环
return: 2个列表,分别是当前处于 prefilling、decoding 状态的请求列表
- prefilling: 正在 prefill 的请求
- decoding: 正在 decode 的请求
'''
def schedule(self) -> tuple[List[Request], List[Request]]:
# 1. 对 decoding 队列中的请求进行 decode,完成后移动到 finished 队列
for request in self.decoding[:]:
# is_max_len_finished 由 scheduler判断
# has_finished_notification 由 engine 在 decode_step 后更新,解耦两者的逻辑,scheduler 不直接接触 tokenizer 和 eos_token_id
if request.has_finished_notification == True:
request.request_status = RequestStatus.FINISHED
self.decoding.remove(request)
self.finished.append(request)
# 2. 对 prefilling 队列中的请求进行 prefill,完成后移动到 decoding 队列
for request in self.prefilling[:]:
# prefill 完成后:
request.request_status = RequestStatus.DECODING
self.prefilling.remove(request)
self.decoding.append(request)
# 3. 从 waiting 队列中取出请求,放入 prefilling 队列,直到达到 max_batch_size
# 不进行kv_cache的创建和初始化
while self.waiting and self.num_in_progress < self.max_num_seqs:
# 先到先服务策略
request = self.waiting.pop(0)
request.request_status = RequestStatus.PREFILL
self.prefilling.append(request)
# kv cache必须定长,否则报错
# request.kv_cache = self.kv_cache_cls(**self.kv_cache_kwargs)
return self.prefilling, self.decoding
'''' 清空 finished 队列,释放资源'''
def clear_finished(self):
self.finished.clear()
def get_running_requests(self) -> List[Request]:
return self.prefilling + self.decoding
@property
def num_waiting(self) -> int:
return len(self.waiting)
@property
def num_prefilling(self) -> int:
return len(self.prefilling)
@property
def num_decoding(self) -> int:
return len(self.decoding)
@property
def num_in_progress(self) -> int:
return len(self.prefilling) + len(self.decoding)
@property
def num_finished(self) -> int:
return len(self.finished)
|
值得指出的是id用int是有明确缺陷的:str和uuid作为id很可能是更好的选项。在这里用int只是因为它简单,而且我的卡大概也不会有超过2147483647个请求。
测试一遍通过,这里就不贴了。接下来就是实现真正的paged Attention了,个人认为它的意义可能不太亚于KV cache本身,直接开启了一小个时代级别的想法。