亚马逊AWS官方博客

垂直电商图像搜索再升级:DINO 模型带来精准匹配体验

我们在网上或者线下购物的时候,有时候很难找到合适的词语来描述你正在寻找的东西。俗话说:”一张图片胜过千言万语”,通常情况下,展示实物或图片比用文字描述物品更容易,尤其是在使用 APP 或者小程序寻找产品的时候。

然而,构建一个电商级图片搜索需要面临诸多挑战,例如:

  • 对于私域,或品类电商来说,其部分产品及图片区分度较低
  • 用户输入的图片在拍摄角度,镜头变形,光照,背景干扰等方面和产品图片的区别较大
  • 系统对响应时间的要求比较苛刻等

在这篇文章中,我们将介绍如何从头构建一个鞋服类的垂直模型,从而实现低延迟,高精度的图片搜索解决方案。

业务背景介绍

图片搜索可以提高零售业务和电子商务中的客户参与度,尤其是对服装类(衣服、裤子、鞋、服装饰品等)零售商而言。服装类是在图片搜索中最重要的产品类型。调研报告显示有 36% 的消费者曾经使用过图片搜索,有 74% 的消费者认为传统的文字搜索很难帮助他们找到正确的产品。

由于行业的特性,服装类大多具有非常高的相似度,比如运动鞋和衣服,大多数鞋的形状和风格非常类似,需要通过非常细粒度的特征来进行识别。比如下面不一样型号的鞋子,会非常相似。

在这篇文章中,您将学习如何构建一个类似的产品目录相似度搜索解决方案。该方案主要集成 Amazon SageMaker 和亚马逊关系数据库服务 Amazon Aurora MySQL,向量数据存储 Amazon OpenSearch

业务需求分解

  • 基于对象的高效搜索
    当用户输入的图片中同时存在多个商品或目标时,允许用户在图像中搜索特定的对象或物品,这样他们能够只搜索感兴趣的产品,而不是搜索整个图像。这种功能可以提高搜索效率,让用户更快地找到所需内容。
  • 自动产品识别
    系统能够自动识别图像中的产品。将来,这项功能可以与电子商务平台集成,根据识别出的产品向用户推荐相关商品,促进销售。
  • 搜索准确性
    用户搜索的图片和索引库中待比对的图片在不同角度、不同光线条件下拍摄,系统在万级别品类下,Top5 的召回也能够达到 85% 以上的准确率,将产品与相关图像正确匹配,这是基于对图像视觉特征的分析。高准确度可以确保搜索结果的相关性。
  • 安全和隐私
    系统可以进行私有化部署,并确保符合相关的隐私法规和合规要求。
  • 索引和存储
    系统需要高效地索引和存储超过 100 万张图像数据,以及相关的元数据,如标签、描述和其他相关信息,以支持快速搜索和检索。

整体方案

参考架构图

方案步骤

离线处理(白色线条部分)

  1. 启动一个 Notebook 读取 S3 里面的所有的图片
  2. 调用 Bedrock 进行图片打标处理,用于过滤用来训练的数据
  3. 打完标记的结果放到 Aurora Mysql 里面保存
  4. 启动 Sagemaker 的模型训练节点,使用过滤后的训练数据进行训练。将训练完后的 embedding 模型部署到 Sagemaker
  5. 调用 embedding 模型对现有的所有产品图片进行 embedding,结果存入 OpenSearch

实时处理(黄色线条部分)

  1. 前端通过 Cloudfront 加载页面和产品图片
  2. Cloudfront 读取 S3 中的静态数据
  3. 当上传图片的时候,Cloudfront 会将请求转发到 API Gateway
  4. API Gateway 将请求转发到 EC2
  5. EC2 将图片发送到 Lambda
  6. Lambda 将图片发送到 GroundingDINO 进行目标检测。如果图片中没有任何目标物品,则返回前端;如果有多个目标物品,则将检测到的目标物品的坐标返回给前端,以允许用户进行物品选择;如果只有一个目标物品,或者用户已经选择了目标物品,则根据 GroundingDINO 返回的长方形框剪切出目标图片,进入下一步
  7. 将剪切出目标图片通过 Lambda
  8. Lambda 调用 embedding 模型获取向量
  9. 通过向量查询 OpenSearch 获取 top5 的产品代码
  10. 通过产品代码查询 Aurora 得到产品详细数据并返回前端

