亚马逊AWS官方博客

使用 Rolling Batch 加速 SageMaker LLM 模型推理性能

业务场景 & 背景介绍

对于 LLM 推理的 GenAI 实际生产应用,其推理的时延和吞吐量是非常重要的性能指标。一方面推理输出的响应时间(时延)越短,客户端的体验越好;一方面同样的时间 GenAI 应用能推理生成的 tokens 数量越多,则意味着同样资源开销下吞吐量更大,其性价比更高。

然而这两点在实施落地上却是痛点和难点,因为对于动则几十上百亿参数的 LLM 模型,其推理生成时 GPU 显存和计算的成本非常高,而且很多场景,如文案生成、报告解读等,输入的 token 长度和输出的 token 长度都超过 1K 甚至达到 10K。tokens 的长度直接关系到 LLM 推理计算时资源的开销(光加载一个 token 大概需要 1M 显存,1K token 即消耗 1G 显存),这还没有考虑用户请求量激增,并发量增长的时候,如何在资源有限的情况下,尽可能地增加处理吞吐量,同时控制响应时延。这是摆在 GenAI 生产落地层面的现实问题。

本文介绍了近期业界新的 Rolling Batch(continually batch)的批处理推理优化技术原理,并给出了在 Amazon SageMaker 上使用 vLLM 框架进行 Rolling Batch 推理优化的实践和测试对比,可以帮助客户在实际生产场景中通过简单配置,立竿见影地提升线上部署的 LLM 的推理吞吐量,降低响应时延,节省资源。

Rolling Batch 原理

对于 LLM 推理优化,通常想到的通过量化、自定义 CUDA 内核等方式进行优化的“黑盒”。然而,情况并非完全如此。

LLM 推理是内存 IO 约束,而不是计算约束。换句话说,目前将 1MB 数据加载到 GPU 的计算核心上所花费的时间比这些计算核心对 1MB 数据进行 LLM 计算所花费的时间要多。这意味着 LLM 推理吞吐量在很大程度上取决于您可以在高带宽 GPU 内存中适应多大的批处理量。所以当我们使用批处理优化的时候,在实际工作负载中可产生 10 倍或更多的令人惊讶的性能差异。

批处理优化简单来说,是指您不必每次有输入序列时都加载新的模型参数,而是一次加载模型参数,然后使用它们来处理多个输入序列。这更有效地利用了芯片的内存带宽,从而提高了计算利用率、提高了吞吐量、降低了 LLM 推理的成本。

刚才提到的一次性批量处理,就是我们常说的 batch 方式,传统的批处理方法为静态批处理,因为批大小在推理完成之前保持不变,静态 batch 示意图如下所示:

可以看到 LLM 推理服务器端将多个推理请求打包,并作为一个批次一次性交给模型的 pipeline 或者 tokenization 统一批处理,当所有推理请求完成后,统一输出给客户调用端,这样的批量处理技术利用了 GPU 批量加载 token 的优势,很大程度上提升了吞吐性能,但存在明显的缺陷:

  • 批处理的窗口大小在推理完成之前是保持不变的(比如 batch size 为 16,一次性处理 16 个 prompt 推理请求),但单个请求可能提前“完成”(比如 S3),这时候整个资源是不能释放的,因为需要等待最长 prompt 推理的 output 输出(比如 S2)。这意味着其吞吐是 GPU 在批次跟多条 prompt 中的最大生成长度正相关。
  • 另外输出到客户端的是一个 batch 数组,客户端程序需要处理多个 prompt 的输出,并且需要指定最大的等待时间,比如 2s,这样即便某个 prompt 请求 1s 即可完成,客户端仍然会等待到 2s 的批次打包窗口时间,才能交由 server 端打包该批次的推理处理。

近期业界逐渐转向动态 batch,它对于提高并发,降低响应时延,有更高性价比的场景,相对于静态 batch 数倍的性能提升,也就是本文中重点介绍的 Rolling Batch 技术,Rolling Batch 技术原理图如下所示:

如上图所示,动态 batch(Rolling Batch,也叫 Continually Batch)不是等待批次中的每个序列完成生成,而是实现迭代级调度,其中批量大小由每次迭代决定。结果是,一旦批处理中的序列完成生成,就可以在其位置插入新序列,从而产生比静态批处理更高的 GPU 利用率。

