亚马逊AWS官方博客

多模态大模型应用实践(一)- 利用微调 LLaVA 实现高效酒店图片分类

需求背景

在当今数字化时代,在线旅行预订平台已成为旅游行业的重要组成部分。平台的日常运营常常面临着一个关键挑战:如何高效、准确地对海量酒店图片进行分类。准确的图片分类不仅能提升用户浏览体验,还能显著提高平台的运营效率和产品上架速度。

然而,随着数据量的急剧增加,传统的人工分类方法已经难以应对。面对每天可能新增的数十万张图片,人工处理不仅耗时耗力,还容易出现分类不一致。因此,一个自动化、高精度的图片分类解决方案变得尤为重要。

本文将介绍如何利用 Amazon SageMaker 部署 LLaVA 模型,实现酒店图片的自动化、高精度分类,以应对千万级别图片的处理需求,同时显著降低运营成本。

具体目标:

  1. 准确分类酒店图片(如房间、大堂、泳池、餐厅等几十余种)。
  2. 高效处理千万级别的存量图片,同时控制推理成本。

方案概述

近年来,多模态 AI 模型(能同时处理文本和图像的模型)取得了显著进展。商业模型如 GPT-4o、Claude3.5 的多模态能力已经相当强大,可以直接用于图片分类任务。然而,在大规模应用场景中,这些模型仍存在一些局限性:

  • 模型在自定义标签分类场景精度有上限,需要大量提示词工程的工作;
  • 模型升级可能导致已积累的提示词失效;
  • 推理的成本较高。

考虑到这些因素,我们选择了开源的 LLaVA 作为基础模型,并使用私域数据进行微调。微调是一种将预训练模型适应特定任务的技术,能够在保持模型通用能力的同时,显著提升其在特定领域的表现。这种方法能够实现自主可控、性能达标且具有成本效益的图片处理模型。

同时,我们采用 vllm 推理加速框架,进一步提升模型的吞吐量。vllm 是一个高效的大语言模型推理引擎,能够显著提高模型的处理速度,这对于处理大规模图片数据集尤为重要。

LLaVA 模型简介

LLaVA(Large Language and Vision Assistant)是一个强大的多模态 AI 模型,它结合了预训练的大型语言模型和预训练的视觉编码器。这种结构使 LLaVA 能够同时理解和处理文本和图像信息,使其成为多模态任务(如图像分类、图像描述等)的理想选择。

Figure 1 LLaVA Architecture

本次我们使用 LLaVa-NeXT(也称为 LLaVa-1.6),它是 LLaVA 的最新版本。相较于前代模型 LLaVa-1.5,LLaVa-1.6 通过以下改进显著提升了性能:

  1. 提高了输入图像分辨率,使模型能够捕捉更多图像细节;
  2. 在改进的视觉指令调优数据集上进行训练,增强了模型的理解能力;
  3. 显著提升了 OCR(光学字符识别)能力,使模型更擅长识别图像中的文字;
  4. 增强了常识推理能力,使模型能够更好地理解图像内容的上下文。

本项目使用的具体版本是基于 Mistral 7B 的大语言模型:llava-hf/llava-v1.6-mistral-7b-hf。Mistral 7B 是一个相对轻量级但性能优秀的语言模型,这使得我们的解决方案既高效又经济实惠。值得一提的是,LLava 系列适配多种大语言模型的语言头,这些模型在不同的下游任务的表现各有优劣,读者可以参考各大榜单,进行最新的模型选择,在本方案的基础上快速切换。

数据准备

高质量的训练数据对于模型性能至关重要。因此我们需要精心准备训练数据集。具体步骤如下:

  1. 收集各类酒店场景的图片数据集:确保图片种类和数量尽可能丰富,覆盖各种可能的场景(如不同类型的房间、各种风格的大堂、室内外泳池、各式餐厅等)。
  2. 为每张图片标注相应类别:这一步骤需要专业知识,确保标注的准确性和一致性。
  3. 构建图像-文本对:这是训练数据的核心。每个训练样本应包含一张图片和与之相关的问题-答案对。例如,问题可以是”这张图片展示的是什么类型的酒店设施?”,答案则是相应的类别。

为了高效管理这些训练数据,我们推荐使用 Hugging Face 的 datasets 包。这个强大的工具不仅可以帮助我们下载和使用开源数据集,还能高效地进行数据预处理。使用 datasets,我们可以将数据构造成如下格式:

from datasets import Dataset, DatasetDict, Image, load_dataset, load_from_disk
dataset = load_from_disk('data.hf')
dataset['train'][0]
{
 'id': 133,
 'images': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=550x412>,
 'messages': [
   {
     'content': [
       {
         'index': None,
         'text': 'Fill in the blank: this is a photo of a {}, you should select the appropriate categories from the provided labels: [Pool,Exterior,Bar...]',
         'type': 'text'
       },
       {
         'index': 0, 
         'text': None, 
         'type': 'image'
       }
     ],
     'role': 'user'
   },
   {
     'content': [
       {
         'index': None,
         'text': 'this is a photo of Room.',
         'type': 'text'
       }
     ],
     'role': 'assistant'
   }
 ]
}
Python

小提示:在构造训练数据集的 content.text 时,提示词的格式对下游任务具有很大的影响,我们测试发现,使用接近于预训练 clip 的格式模版:this is a photo of {} ,能够提升下游任务的准确率~5%。

模型训练