技术难点以及解决思路

图像预处理

技术难点:

  • 存在不适合训练的图片:某些图片可能只显示产品的部分视角(如鞋底),这对于训练模型来说可能不太合适。
  • 图片质量不一致,角度不同:由于图像来源的多样性,图像质量和拍摄角度可能存在差异,这会影响模型的训练效果。

解决方案:

参照下图,我们利用大语言模型最新的多模态功能,输入图片,让模型对图片的进行图片标注,在我们的场景中,我们设计了如下标签体系。“是否出现模特”,“模特人数”,“是否真实世界的场景”,“是否穿在模特身上”,“拍摄角度”,“局部还是整体”等,通过这些图片,我们可以过滤掉比如鞋底这类对训练和搜索都没有帮助的图片。

同时,我们也利用这些标签进行训练集和测试集的划分。真实世界场景的图片都被划分到测试集。

目标检测和分割

技术难点:

  • 用户使用的搜索图片无法做预先的限定,会出现不包含任何产品和包含多个产品的情况。如何确定图像中的目标是否为公司销售的产品类别:需要一种方法来识别图像中的目标是否属于公司销售的产品范围。
  • 如果检测到多个产品,需要用户选择:当图像中包含多个产品时,需要提供一种机制让用户选择感兴趣的产品。

解决方案:

使用 Grounding DINO 进行目标检测,针对鞋子,帽子,裤子等。然后直接使用代码剪切出对应的长方形块(这里保留了长方形块里面的所有元素,包括背景。最后没有使用 SAM 切割出不规则的物品,原因是我们发现,仅对目标图片做方框的截取即可,使用 SAM 做像素级分割,反而降低了模型的效果 )。

首先我们先构建模型压缩包,并上传至 S3 存储桶中,如下图所示[JB2] :

import boto3
import sagemaker
from sagemaker import serializers, deserializers
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

s3_model_prefix = "east-ai-models/grounded-sam"
!touch dummy
!rm -f model.tar.gz
!tar czvf model.tar.gz dummy
s3_model_artifact = sess.upload_data("model.tar.gz", bucket, s3_model_prefix)
print(f"S3 Code or Model tar uploaded to --- > {s3_model_artifact}")
!rm -f dummy

接下来我们准备创建模型所需要的代码,以下代码均在本地“code”路径下:

endpoint_name ="grounded-sam"
#%%
framework_version = '2.3.0'
py_version = 'py311'
instance_type = "ml.g4dn.xlarge"
endpoint_name ="grounded-sam"

model = PyTorchModel(
    model_data = s3_model_artifact,
    entry_point = 'inference.py',
    source_dir = "./code/",
    role = role,
    framework_version = framework_version, 
    py_version = py_version,
)

print("模型部署过程大约需要 7~8 分钟,请等待" + "."*20)

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
)

print("模型部署已完成,可以继续执行后续步骤" + "."*20)

准备自定义推理脚本 clip_inference.py。我们在 model_fn 中进行模型加载,在 predict_fn 定义推理逻辑,核心代码如下:

import os
import io
from PIL import Image
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
import json
import boto3
import uuid
import math

def get_detection_boxes(image_source: Image, model: dict, prompt: str = "clothes . pants . hats . shoes") -> (
        list, list, list):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    box_treshold = 0.3
    text_treshold = 0.25
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image_transformed, _ = transform(image_source, None)
    try:
        boxes, logits, phrases = predict(
            model=model['dino'],
            image=image_transformed,
            caption=prompt,
            box_threshold=box_treshold,
            text_threshold=text_treshold,
            device='cuda'
        )
    except Exception as e:
        print(e)
        return
    boxes_list = boxes.numpy().tolist()
    logits_list = logits.numpy().tolist()
    return boxes_list, logits_list, phrases


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.cuda()
    _ = model.eval()
    return model