当多个 prompt 请求到达时,LLM 推理服务器端执行 Iteration 级别 batch,每当一个 output token 生成,则会检测整个处理队列,当发现有 prompt 请求已经有结束生成(EOS token id),则自动调度,拉取新的 prompt 请求的 input token 填充,不会等待其他 prompt 未完成的生成过程,整个过程不断循环,也没有静态 batch 中的打包等待时间。

Rolling Batch 框架

实现 Rolling Batch 的框架有如下关键的功能:

  • 服务端需要实现网络监听,持续接收客户端请求并不断加入到请求队列;并且检测推理的每个 step,当发现某个 prompt 请求结束的 EOS token id 时,调度并填充新的 prompt 请求填充;另外在计算层面,Rolling Batch 在一个 sequence 里面要同时计算不同 prompt 请求的 prefilling tokens(输入 token)和 completion tokens(输出 token),这二者计算方式是不一样的。
  • 除此之外,continually batch 还能够解锁新的 paged attention 的功能,因为连续的 batch,在每次 iteration 的时候可以分配非连续的显存给下一个填充的 input sequence token,意味着不需要像常规 batch 时固定分配连续显存,提高显存利用率。
  • 最后,continually batch 还能和 stream 流式输出结合,因为每个 step 都是 iteration 级别的,即每次处理都可以输出一个 prompt 请求的 token,则在服务端可以把每次生成的 token 都放到输出队列并向客户端推送,即使其没有生成完成,只需要标记 uncompletion 的 request id 对应关系。

目前支持 Rolling Batch 的服务端推理框架主要有 HF 的 text generation inference 和 vLLM,本文重点介绍业界知名度较高的 vLLM 框架的 Rolling Batch 实现。

vLLM 是 UC 伯克利团队开发的一个开源的 LLM 推理和服务引擎,它实现了上文提到的 Rolling Batch 批处理以及 PagedAttention 的全新的注意力算法,相对于静态 batch,vLLM 提供了高达数十倍的吞吐量,而无需进行任何模型架构更改,详细内容可以参考 vLLM 官方站点

本文简单讲解 vLLM 框架的核心代码,帮助小伙伴们理解其 Rolling Batch 的具体实现。

vLLM 提供了两类推理的实现,一类是 offline inference,类似于 HF pipeline 的 batch 推理接口,用于离线批量的推理生成;一类是和 openai api 类似的实时在线推理,用于服务端接收并发推理请求的应用部署,其本身也可以通过命令行拉起一个 web 服务端进行部署。

vLLM 框架 offline inference

vLLM 的 offline inference 首先使用 LLM 的接口初始化离线推理的 LLMEngine 引擎类型,使用 model 路径参数加载模型,使用 SamplingParams 传递推理参数,再通过 generate 推理接口,对传入的多条 prompt 请求进行批量推理,其调用代码示例如下:

predictor = LLM(model=model_location, tensor_parallel_size=int(tensor_parallel_degree))
tokenizer = LlamaTokenizer.from_pretrained(model_location, torch_dtype=torch.float16)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=params["max_tokens"])
result = predictor.generate(data["inputs"], sampling_params)
result_json = []
for output in result:
    prompt = output.prompt
    generated_text = output.outputs[0].text

其中 generate 的推理,即调用 LLMEngine 的 add_request 方法,step 步骤方法,该步骤方法会触发_run_engine 的核心实现:

def _add_request(
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
    ) -> None:
        request_id = str(next(self.request_counter))
        self.llm_engine.add_request(request_id, prompt, sampling_params,
                                    prompt_token_ids)

这里面 llm_engine 的 add_request,会调用 step(),也就是 Rolling Batch iteration 级别推理核心操作,即调用 model,对每一个 prompt sequence 请求推理生成 output token 及其它相关操作的过程,也就是下面代码中的 run_workers 函数操作:

def step(self) -> List[RequestOutput]:
        (seq_group_metadata_list, scheduler_outputs,
         early_return) = self._schedule()
        if early_return is not None:
            return early_return

        # Execute the model.
        output = self._run_workers(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )

_run_workers 执行 Rolling Batch 推理,首先它会调度要在下一次迭代中执行的 prompt 请求序列,然后它执行模型 generate,并用模型输出更新调度器,最后解码并返回新生成的结果,这些步骤中都会更新请求,处理及结束等不同的 prompt 队列状态。