数据准备完成后,下一步是进行模型微调。我们使用 TRL(Transformer Reinforcement Learning)训练框架进行模型微调,基于 deepspeed 进行分布式训练。

以下是关键的训练命令及其重要参数:

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft_vlm.py \
    --dataset_name customer \
    --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --output_dir sft-llava-1.6-7b-hf-customer2batch \
    --bf16 \
    --torch_dtype bfloat16 \
    --gradient_checkpointing \
    --num_train_epochs 20 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 100 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --per_device_eval_batch_size 8
Python

关键参数说明:

–dataset_name:指定使用的数据集

–model_name_or_path:基础模型路径

–save_steps:每 100 步存储一次模型 checkpoint

–num_train_epochs:训练轮数,这里设置为 20 轮

–learning_rate:学习率,这里设置为 2e-5

–per_device_train_batch_size:每个设备的训练批次大小,这里设为1,注意这里由于微调数据量较小,建议使用较小的 batch size 提升精度表现

经测试,在一台配备 Nvidia H100 GPU 的 P5 实例上,训练 1000 张图片大约需要 10 分钟完成训练。这个时间可能会根据具体的硬件配置和数据集大小有所变化。训练结束后,我们将 checkpoint 上传至 S3,为后续的推理部署做准备。

部署推理

模型训练完成后,下一步是将其部署为可用的推理端点,由 Amazon SageMaker 为我们托管。这里我们采用了基于 DJL(Deep Java Library)的推理框架,将微调后的 LLaVA 1.6 模型部署在 g5.xlarge 实例上。

部署过程主要包括以下步骤:

1. 准备 serving.properties文件,这个文件用于指定推理框架和微调模型的位置:

engine = Python
option.rolling_batch=vllm
option.tensor_parallel_degree = max
option.max_rolling_batch_size=64
option.model_loading_timeout = 600
option.max_model_len = 7200
option.model_id = {{s3url}}
Python

这里我们使用 vllm 作为推理引擎,它能够显著提升推理速度。

2. 将配置目录打包上传到 S3,然后使用以下代码完成推理端点的部署:

from sagemaker.model import Model
model = Model(
    image_uri=inference_image_uri,
    model_data=s3_code_artifact,
    role=role,
    name=deploy_model_name,
)
predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.xlarge",
    endpoint_name=endpoint_name
)
Python

这段代码使用 SageMaker 的 Model 类来创建和部署模型。我们指定了模型镜像、模型数据位置、IAM 角色等信息,然后调用 deploy 方法来创建推理端点。

3. 部署完成后,我们可以测试推理端点。以下是一个测试示例,我们构造一个包含文本和图片的请求:

请求包含一个文本问题和一张 base64 编码的图片,模型将分析图片并回答问题。

推理结果样例:

inputs = {
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Fill in the blank: this is a photo of a {}"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{base64_image}"
                    }
                }
            ]
        }
    ],
    "max_tokens": 256
}
output = predictor.predict(inputs)
{
 'id': 'chatcmpl-140501937231984',
 'object': 'chat.completion',
 'created': 1732716397,
 'choices': [{
   'index': 0,
   'message': {
     'role': 'assistant',
     'content': ' this is a photo of Terrace/Patio.'
   },
   'logprobs': None,
   'finish_reason': 'eos_token'
 }],
 'usage': {
   'prompt_tokens': 2168,
   'completion_tokens': 12,
   'total_tokens': 2180
 }
}
Python

在这个例子中,模型识别出预定义的类别:Terrace/Patio。

成本估算

得益于 vllm 批量推理的特性,每千张图片的推理时间约 674s,结合 g5.xlarge 的实例价格,千张图片的推理成本约为 $0.26,对应 GPT4o 的价格约 $5.54。

ml.g5.xlarge($/hour) GPT4o($/M Tokens)
price 1.408 2.5
Time 674s /
tokens / 2180000
cost 0.26 5.54

Table 1 Cost Comparison

总结与展望

本方案基于 Amazon SageMaker 平台,通过对最新 LLaVA Next 模型的微调,成功探索了多模态大模型在酒店图片分类任务中的应用,实现了一个高效率、低成本的图片处理解决方案。这一经验不仅适用于酒店图片分类,还可推广至其他电商领域,如服装、家具等产品的图片分类。方案的项目代码可从 Git 上获取。

随着业务规模的扩大和有效数据的积累,我们需要持续关注并改进以下方面:

  1. 拓展模型能力,支持更多样化的图片类别和复杂场景。
  2. 利用经市场验证的高质量数据集,持续优化模型性能,提升分类准确率。
  3. 探索先进的批量推理技术和模型压缩方法,进一步降低推理成本,提高系统效率。

*前述特定亚马逊云科技生成式人工智能相关的服务仅在亚马逊云科技海外区域可用,亚马逊云科技中国仅为帮助您了解行业前沿技术和发展海外业务选择推介该服务。

参考资料

本篇作者

林益龙

亚马逊云科技解决方案架构师,专注于在企业中推广云计算与人工智能的最佳实践。曾担任运维经理、解决方案架构师等岗位,拥有多年的企业 IT 运维和架构设计经验。

刘俊逸

亚马逊云科技资深应用科学家,毕业于康奈尔大学数据科学专业,负责基于开源大模型调优构建生成式 AI 解决方案在行业的落地应用,具有十年机器学习领域工作经验,主要研究方向是多模态算法、模型微调、模型小型化等。