def model_fn(model_dir):
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filenmae = "groundingdino_swint_ogc.pth"
    ckpt_config_filename = "GroundingDINO_SwinT_OGC.cfg.py"
    model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
    model_dic = {'dino': model, 'sam': ''}
    return model_dic


def save_file_to_s3(mask_image, file_extension, output_mask_image_dir: str):
    # 图片存储到s3
    ......
    return mask_image_output


def crop_images_from_boxes(image_source: Image, boxes: list, scale_factor: float = 1.0, target_size: int = 400) -> list:
    cropped_images = []
    width, height = image_source.size

    for box in boxes:
        cx, cy, w, h = [coord * scale_factor for coord in box]
        # 计算边界框的左上角和右下角坐标
        x1 = max(0, math.floor((cx - w / 2) * width))
        y1 = max(0, math.floor((cy - h / 2) * height))
        x2 = min(width, math.ceil((cx + w / 2) * width))
        y2 = min(height, math.ceil((cy + h / 2) * height))

        # 如果边界框在图像范围内,则裁剪图像
        if x2 > x1 and y2 > y1:
            cropped_image = image_source.crop((x1, y1, x2, y2))
            # 调整裁剪后图像的大小
            cropped_width, cropped_height = cropped_image.size
            # 等比例调整到目标尺寸
            scale = min(target_size / cropped_width, target_size / cropped_height)
            new_width = int(cropped_width * scale)
            new_height = int(cropped_height * scale)
            cropped_image = cropped_image.resize((new_width, new_height), resample=Image.BICUBIC)
            cropped_images.append(cropped_image)

    return cropped_images


def predict_fn(input_data, model):
    print("=================Dino detect start=================")
    try:
        file_extension = os.path.splitext(input_data['input_image'])[1][1:].lower()
        dir_lst = input_data['input_image'].split('/')
        s3_client = boto3.client('s3')
        s3_response_object = s3_client.get_object(Bucket=dir_lst[2], Key='/'.join(dir_lst[3:]))
        img_bytes = s3_response_object['Body'].read()
        image_source = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        if 'boxes' not in input_data:
            prompt = input_data['prompt']
            boxes, logits, phrases = get_detection_boxes(image_source, model, prompt)
            if len(boxes) == 0:
                return {"error_message": "The image does not contain any object needed"}
            elif len(boxes) > 1:
                return {"boxes": boxes, "file_type": file_extension, "logits": logits, "phrases": phrases}

        boxes = [input_data['boxes']] if 'boxes' in input_data else boxes

        cropped_images = crop_images_from_boxes(image_source, boxes)
        mask_image_output = save_file_to_s3(cropped_images[0], file_extension, input_data['output_mask_image_dir'])
        return {"mask_image_output": mask_image_output}
    except Exception as e:
        print(e)

Embedding 模型

技术难点:

传统的图片 Embedding 模型在用作向量召回时往往存在如下问题

  • 缺乏标注的图片:训练模型需要大量已标注的图像数据,但获取这些标注成本过高,可能存在困难。
  • 模型需要高精度以进行细粒度比较:为了准确匹配相似产品,嵌入模型需要具有足够的精度来捕捉细微的差异。
  • 模型输出的 Embedding 的鲁棒性不足:会受到背景,衣物形变,拍摄角度,光线等因素的较大影响。
  • 需要私有部署选项以保证安全和隐私:出于安全和隐私考虑,可能需要在本地私有环境中部署模型。
  • 模型应该可定制和可扩展:为了满足不同的需求,模型应该具有一定的定制和扩展能力。

解决方案:

先用基于 DINO+VIT 的模型在私有产品图片数据上进行预训练,这个阶段无需进行标注,DINO 就可以自行关注到图片中的主体,而不容易受到背景的干扰。在第二阶段,我们采用对比学习或者分类的方式对模型进行 Finetune 从而进一步提升召回能力。下图可视化了 DINO 模型的注意力层,展示其相对于传统模型的优点,我们可以看到 DINO 这一列中展示的模型注意力可以剥离背景的干扰因素,而传统的有监督算法的注意力没有准确的捕捉到图片中的主体。