vLLM offline inference 接口执行_run_engine 方法,每一次 iteration 迭代 step 操作完成后检测是否所有 prompt 请求完成,如果已经完成则一并输出,其代码逻辑如下:

def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
        # Initialize tqdm.
        if use_tqdm:
            num_requests = self.llm_engine.get_num_unfinished_requests()
            pbar = tqdm(total=num_requests, desc="Processed prompts")
        # Run the engine.
        outputs: List[RequestOutput] = []
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
            for output in step_outputs:
                if output.finished:
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
        outputs = sorted(outputs, key=lambda x: int(x.request_id))
        return outputs

vLLM 框架 online inference

vLLM 的在线推理接口与刚才的离线推理接口不同,是面向 web 应用服务器端的 api 推理接口,用户通常使用 vLLM 提供的 api server 命令行工具拉起应用服务器,该服务器会循环监听在线实时请求的到达,对到达的请求不断进行 Rolling Batch 的推理迭代,每次 iteration 迭代后,检测 sequences 队列是否有完成的 sequence(eos),如有则输出,并继续填充新的 prompt inputs。

这部分 web 服务端的实现是典型的异步调用方式,NIO 非阻塞多路复用,vLLM 是使用 python coroutine 协程技术来处理的。

vLLM 启动异步处理的核心代码如下所示:

def start_background_loop(self) -> None:
        """Start the background loop."""
        if self.is_running:
            raise RuntimeError("Background loop is already running.")
        self.background_loop = asyncio.get_event_loop().create_task(
            self.run_engine_loop())
        self.background_loop.add_done_callback(_raise_exception_on_finish)

如上 vLLM 的在线推理接口,vLLM api server 使用 async_llm_engine,会启动 event loop 循环事件监听,在 event loop 中启动 run_engine_loop 的协程函数,而在协程函数中,会 await 异步等待每一 model generation 的 iteration 及相关步骤的 engine_step。

 async def run_engine_loop(self):
        while True:
            await self.engine_step()
            await asyncio.sleep(0)

engine_step()的步骤和上文看到的 vLLM 离线的 LLMEngine 引擎的 step()操作是一致的,不同在该步骤是通过 coroutine 的 await 异步执行的方式,核心代码如下:

async def engine_step(self):
        """Kick the engine to process the waiting requests."""
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
        else:
            request_outputs = await self.engine.step_async()
            
async def step_async(self) -> List[RequestOutput]:
        (seq_group_metadata_list, scheduler_outputs,
         early_return) = self._schedule()
        if early_return is not None:
            return early_return

        # Execute the model.
        output = await self._run_workers_async(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )

        return self._process_model_outputs(output, scheduler_outputs)

Rolling batch on SageMaker 的使用

SageMaker 推理容器同时支持 HuggingFace 的 TGI 和 vLLM 两种动态 batch 框架,本文着重介绍 vLLM 框架在 SageMaker 上的使用。

SageMaker 使用 Large Model Inference(LMI)容器 inference 时,直接调用了 vLLM engine 的 step api,每次 iteration 迭代逐个输出 token 到输出队列,并调用 vLLM 状态 api 判断单条 request 请求是否结束,如果结束则标记并清理。

def inference(self, input_data, parameters):
        batch_size = len(input_data)
        new_requests = self.get_new_requests(input_data, parameters,
                                             batch_size)
        # step 0: register new requests to engine
        for request in new_requests:
            request_id = random_uuid()
            request.parameters.pop('seed', None)
            if "max_new_tokens" in request.parameters.keys():
                request.parameters["max_tokens"] = request.parameters.pop(
                    "max_new_tokens")
            sampling_params = SamplingParams(**request.parameters)
            self.engine.add_request(request_id, request.input_text,
                                    sampling_params)
            self.request_cache[request_id] = {
                "curr_length": 0,
                "text": "",
                "finished": False
            }
        request_outputs = self.engine.step()
        # step 1: put result to cache
        for request_output in request_outputs:
            req_id = request_output.request_id
            self.request_cache[req_id]["text"] = request_output.outputs[0].text
            if len(request_output.outputs) > 1:
                logging.warning(
                    f"Finding more than 1 output for single request {len(request_output.outputs)}"
                    f"Beam search is not supported yet, use first output by default"
                )
            self.request_cache[req_id]["finished"] = request_output.finished
        # step 2: send result back
        finished_id = []
        for (key, cache), request in zip(self.request_cache.items(),
                                         self.pending_requests):
            request.set_next_token(cache["text"][cache["curr_length"]:],
                                   self.output_formatter, cache["finished"])
            cache["curr_length"] = len(cache["text"])
            if cache["finished"]:
                finished_id.append(key)
        # step 3: clean finished requests
        for key in finished_id:
            self.request_cache.pop(key)

