亚马逊AWS官方博客
基于 Amazon SageMaker 的多模态模型训练、推理及批量表征提取
背景
随着大语言模型(LLM)的发展,视觉语言模型(VLM)的应用及落地也在越来越多的场景中被关注及提出。相比于传统的检测类的图像单一模态模型,图文多模态模型对于图像信息有着更好的理解,主要体现在其与人类理解的对齐能力上,本文通过 Amazon SageMaker 展示其对多模态大模型的训练及推理基础设施所带来的简化。同时我们发现在较多线上业务场景(如内容、电商)的实践中,非结构化数据如文本、图像可以使用其向量表征使其信息得到更充分的挖掘和利用,因此本文同时针对跨模态表征在 Amazon SageMaker 的批量抽取提供了示例,使得其可以快速接入不同场景进一步帮助业务提效。
主流跨模态生成的原理概要
以下用实际场景中使用较多的两类模型范式为例来进行简要的介绍。
BLIP2
BLIP2 [1, 2]提出了一种利用已经预训练好的视觉和文本单独模态的基础大模型,来进行多模态对齐训练的方法。其设计了 Querying Transformer (Q-Former),用于桥接视觉和文本两种表征,使得 image encode 得以和 LLM 进行交互,同时也是整个管道中唯一可训练的模块。
针对这种模型管道,BLIP2 提出了两阶段预训练策略:
Stage1,使用 3 个并行任务来跨模态的对齐:
- Image-Text Contrastive Learning – 单图表征及单文表征的对比学习任务
- Image-grounded Text Generation – 单向图生文任务
- Image-Text Matching – 利用交互表征来进行图文匹配的二分类任务
Stage2,引入 LLM。使用 Stage1 中的图生文任务的输出,与输入文本一起送入 LLM。并直接利用 LLM 的目标及 Loss 来进一步更新 Q-Former 参数。
生成(推理)时,BLIP2 整体流程包括三个阶段:
- 使用 ViT 作为图像编码器,生成图片的视觉表征(向量)
- 利用提出的 Querying Transformer (Q-Former),将 1 中的单视觉表征将转化为对齐文本后的交叉表征
- 将 2 中的交叉表征,叠加提示词(Prompt)并送入 LLM,使用 LLM 来生成视觉相关的文本。这可以是单纯 Caption、或是 Prompt 中的 Question
LLaVA
自然语言处理领域的指令微调(Instruction Tuning)可以帮助 LLM 理解多样化的指令并生成比较详细的回答。LLaVA [3]首次尝试构建图文相关的指令微调数据集来将 LLM 拓展到多模态领域。具体方法为:基于 MSCOCO 数据集,每张图有 5 个较简短的基准真相描述(Ground Truth)和包括类别和位置的识别矩形框(Object BBox)序列,并将这些作为 Text-Only GPT4 的输入,通过提示词(Prompt)的形式让 GPT4 生成 3 种类型的文本:1)关于图像中对象的对话;2)针对图片的详细描述;3)和图片相关的复杂的推理过程。注意,这三种类型都是 GPT4 在不看到图片的情况下根据输入的文本生成的,为了让 GPT4 理解这些意图,作者额外人工标注了一些样例用于语境学习(In-Context Learning)。其模型结构上,采用 CLIP 的 ViT-L/14 作为视觉编码器,LLaMA 作为文本解码器,通过一个简单的线性映射层将视觉编码器的输出映射到文本解码器的词嵌入空间,如下图所示:
模型训练分为两阶段:
Stage1,跨模态对齐预训练,从 CC3M 中通过限制 caption 中名词词组的最小频率过滤出 595k 图文数据,冻住视觉编码器和文本解码器,只训练线性映射层;
Stage2,第二阶段进行指令微调,一版针对多模态聊天机器人场景,采用自己构建的 158k 多模态指令数据集进行微调;另一版针对 Science QA 数据集进行微调。微调阶段,线性层和文本解码器(LLaMA)都会进行训练并进行参数更新。
小结
从以上两类模型的设计上可以看出,多模态的模型架构从原理上都是相似的,均是多阶段管道的形式,其中核心的还是将 vis & lan 的表征进行统一,并通过管道中位于下游的 LLM 来进行理解。这个统一表征的过程可以通过 BLIP2 中的 Q-Former 来实现,或者 LLaVA 中的线性投影(Linear Projection)来实现。对于此类模型,其训练也根据不同的管道的构成分为了不同的阶段。比如 BLIP2 的预训练 1/2 Stage,其中差异点主要在于参数冻结(Freeze)的部分,这也是影响训练成本上主要因素之一。
在 Amazon SageMaker 上使用 LAVIS BLIP2
本节以 LAVIS BLIP2 [4]为例,来演示如何基于 Amazon SageMaker 进行模型训练,及交叉表征提取。
模型训练
很多企业具有特有场景定制化调优和私域数据保护的需求,因此相对于模型即服务(MaaS)类型产品更倾向于通过模型微调来提高应用效果。而 Salesforce 对于 BLIP2 的构建是直接整合在其 LAVIS [4, 10]框架中的,其一定程度上封装了训练或推理过程中的相对繁琐的配置过程。其模型中间结果、训练图片以及训练标注的加载,均可在配置文件中统一配置。下例中,我们将为大家展示如何在 SageMaker 上使用 coco [12]数据集进行模型微调,详细代码见引用[5]。
在 LAVIS/lavis/configs/default.yaml
中,将路径修改为训练实例本地 NVMe 路径 [6]
修改 LAVIS/lavis/datasets/builders/caption_builder.py
重载其 build_datasets()
方法,在训练实例当前节点的主进程中,利用 s5cmd [7]加速拷贝,从 Amazon S3 持久化存储将训练数据集(包括 image 及 caption)拷贝至训练实例的 NVMe 存储。这里,Lavis 框架帮助我们封装了 Multi-GPU 训练的进程判断逻辑,因此可以直接使用 is_main_process()
来替代 if 0 == LOCAL_RANK
等条件判断。
同时对训练的全局配置文件 lavis/projects/blip2/train/caption_coco_ft.yaml
进行调整
将训练数据存放位置调整为前序拷贝至的 NVMe 目标路径,并将模型输出存放位置调整为/opt/ml/model/
,SageMaker 在完成训练后将该路径下的文件拷贝至如下 Estimator 启动器中预设的 S3 持久化存储路径。
如果在.yaml
配置文件中开启 evaluation,其所使用到的 pycocoevalcap
库需要依赖 JAVA Runtime(JRE)运行环境,因此需要在预置镜像的基础上新增安装,Dockerfile 的调整如下:
最后通过 SageMaker Estimator 的 estimator.fit()
启动训练。对于垂直域图文训练集来说,BLIP2 官方建议从 Stage2 预训练开始可以获得更好的效果,其实践方式与以上微调配置过程完全一致。
模型推理
LAVIS 框架提供了封装好的推理接口,如 LAVIS/examples/blip2_instructed_generation.ipynb
[4] 中提供的 Captioning/VQA 任务的推理示例。
可以使用 SageMaker Endpoint 及 LMI(Large Model Inference)容器[11]根据 LAVIS 提供的推理接口进行快速部署。基于线上业务的推理,可以根据不同推理延迟以及推理负载大小(payload)的需求,选择 SageMaker 的 Real-time 或 Async Endpoint。离线批量推理任务除了 SageMaker Batch Transform 之外,从灵活性的角度,如下展示如何用 SageMaker Training 所提供的按需集群来进行批量推理。
交叉表征提取
多模态模型除了在构建跨模态的应用上,其中间态的交叉表征也可以作为对图文物料的多模态理解(即经过蒸馏的优势特征)用于其他类型的场景如排序等。对于生产场景中无论是线上的实时推理或离线的预生成,批量(多图)推理的能力是必须具备的。本节以批量特征抽取为例进行说明。
在推理代码中,首先对 payload 的注入形式进行简单的调整,使多个图片编码合并为一个 list tensor。
此时得到一个 tensor shape 是[batch_size, 3, 364, 364]
的 img_batch
。第 0 维为该 batch 的大小,其他维度分别为图片的通道数及图片原始尺寸。同时构建相同 batch size 的 text caption,并传入模型。
推理结束后可以得到一个 tensor shape 是[batch_size, 32, 768]
的推理结果,其中第 0 维仍为该 batch 的大小,而第 1、2 维度则是经过 Q-Former 编码之后得到的一个对应单图 32*768 的 tensor 表征。直接针对该第 1 维进行 mean pooling:
可以得到一个[batch_size, 768]
维的 tensor,此时将每张图片都有了一个包含图文 cross 信息的新 dense 表征。
这里可以将该向量表征直接推送至 SageMaker Online Feature Store 的 InMemory
tier [8],其底层封装了 Amazon ElastiCache for Redis 服务[9],可以提供亚毫秒级延迟并支持集合类型的特征(包括 Vector 类型),适用于线上如 Feeds Ranking 等其他模型的推理。同时,SageMaker Feature Store 天然具备了 online / offline 的同步机制,可以使用 Amazon Athena 对离线特征进行分析、拼接等处理。构建 SageMaker 托管的 ElastiCache Online Feature Store 及特征注入使用示例可以参考[13]。作为更一般的形式,如下展示将以上 vector 进行.csv
格式的存储。
配置及任务启动
如[5],配置常规的 SageMaker Estimator,首先需要将输入图片拷贝到集群,或直接存储于 Amazon FSx for Lustre 并 mount 进行读取。如采用前者传输方式,考虑到输入为图片类型的小文件集合,建议使用 s5cmd [5]的并发能力对传输过程进行加速。此外,需要在主进程上将上游训练任务保存的 checkpoint 文件从 S3 拉取至算力集群。
同时在配置文件中 LAVIS/lavis/configs/models/blip2/blip2_coco.yaml
,写入以上 checkpoint 在 SageMaker 按需实例上的本地路径。
最后通过 estimator.fit()
来启动该批量特征生成任务。
效果评估
根据上述基于 Amazon SageMaker 平台的训练及批量表征提取过程,这里使用 coco2014 [12] 数据集训练 100 个 iteration 并使用 finetune 后的 checkpoint 进行批量表征提取,作为效果展示。对 2000 个样本所获取的[2000, 768]
维数据点使用 t-SNE 降至 2 维进行可视化如下,可以看出 Q-Former 产生的向量表征可以较明显地体现出 finetune 100 个 iteration 后其分布发生的变化。
同时,我们针对以上 2000 个样本对批量抽取的性能进行了评估(使用 Amazon SageMaker Training 所提供的按需临时集群/实例,1 台单卡 A10 GPU ml.g5.4xlarge 实例进行测试)。使用不同的批次大小对该 2000 个样本进行遍历抽取,如下所示:
并通过测量整体耗时最终计算得到单条样本的平均抽取耗时。可以看出通过调整原始推理接口并进行批量推理,使得单个样本的推理耗时及成本有了较明显的降低。
总结
BLIP2 除了作为首个“unlock the capability of zero-shot instructed image-to-text generation” [1]的模型范式之外,该类型的模型 pipeline 设计范式可能与生产场景的一般诉求更加匹配,原因总结如下:
- 本质上是一个通过 Q-Former 链接的松耦合训练范式,可以自由对 Image Encoder 以及 LLM 模块进行插拔
- 预训练和微调阶段均不需要调整 LLM 参数,训练成本相对低
- 训练过程中 LLM 模块冻结,因此得到的中间 Q-Former 直觉上对于跨模态的理解更加充分(对比 LLAVA [3]的 LLM non-frozen 训练范式,一方面其中间的 MLP 表达能力是天然有限的,一方面由于 LLM 参数是可更新的,两方面均会使得 LLM 承载一定的跨模态理解能力),因此其生成的 Embebedding 更适合作为表征,可以更好的迁移至其他场景进而提升其他场景的模型性能
本文以 LAVIS BLIP2 为例,展示了其在 Amazon SageMaker 平台上的训练及推理过程。同时通过对原有推理接口进行简单的调整及适配,使得 LAVIS BLIP2 可以在 Amazon SageMaker 所托管的基础设施之上,快速进行批量的图文对粒度的特征抽取以赋能更多算法场景。更多相关的场景化应用、算法提效示例等将在后续文章中做进一步探讨。
参考链接
[1] Salesforce BLIP2 Blog – https://blog.salesforceairesearch.com/blip-2/
[2] BLIP2 Paper – https://arxiv.org/abs/2301.12597
[3] LLAVA 介绍 – https://llava-vl.github.io/
[4] Saleforce LAVIS Repo – https://github.com/salesforce/LAVIS.git
[5] LAVIS BLIP2 on Amazon SageMaker 示例代码 – https://github.com/haozhx23/Lavis-Blip2-on-SageMaker
[6] Amazon SageMaker Training Instance Storage – https://docs.aws.amazon.com/sagemaker/latest/dg/model-train-storage.html
[7] S5cmd 加速传输工具 – https://github.com/peak/s5cmd
[8] SageMaker Online FeatureStore – https://docs.aws.amazon.com/sagemaker/latest/dg/feature-store-storage-configurations-online-store.html
[9] Amazon ElastiCache – https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/WhatIs.html
[10] Saleforce LAVIS 介绍 – https://opensource.salesforce.com/LAVIS/latest/intro.html
[11] SageMaker Large Model Inference – https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-dlc.html
[12] coco 数据集 – https://cocodataset.org/
[13] SageMaker Feature Store Workshop – https://github.com/aws-samples/amazon-sagemaker-feature-store-end-to-end-workshop/tree/main