在具体的算法开发过程中,我们评估了 DINO 和 DINO V2,Triplet Loss 和 Cross Entropy Loss,也对比了 VIT 和 CNN,在大量实验的基础上,得到的最终的结论如下:

  1. Triplet loss,目前看下来经济性远不如 cross entropy loss,同样的训练轮次完全不收敛(个位数的 mAP),原因是 cross entropy loss 训练过程中一次梯度更新优化的是整个样本分布,而 triplet loss 一次梯度更新仅仅是优化采样到的正负样本,训练效率完全不是一个等级,但是 triplet loss 这种直接优化特征的模式其实更加适配向量匹配任务,可能需要更大的 batch size 或者更细致的超参数调节,加上更完备的难负样本挖掘。
  2. DINOv2(即加入了 MAE 损失的 DINO)在此场景下毫无意外地比 DINO 差,甚至 large 和 giant 版本的 VIT-dinov2 都比不过 Base 的 VIT-dino,目前的猜测是由于重建类的损失(MAE 损失)并不适配判别场景,此种场景下还是判别损失(Cross Entorpy Loss)更加合适,关注的特征也更加低频,提取到的特征更加适合做判别任务。
  3. DINOv1 是目前最适合做向量搜索的预训练算法,这种预训练方法甚至可以一定程度上弥补模型参数量的差距。
  4. 有条件的话可以用 DINO 的训练框架预训练更大的 VIT 模型,因为目前 DINO 官方给出的最大的 VIT-dino 只到 base,并没有 large 版本放出,后续可以在 Google Landmarks v2+ImageNet+私有数据集上进行预训练。

将训练好的 DINO 模型部署在 SageMaker 上,需要提供推理脚本文件 inference.py。其中的主要代码如下:

...
def predict_fn(single_data, model):
    """
    Predict a result using a single data

    :param single_data: a single numpy array for an image
    :type single_data: numpy.array
    :param model: the loaded model
    :type model:
    :return:an object with prediction value
    :rtype: object
    """

    imsize = 648

    transform = pth_transforms.Compose(
        [
            pth_transforms.Resize((imsize, imsize), interpolation=3),
            pth_transforms.ToTensor(),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    image = transform(single_data)

    try:
        output = model(image[None].cuda())
        # First, move the tensor to CPU
        cpu_tensor = output.cpu()

        # Then convert to NumPy array
        numpy_array = cpu_tensor.detach().numpy()
        return numpy_array[0]
    except Exception as e:
        raise e


def input_fn(input_data, request_content_type):
    #  The request_body is coming 1 by 1
    """An input_fn that loads a pickled tensor"""

    if request_content_type == "application/json":
        try:

            json_request = json.loads(input_data)

            file_byte_string = s3_client.get_object(
                Bucket=json_request["bucket"], Key=json_request["file_name"]
            )["Body"].read()

            im = Image.open(io.BytesIO(file_byte_string))
            im = im.convert("RGB")

            return im
        except Exception as e:
            raise e
    elif request_content_type == "application/x-image":
        im = Image.open(BytesIO(input_data))
        im = im.convert("RGB")
        return im
    else:
        # Handle other content-types here or raise an Exception
        # if the content type is not supported.
        raise Exception("Unsupported content type")


def model_fn(model_dir):

    pretrained_weights = os.path.join(model_dir, "checkpoint.pth")
    print(os.path.abspath(os.path.join(model_dir, "config.json")))
    # Open the file and load its contents
    config_path = os.path.join(model_dir, "config.json")
    with open(config_path, "r") as config_file:
        model_config = json.load(config_file)

    print("loading model info: %s", model_config)

    # load pretrained weights
    if os.path.isfile(pretrained_weights):
        model = vits.__dict__[model_config["arch"]](
            patch_size=model_config["patch_size"],
            drop_path_rate=model_config["drop_path_rate"],  # stochastic depth
        )
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model.load_state_dict(state_dict, strict=False)
        print(
            "Pretrained weights found at {} and loaded with msg: {}".format(
                pretrained_weights, msg
            )
        )
    else:
        print(
            "Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2."
        )
        model = torch.hub.load(
            "facebookresearch/xcit:main", "vit_small", pretrained=False
        )
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"
            )
        )

    model = model.cuda()
    model.eval()
    return model
    ...