用户可以通过配置化方式,轻松实现推理服务器上的 Rolling Batch 推理调用,不需要安装及部署 vLLM 库及开发封装 vLLM engine 的服务器端应用,且自动 enable vLLM 的 paging attention 关注度优化等功能。

SageMaker LMI 容器镜像配置及使用

LMI vLLM 的容器镜像配置在 SageMaker 上很简洁方便,SageMaker 上通过  image_uris SDK 指定 LMI 对应的镜像版本,所示如下:

image_uri = image_uris.retrieve(
framework="djl-deepspeed",
region=sess.boto_session.region_name,
version="0.23.0"
)

如上我们指明 LMI 容器镜像为 djl-deepspeed 0.23.0 的版本,该版本即为支持 vLLM rolling_batch 的容器镜像。

然后我们再通过 serving.properties 配置文件设置对应 vLLM 部署参数,示例如下:

%%writefile serving.properties
engine=Python
option.s3url=s3://YOUR_SAGEMAKER_BUCKET/LLM-RAG/workshop/LLM_llama2_model/
option.task=text-generation
option.trust_remote_code=true
option.tensor_parallel_degree=4
option.rolling_batch=vLLM
option.dtype=fp16
option.enable_streaming=true

如上我们指明 rolling batch 的框架为 vLLM(option.rolling_batch),并指定以 fp16 半精度加载(option.dtype),且以 4 的并行度加载模型(option.tensor_parallel_degree)。

关于 LMI 镜像部署模型的详细参数配置及 sdk 使用这里不再赘述,感兴趣的小伙伴可以参考 SageMaker LMI 推理镜像官方文档

客户端调用

如上文所述,Rolling Batch 下的客户端推理调用不用等待静态 batch 打包,及处理输出的数组,正常单次调用,单次返回结果即可,如下示例代码:

prompts = [prompt_live,prompt_live,prompt_live]
def call_endpoint(prompt):
    input = {"inputs": prompt, "parameters": parameters}
    input = json.dumps(input)
    start = time.time()

    response = smr_client.invoke_endpoint(EndpointName=endpoint_name,
                                       ContentType='application/json',
                                       Accept='application/json',
                                       Body=input)
    results = response['Body'].read().decode("utf-8")
    end = time.time()
    process_time=end-start
    print("process time:"+str(int(process_time)))
    ouputJson=json.loads(results)
    print(ouputJson)


results = Parallel(n_jobs=3, prefer='threads', verbose=1)(
    delayed(call_endpoint)(prompt)
    for prompt in prompts
)

如上我们用了 3 个客户端并发线程,对三个 prompt 并行请求 vLLM 的 SageMaker Endpoint 服务端,结果显示 3 并发几乎同时返回,平均响应时间在 7,8s 左右,验证了 Rolling Batch 不需要客户端等待静态 batch 打包输出及处理数组返回:

[Parallel(n_jobs=3)]: Using backend ThreadingBackend with 3 concurrent workers.
process time:7
{'generated_text': 'The script should be delivered in a Microsoft Word document.\nThe script should be written in Chinese and should be concise, interesting, and attractive, ensuring to contain all the elements mentioned above.\nThe script should be delivered in a Microsoft Word document with the formatting and layout as per the instructions.\nThe script should be delivered in a Microsoft Word document with the formatting and layout as per the instructions. The script should be written in Chinese and should be concise, interesting, and attractive, ensuring to contain all the elements mentioned above.\nThe script should be written in Chinese and should be concise, interesting, and attractive, ensuring to contain all the elements mentioned above. The script should be delivered in a Microsoft Word document with the formatting and layout as per the instructions.\nThe script should be delivered in a Microsoft Word document with the formatting and layout as per the instructions. The script should be written in Chinese and should be concise, interesting, and attractive, ensuring to contain all the elements mentioned above. The script should be delivered in a Microsoft Word document with the formatting and layout as per the instructions. The script should be written in Chinese and should be concise, interesting, and attractive, ensuring to contain all the elements mentioned above.'}
process time:8
{'generated_text': 'I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script.\n( 1 review ) Changsha, China\nHello I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script.\nHello, I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script.\nHello! I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script.\nHello, I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script. I have a great experience in this field.\nHi, I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script.\nHello, I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script. I have a great experience in this field. Thank you for your time and consideration.\nHello, I am a professional script writer and have been writing scripts for live broadcasts for a long time. I can help you with your script. I have a great experience in this field. I can help you with your script.\nI'}
process time:8
{'generated_text': 'The script should be between 100-120 words, and the delivery time is 24 hours.\nThe script should be delivered in a format that can be easily imported into the live broadcast software.\nThe script should be written by a native Chinese speaker with excellent writing skills and a deep understanding of the target audience and the product.\nThe script should be well-structured, with clear transitions and a logical flow, to ensure that the audience can follow along easily.\nThe script should be engaging and entertaining, with a focus on creating a fun and interactive atmosphere for the audience.\nThe script should be creative and original, with a unique perspective that sets it apart from other scripts.\nThe script should be well-researched, with accurate information about the product and its features.\nThe script should be free of errors, with proper grammar, spelling, and punctuation.\nThe script should be delivered on time, within the specified deadline.\nThe script should be delivered in a format that can be easily imported into the live broadcast software, such as a .doc or .pdf file.\nThe script should be delivered with a brief summary of the product and its features, to help the host prepare for the live broadcast.\nThe script should be delivered with a brief introduction of the host, to help the audience get to know the host better.\nThe script should be delivered with a brief introduction of the product'}
[Parallel(n_jobs=3)]: Done   3 out of   3 | elapsed:    8.9s finished

