亚马逊AWS官方博客

使用 SageMaker 对 Whisper 模型进行微调及部署

本文主要介绍如何使用 SageMaker 对 Hugging Face 上 pytorch 版本的 Whisper 模型进行微调(Fine-tuning)和部署。相比于原版模型而言,HuggingFace 版的实现学习成本更低,上手更容易。

OpenAI Whisper 简介

Whisper 作为 OpenAI 最新开源的自动语音识别(ASR)模型,采用了编码器-解码器(encoder- decoder)transformer架构,并使用了 68 万小时的从互联网收集的多语言、多任务的已标注数据进行训练。根据其论文显示,Whisper 模型在无需微调(zero-shot)的情况下,在多个数据集的测试上鲁棒性更高,错误率更低。关于 Whisper 模型的更多细节,参见其官方网站 https://openai.com/blog/whisper/ 以及 https://github.com/openai/whisper

Figure 1 Whisper 模型架构(来源:https://github.com/openai/whisper

HuggingFace 简介

HuggingFace 是一个数据科学家、机器学习工程师、研究人员分享创意、数据以及模型并进行协作的社区和平台。目前,HuggingFace 上已经拥有了 30K+的模型,5K+的数据集以及各式各样的 Demo 以及帮助使用和分享模型和数据集代码库。本文主要用到的 whisper-small 模型(https://huggingface.co/openai/whisper-small),common-voice 数据集(https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0)都来自 HuggingFace。

数据准备

我们使用 SageMaker Studio Notebook 作为工作环境来处理我们使用 Common Voice 数据集中的 zh-TW 子集。由于数据预处理在 Notebook 中直接执行,我们可以选择性能较好的 CPU 实例例如 ml.c5.4xlarge 或者 GPU 实例例如 ml.g4dn.xlarge 来进行处理。为了较好的兼容 HuggingFace dataset 的一些显示组件,我们选择 Data Science Kernel。

Figure 2 为 SageMaker Studio Notebook 选择 Image,Kernel 以及 Instance Type

数据预处理主要包括几个部分:

  1. 依赖的安装,包括 pytorch audio, git-lfs 以及 transformers 等。
  2. 使用 load_dataset 从 HuggingFace 的 git 中下载 common voice 中 zh-tw 子集(subset),并且把 train, test, validation 三个分片加载到 DatasetDict 中,方便之后的训练直接使用。
  3. 音频的下采样,从 48k 降到 Whisper 使用的 16k,以及其他特征变换。
  4. 使用 log-Mel 算法进行音频特征的提取(Whisper feature_extrator)以及音频标注到 Whisper 中字符集 id 的映射(Whisper tokenizer)。
  5. 处理完的数据保存到本地并上传到 S3。

具体代码参见:https://github.com/aws-samples/amazon-sagemaker-finetune-deploy-whisper-huggingface/blob/main/zhtw-data.ipynb

Figure 3 数据处理流程

通过对比源码我们可以发现,HuggingFace 版本在数据预处理的实现上与原版并不完全一致,例如原版的 log-Mel 算法使用了 pytorch 的函数进行实现(https://github.com/openai/whisper/blob/main/whisper/audio.py),而HuggingFace 版本则使用 numpy 进行了实现(https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py)。

Figure 4 log-Mel Spectrogram 在原版(左)与 HuggingFace 版(右)的实现

由于 HuggingFace 对 ASR 类模型进行了统一的 API 封装,包括音频特征提取的 feature_extractor 和文本到 id 转换的 tokenizer,降低了我们对于不同模型之间差异的学习成本。

在数据准备部分,除了最后需要上传 S3 以外,其他步骤都与本地处理完全相同。借助 SageMaker Studio Notebook,我们可以根据数据集的大小和数据处理代码所需的环境不同选择最合适的实例类型和内核类型。

微调训练(Fine Tuning)

在数据准备完成后,我们就可以开始训练了。由于这一部分主要是对于代码的打包和调用 SageMaker API,所以可以使用较低配置的机型,这里采用了 SageMaker Studio Notebook ml.t3.medium 机型配合 Data Science Kernel。微调训练部分主要涉及以下内容:

  1. train.py:训练脚本,主要是调用了 transformers 的 Seq2SeqTrainer 来进行训练。注意它使用的超参数包括了训练自身的超参数,例如批次大小,最大训练步数等。可以参考官方样例进行设置。另外一部分是 SageMaker 中 HuggingFace Estimator 设置的参数,例如数据路径以及使用多少 gpu 等,这些参数会以环境变量的形式传给 train.py。这部分参数可以参考本文档以及本文档。在训练最后一步需要调用 trainer.save_model(args.model_dir)和 processor.tokenizer.save_pretrained(args.model_dir)把模型保存到训练机本地磁盘指定目录中。在训练完成后该目录内文件会被打包上传到 S3。
  2. requirements.txt:在训练开始前需要安装的 python 依赖。例如包含了 Whisper 的新版本 transformers。
  3. zhtw-finetune.ipynb:这个 notebook 主要是设置了训练机型、是否使用分布式训练等,然后调用 SageMaker SDK 里的 HuggingFace Estimator 进行训练。例如我们可以使用 distribution = {‘smdistributed’:{‘dataparallel’:{ ‘enabled’: True }}} 配合 instance_type = ‘ml.p3.16xlarge’ 进行单机多卡分布式训练;也可以用 distribution = None 配合 instance_type = ‘ml.p3.2xlarge’ 进行单机单卡训练。

具体代码参见:https://github.com/aws-samples/amazon-sagemaker-finetune-deploy-whisper-huggingface/blob/main/zhtw-finetune.ipynb

Figure 5 微调过程

当然我们可能会好奇,HuggingFace Estimator 中的参数是如何让 Seq2SeqTrainer 进行单机或者分布式训练的,这里我们可以从 trainer 代码中(https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py)找到线索:

Figure 6 Trainer 中 SageMaker 部分相关代码

训练完成后,我们可以在 SageMaker 控制台上找到本次训练的相关信息和输出。

Figure 7 SageMaker 控制台上的训练记录

Figure 8 CloudWatch 中的算法指标 WER

Figure 9 CloudWatch 中的部分实例指标

Figure 10 SageMaker 控制台中的输出模型位置

部署

由于部署时会用到较多的根目录下的磁盘空间,因此这里我们使用 SageMaker Notebook,机型上依然使用 ml.t3.medium 配合 Conda Python Kernel。部署过程主要涉及两个概念,SageMaker Model 和 SageMaker Endpoint。

我们可以认为 SageMaker Model 定义了如何使用模型进行推理,它包括了:

  1. 推理需要的运行环境 Docker 镜像,它可以与训练镜像不同,我们可以在推理镜像上做针对于推理的优化,或者是使用 CPU 推理镜像来降低成本等。可以从此处找到最新的预置的训练和推理镜像列表。
  2. 训练完的模型(例如 TrainingJob 输出的 S3 model artifact)
  3. 用于输入输出处理和模型调用的脚本(通常是 inference.py)
  4. 运行时需要的依赖(requirements.txt)

当这部分准备工作完成以后,我们可以决定用哪种实例来运行 SageMaker Model 以及流量如何分配等,然后启动 SageMaker Endpoint 对外提供 RESTful Endpoint 提供推理能力。SageMaker Model 和 SageMaker Endpoint 在 SageMaker Python SDK 中的封装分别是 sagemaker.model.Model 和 sagemaker.predictor.Predictor。在示例 Notebook 中我们使用了 Predictor 来对音频文件进行推理,得到了转录后的文本。

具体代码参见:https://github.com/aws-samples/amazon-sagemaker-finetune-deploy-whisper-huggingface/blob/main/zhtw-deploy.ipynb

Figure 11 SageMaker Model


Figure 12 SageMaker Endpoint

Figure 13 使用 Predictor 调用 SageMaker Endpoint 完成音频到文字的转录

到这边为止,我们已经完成了 Whisper 在 SageMaker 上的微调和部署工作。

总结

本文简单介绍了如何使用 SageMaker 平台对 Whisper 模型进行微调和部署。可以看到 HuggingFace 不仅仅提供了模型和数据集的共享,在对于它们的使用上也提供了封装,特别是针对在 SageMaker 上的训练调优也提供了相应的支持,在很大程度上简化了我们使用这些模型的难度。而 SageMaker 平台不仅仅提供了多种多样的算力帮助我们训练和部署模型,也在机器学习的各个环节提供了集成的开发运行环境,使得数据科学家可以更轻松的完成从数据准备、模型训练到部署运行的机器学习全过程,提高了工程效率。

参考资料

https://github.com/aws-samples/amazon-sagemaker-fine-tune-and-deploy-wav2vec2-huggingface

https://github.com/huggingface/blog/blob/main/fine-tune-whisper.md

本篇作者

施俊

AWS 解决方案架构师,主要负责数字金融客户和企业级客户在 AWS 上的架构设计与实施。10+年金融软件研发和机器学习经验。