向量搜索

技术难点:

  • 用于产品召回,而非图像召回:最终目标是根据图像找到相应的产品,而不是简单地找到相似图像。
  • 需要支持从向量存储中高效检索向量:向量数据库需要能够支撑百万级的快速向量检索,且搜索结果应该能够提供产品的唯一标识符(如产品代码)。

解决方案:

使用 OpenSearch 同时存储图片的向量数据和产品的代码,这样在做向量相似度对比后,可以同时获取产品代码。同时使用 Faiss-HNSW 算法作为检索算法,同时相似度的计算我们使用了和模型 Finetune 阶段相匹配的 Cosine 函数。核心的考虑点如下图:

OpenSearch 提供了多种算法选择,通过下图的对比,我们最终选择了 FAISS-HNSW 作为向量索引算法。

总的来说,这里涉及图像处理、目标检测、图像分割、embedding 和向量搜索等多个方面,需要解决数据、模型精度、部署环境和搜索结果等多个挑战。通过合理的数据预处理、模型选择和系统设计,可以构建一个高效的基于图像的产品检索系统。

实验测试结果

上图是 CMC(Cumulative Match Characteristic)的测试结果,横坐标 rank n 代表检索出的前 n 个产品,纵坐标是检索出的前 n 个产品里面有目标产品的概率。我们的测试产品库中包含 6000 个左右的品类,用户图片都是真实世界场景的图片,可以看到有 75% 的图片在 rank 1 的位置召回,86% 的正确产品图片都在前 5 的位置被召回。这个检索的精度,满足了客户要求的前 5 个产品里面有目标产品的概率达到 85% 的要求。并且经过业务人员的确认,搜索可以自动忽略背景的影响,对于细节的区别和辨认也已经接近或者达到人类水平。

结论

本文通过使用服装鞋类商品进行模型训练,同时通过 GroundingDINO 进行目标物品检测和剪切的方式对图片进行搜索,这种方式满足企业级的,特别是垂直行业的高精度搜索。有助于更好地提升用户的搜索体验。

该方案也可以拓展到其他的垂直行业使用,如电商、游戏、短视频,医疗、制造业等。

如果您有任何相关的问题或需求,都欢迎随时联系我们进一步交流。

参考资料

https://github.com/IDEA-Research/GroundingDINO

https://arxiv.org/pdf/2303.05499

https://github.com/facebookresearch/dino

本篇作者

江炳坤

亚马逊云科技资深解决方案架构师。拥有十余年系统架构设计经验。曾任职高级架构师、企业架构师等岗位,涉及移动互联网、金融、政府等行业。在AI应用,营销系统,微服务,高并发高可靠的系统设计方面具有丰富的实战经验。目前专注于将 AWS 云平台技术应用于实际解决方案,为客户实现技术创新和成功的技术落地。

姬军翔

亚马逊云科技资深解决方案架构师,在快速原型团队负责创新场景的端到端设计与实现。

吕浩然

亚马逊云科技应用科学家,长期从事计算机视觉,自然语言处理等领域的研究和开发工作。支持数据实验室项目,在时序预测,目标检测,OCR,自然语言生成等方向有丰富的算法开发以及落地实践经验。

尹振宇

亚马逊云科技解决方案架构师,负责基于 AWS 云平台的解决方案咨询和设计,尤其在无服务器领域和微服务领域有着丰富的实践经验。

洪丹

亚马逊云科技原型解决方案架构师,负责机器学习应用场景的快速构建,为客户提供高效、精准的解决方案,以满足他们独特的业务需求和挑战。

华成

亚马逊云科技客户解决方案经理,目前在亚马逊云科技主要支持泛零售行业的客户。通过运用云相关解决方案等帮助客户在迁移到亚马逊云和云上运维期间实现自身的业务价值,帮助客户成功。