我们再用 8 个 prompt 并发请求,同样的 vLLM SageMaker Endpoint 服务端进行测试:

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
process time:12
process time:14
process time:14
process time:14
process time:14
process time:14
process time:14
process time:14
[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:   14.7s finished

可以看到请求结果,延迟增长了一点(14.7s),但是吞吐量增加明显(8*300 new tokens vs 4*300 new tokens),发挥出了 Rolling Batch 的优势。

同时,SageMaker 的 LMI 推理镜像封装了流式输出的功能,如果配置为流式输出,则 LMI 容器会向客户端推送单个的 token,客户端调用 SageMaker 的 invoke stream api 即可获得 vLLM rolling batch 的单个 token 结果,客户端可以像 openai stream api 方式迭代接收服务端输出的 chunk 片段序列,并逐个展示,进一步增强客户体验,代码如下所示:

prompts = [prompt1,prompt1,prompt1,
           prompt2,prompt3,prompt4,
           prompt_live,prompt2,prompt4,
           prompt1]

def call_endpoint(prompt):
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    print("start invoke timestamp:", formatted_time)
    response_model = smr_client.invoke_endpoint_with_response_stream(
            EndpointName=endpoint_name,
            Body=json.dumps(
            {
                "inputs": prompt,
                "parameters": parameters
            }
            ),
            ContentType="application/json",
        )

    event_stream = response_model['Body']
    index=0
    for event in event_stream:
        eventJson=event['PayloadPart']['Bytes'].decode('utf-8')
        #output=extract_unicode_chars(eventJson)
        output=(eventJson)
        if index==3:
            first_ouput_time = datetime.datetime.now()
            formatted_time = first_ouput_time.strftime("%Y-%m-%d %H:%M:%S")
            print("first output token:"+output+" timestamp:"+formatted_time)
        index=index+1

如上只需要调用 invoke_endpoint_with_response_stream 接口,SageMaker vLLM LMI 会自动封装 Rolling Batch 的每次迭代输出 token,并推送到 response 响应输出,该响应输出是一个 python 迭代器,客户端遍历该 event_stream,即可获取每次流式生产的 tokens,效果如下:

[Parallel(n_jobs=3)]: Using backend ThreadingBackend with 3 concurrent workers.
start invoke timestamp: 2023-10-10 00:45:02
first output token: script timestamp:2023-10-10 00:45:02
{"generated_text": "The script should
{"generated_text": "The script should be
{"generated_text": "The script should be written
{"generated_text": "The script should be written in
{"generated_text": "The script should be written in a
{"generated_text": "The script should be written in a professional
{"generated_text": "The script should be written in a professional and
{"generated_text": "The script should be written in a professional and eng
...省略
CPU times: user 194 ms, sys: 17.7 ms, total: 211 ms
Wall time: 10.3 s
[Parallel(n_jobs=3)]: Done   3 out of   3 | elapsed:   10.3s finished

可以看到,流式调用 Rolling Batch 时,第一个 token 生成只用了 2s 即可以在客户端输出,和 openai stream 接口一样的客户体验,并且整个 3 个 prompt 并发推理在 10s 左右,和刚才的压测延迟基本一致,延续了 Rolling Batch 的高吞吐性能,一举两得。

性能压测对比

我们使用 Python Parallel 并发库,模拟生产线上多用户 prompt 请求的并发场景,使用 g5.12xlarge(A10,4 卡,24G 显存)机器,vLLM SageMaker LMI 推理镜像,部署 llama2 13b 模型进行推理请求。

SageMaker vllm 的 LMI 配置如下所示:

engine=Python
option.s3url=s3://sagemaker-us-west-2-687912291502/LLM-RAG/workshop/LLM_llama2_model/
option.task=text-generation
option.trust_remote_code=true
option.tensor_parallel_degree=4
option.rolling_batch=vllm
option.dtype=fp16
option.enable_streaming=true

我们同样使用 Python Parallel 库,并发多线程请求部署 vLLM 的 SageMaker endpoint 终端节点,通过多个线程并行处理多个 prompt 输入,模拟多用户客户端的实时请求,代码示例如下:

prompts = [prompt_live,prompt_live,prompt_live,prompt_live,
           prompt_live,prompt_live,prompt_live,prompt_live]
...省略

results = Parallel(n_jobs=8, prefer='threads', verbose=1)(
    delayed(call_endpoint)(prompt)
    for prompt in prompts
)

没有并发时,响应时延为 7.9s,输出 37tokens/s 左右

[Parallel(n_jobs=1)]| elapsed:   7.9s finished

当并发增加到 4,延迟只增加了 2s 左右,但吞吐量增加到 106 tokens/s

[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   11.3s finished

当并发再增加到 8,延迟为 14.9s,吞吐量为 171 tokens/s

[Parallel(n_jobs=8)]: Done   8 out of   8 | elapsed:   14.9s finished

并发增加到 16,延迟为 22.4 s,吞吐 218 tokens/s

[Parallel(n_jobs=16)]: Done  16 out of  16 | elapsed:   22.4s finished

可以明显看到 Rolling Batch 的吞吐量性能,当并发增长,服务端 batch 越大,优势越明显。

以下为详细的机型,input token length,max new token length,在不同并发请求场景下,benchmark 的响应时延和吞吐量(tokens/s)的情况:

A B C D E
1 max new token length input token length 并发 prompts 请求 响应时延(秒) 吞吐量(tokens/s)
2 300 500 3 7.9s 170
3 300 500 4 10.3s 200
4 300 500 8 12s 230
5 300 500 16 15s 300

总结

本文介绍了近期业界新的 Rolling Batch(Continually Batch)的批处理推理优化技术原理,讲解了 vLLM 框架 Rolling Batch 的具体实现的核心代码,以及在 Amazon SageMaker 上使用 vLLM 框架进行 Rolling Batch 推理优化的部署实践,并给出了不同并发下的 benchmark 测试对比。客户在实际生产场景中可以参考本文中的配置及压测性能,使用 SageMaker vLLM 部署方案显著提升线上部署的 LLM 的推理吞吐量,降低响应时延,降低 TCO。

参考资料

SageMaker SDK 模型推理部署:https://docs.aws.amazon.com/zh_cn/sagemaker/latest/dg/deploy-model.html

SageMaker LMI 推理镜像:https://docs.aws.amazon.com/zh_cn/sagemaker/latest/dg/large-model-inference-dlc.html

vLLM 推理框架:https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html

LMI vLLM 推理代码示例:https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py

本篇作者

唐清原

亚马逊云科技高级解决方案架构师,负责 Data Analytic & AIML 产品服务架构设计以及解决方案。10+数据领域研发及架构设计经验,历任 IBM 咨询顾问,Oracle 高级咨询顾问,澳新银行数据部领域架构师职务。在大数据 BI,数据湖,推荐系统,MLOps 等平台项目有丰富实战经验。