亚马逊AWS官方博客

结合 HSDP 及模型并行加速 Llama3 训练

Efficient Large-Scale Training with Pytorch FSDP and AWS 中,Meta 首次展示了 FSDP(Fully Sharded Data Parallel) 如何利用云上基础设施( A100 GPU P4de 实例集群)来实现大规模训练的提效。

FSDP 作为 ZeRO 的一种实现形式,其通过消除 DDP(分布式数据并行)中存在的优化器计算和状态存储、梯度和模型参数内存存储的冗余,有效扩展了在固定资源下可训练的模型量级。这种冗余减少,使 FSDP 相比于朴素数据并行能够在相同的资源下训练更大的模型(参考 Maximizing training throughput using PyTorch FSDP,使用 FSDP 及 A100 GPU 集群 Llama2 7B 上达到 上训练 Llama2 7B 达到了57%的较高水位 MFU)。

在 PyTorch 的 2.1 及近期的 2.4 版本中,分别正式支持了 FSDP 的 Hybrid Sharded Data Parallel 以及 DeviceMesh

混合分片数据并行

全量分片的挑战

基于全量分片如 FSDP 或 DeepSpeed ZeRO 3 的训练范式,包括对于参数、梯度、优化器状态等的切分,能够带来明显的显存节省。但其所引入的额外通信开销,以及可能的 CPU 卸载所需的主机及 GPU 间内存拷贝开销等,所带来的挑战有:

额外通信开销:由于参数分片,在训练前向时,每个 GPU 都需要额外的通信操作(All-Gather)来从其他所有 GPU 汇集参数,并通过逐层汇集及用完丢弃的形式, 虽然保证了内存的节省,但引入了参数量规模的通信开销,叠加训练反向时的参数的逐层汇集及梯度分发汇总,导致累计 3 倍于参数量的通信,是标准的数据并行通信量的 1.5 倍,参考 ZeRO Ch5.3

进程间通信时延瓶颈的累积:由于节点内和跨节点 GPU 进程通信时延的不对等,同时跨节点进程通信受集群拓扑的影响较大。由于全量分片涉及到集群中的所有 GPU 进程,因此当集群规模扩大时,通信路径上的时延瓶颈可能出现叠加及积累,导致其带来的影响更加显著,制约了集群规模及训练性能的进一步提升。

HSDP 混合分片并行

不同于 FSDP 中直接在全集群上进行训练状态的分片,HSDP(Hybrid Shard Data Parallel)使用混合分片策略,可以根据集群的拓扑形态进行分片,比如在节点内完全分片,并在节点之间使用不同模型副本进行数据并行。使得较大开销的 AllGather 及 ReduceScatter 集合操作仅在节点内完成,因此可以更好的利用 GPU 间 NVLink 带宽,对于中等大小的模型训练,能够带来较显著的性能收益。

Process Group 及 Device Mesh

FSDP 和 HSDP 都依赖于进程组(Process Group)进行通信。进程组是用于模型分片的通信组,FSDP 默认自动构造进程组,来自动进行 AllGather 及 ReduceScatter 等集合通信操作。对于 HSDP,可以通过传入一个 ProcessGroup 的描述元组,来分别表征分片及模型副本所使用的组,用于描述模型状态分片、多副本间并行的组合形式。

较新的 DeviceMesh 是一种更高级别的抽象,用于管理多个进程组(ProcessGroup)。其简化了在节点内和节点间创建进程组的过程,无需手动设置子进程组的 Ranks。此外,DeviceMesh 也可以对多维并行场景下的底层进程组和设备进行管理。因此,在较新版本的 PyTorch 中,DeviceMesh 成为了进程组的互斥替代形式。比如,在混合分片(HSDP)训练时,可以通过指定一个2维的 DeviceMesh 来取代相对更复杂的 ProcessGroup 定义。

以 2 节点的 8 GPU 计算实例上的配置为例,通过 DeviceMesh 定义两维的 ProcessGroup 可以使用如下的形式。

from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 8)), mesh_dim_names=("replicate", "shard"))

# Users can access the underlying process group thru `get_group` API.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")

FSDP 对基于 DeviceMesh 进行混合分片(HSDP)的过程进行了封装,可以在指定 sharding_strategy 的同时,直接传入以上的二维 DeviceMesh 定义,示例如下。

mesh_2d = init_device_mesh("cuda", (2, 8))
model = FSDP(
<model>, device_mesh=mesh_2d, sharding_strategy=ShardingStrategy.HYBRID_SHARD
)

高效构建基于 HSDP 的张量并行

FSDP 叠加 DeviceMesh 进行训练资源的定义后,即可以使用 HSDP 快速进行层次化混合分片的并行训练。但当模型参数尺寸,窗口长度,全局批次大小等条件固定后,仍旧容易出现的情况是,混合分片(Hybrid Shard)无法满足显存的约束,导致不得不降级为全量分片(Full Shard)或引入CPU 卸载,来进一步降低显存占用,从而无法利用混合分片(Hybrid Shard)提供预期的性能优势。因此,需要引入其他分片形式(如张量并行对模型参数作分片)进一步降低显存占用,来避免使用 CPU 卸载等可能带来较大性能损失的操作,从而使得一个完整模型副本保持在单个节点内。

张量并行是模型并行的一种类型,其中特定的模型权重(包含激活值)、梯度和优化器状态被划分到多个设备上,与 ZeRO 的计算前进行参数分片的汇聚(ZeRO Ch5.3)所不同的是,张量并行在各参数分片独立计算后汇聚结果(Megatron-LM Ch3)。当单个张量消耗大部分 GPU 内存时(例如具有大词汇量的大型嵌入或具有大量类别的大型 softmax 层),在这种情况下,将整个大型张量作为原子单元进行操作是低效的,同时会导致 GPU 内存负载的不平衡。

构建张量并行的挑战

PyTorch 的张量并行参考了 Nvidia Megatron-LM 的实现,成为其并行计算栈的重要组成部分。然而,基于 PyTorch 的模型并行构建仍存在的挑战有:

模型分片管理:开发者需要明确地对模型的参数进行分片,意味着在代码实现过程中,开发者必须基于不同模型结构将其分割成多个子模块。比如线性层基于行向或列向的切分,输入的嵌入层和输出的投影层,以及自注意力层等的切分等,增加了实现和调试的复杂度。

基于集群规模的灵活扩展:基于模型构建张量并行的另一个挑战在于,当集群的规模扩展扩缩减时,需要基于集群或任务需求的规模进行并行模式的重新调整。这增加了扩缩容时任务适配的复杂性。此外,如在扩展过程中,通信负载的增加并不是线性的,需要根据基础设施的情况调整不同的并行方案以实现较高性能。

MiCS

Amazon Science 针对 ZeRO 的训练范式(包括 FSDP、DeepSpeed 等)进行了优化(optimizing the communication efficiency),并早在 2022 提出了 MiCS(Minimizes Communication Scale)来降低训练中的通信开销。

MiCS 可以使模型训练在云端数百个 GPU 上进行高效扩展,其最小化了通信规模,从而降低通信开销。具体来说,分布式训练框架(如 FSDP 及 DeepSpeed)会将模型状态划分到所有 GPU 上,而 MiCS 则创建多个完整的模型状态副本,并将每个完整状态副本划分到一个GPU子集。MiCS 的 1)Scale-aware Model Partitioning 根据模型大小,来确定一个完整的模型副本需要单个或多个计算节点来进行承载,而非使用整个集群中的全量 GPU。

因此,在 MiCS 中,消耗较大的集合通信操作仅限于集群中的的一个 GPU 子集。这样,当我们在训练集群中新增新节点及模型副本来进行集群扩展时,计算节点或 GPU 子集间的通信开销将保持不变,同时维持在与一个朴素数据并行的节点间通信量相当的较低水位。

如果模型状态副本无法装入单个节点, MiCS 将利用 2)Hierarchical Communication Strategy 来减少节点间传输的数据量。最后,MiCS 提出了 3)2-hop Gradient Synchronization 的梯度同步调度机制,可在所有工作节点之间分摊昂贵的梯度同步开销。

* Figure – near-linear-scaling-of-gigantic-model-training

MiCS 能够实现近乎线性可扩展性,并且与 DeepSpeed v0.5.6 中内置的三阶段零冗余优化器(ZeRO)的第二和第三阶段相比,吞吐量提高了2.82 倍。在 p4d.24xlarge(40GB A100)和 p4de.24xlarge(80GB A100)实例上部署 MiCS,用于训练高达 175B 参数的专有模型。当在 16 个 p4de.24xlarge 实例上训练序列长度为 2048 的 175B 参数模型时,其能够在每个 GPU 上实现 54.2% 的 MFU。在 64 个 p4d.24xlarge 实例(512 个 A100 GPU)上训练千亿参数模型时,MiCS 在每个 GPU 上保持超过 54.5% 的 MFU。当集群规模从 128 个 GPU 扩展到 512 时,MiCS 实现了 99.4% 的线性扩展效率(以“Weak Scaling”指标衡量)。相比之下,DeepSpeed ZeRO 的第三阶段只实现了 72% 的扩展效率,并在 MFU 仅为 19.9% 时即到达饱和状态。

可以看出,MiCS 从不同角度对大规模集群训练中的通信进行了基于拓扑感知的优化,极大程度降低了 ZeRO 范式在训练过程中的通信负载。

SMPv2 (SageMaker Model Parallel Library v2)

SMP(SageMaker Model Parallel Library)于 2020 年以 Library 的形式首次集成至 Amazon SageMaker SDK,并被包括 BloomburgGPT 等模型的训练任务所采用。SMPv2 在 SMP 及 MiCS 优化理念的基础上,进一步结合了最新的 PyTorch FSDP( 包括 HSDP),并与 Nvidia Transformer Engine 实现集成,得以在 FSDP 或 HSDP 等 ZeRO 范式的数据并行基础上,进一步叠加张量并行并快速构建训练任务,以获得最佳集群性能。除此之外,SMPv2 的几个核心优化点如下:

激活状态卸载及预加载优化

通常情况下,前向传递会在每一层计算激活值,并将它们保存在 GPU 内存中,直到相应层的反向传播完成。在 GPU 显存有限的情况下,在前向传递后将这些张量卸载到 CPU 内存,并在需要时将它们取回 GPU,以节省大量 GPU 显存使用。PyTorch 支持卸载激活值,但实现会导致 GPU 在反向传播期间从 CPU 获取激活值时处于空闲状态。这会导致使用激活值卸载时性能严重下降。

SMPv2 改进了这种激活值卸载方式。通过参数 activation_loading_horizon 的控制,会预先获取激活值,在 GPU 需要开始这些激活值的反向传播之前就将它们取回。预取功能有助于训练更高效地进行,降低 GPU 空闲占比。这样可以在降低内存使用的同时,缓解性能下降。

延迟参数初始化

当模型参数规模较大时,在有限的 GPU 内存上直接进行大型模型初始化并不总是可行的。为解决更大参数规模模型初始化 GPU 内存不足所导致的问题,SMPv2 可实现在 CPU 内存上初始化模型。但对于超过 20B 或 40B 参数量的更大模型,即使使用 CPU 内存也可能不足。对于这种情况,可以在 PyTorch 中的元设备(meta device)上初始化模型,它允许创建没有任何实际数据的空张量。元设备上的张量只需要形状信息,这允许在元设备上创建具有大量参数的大型模型。Hugging Face Accelerate 提供了 init_empty_weights 上下文管理器,可帮助在元设备上创建这样的模型,同时在常规设备上初始化缓冲区。在训练开始之前,PyTorch FSDP 会初始化模型参数。SMPv2 的延迟参数初始化功能将模型参数的创建延迟到 PyTorch FSDP 执行参数分片之后。当对 PyTorch Module 进行分片时,PyTorch FSDP 接受一个参数初始化函数 param_init_fn,并供每个 Module 调用。值得注意的是,SMPv2 在 PyTorch v2.0.1 中对 param_init_fn 相比原生版本同时进行了优化。

张量并行

SMPv2 与 Transformer Engine 集成,实现了张量并行,并直接运行在 PyTorch FSDP API 之上。可以同时启用 PyTorch HSDP 和 SMPv2 的张量并行,并确定最佳模型并行方式以获得最佳性能。实际应用中,张量并行在以下场景有明显优势:

  • 当使用长上下文长度进行训练时,由于 FSDP 未对激活值(Activation)进行分片,因此其会导致较高的显存占用。
  • 当全局批量大小固定时(如根据损失下降情况、泛化性要求等原因确定了该超参数),在较大的集群上进行数据并行训练时,全局批量大小有可能超过期望限制,比如 Llama 2 70B 的全局批大小为 1K(4M Tokens),因此其仅用数据并行的训练范式无法使用 1K 以上的进程进行并行训练。

* 目前 SMPv2 同时支持专家并行

基于SMPv2 构建张量并行的最佳实践

初始化配置

SageMaker Estimator 配置中新增初始化启动参数如下:

distribution={
    "torch_distributed": { "enabled": True },
    "smdistributed": {
        "modelparallel": {
            "enabled": True,
            "parameters": {
                "hybrid_shard_degree": Integer,
                "sm_activation_offloading": Boolean,
                "tensor_parallel_degree": Integer,
            }
        }
    }
}

在训练脚本中添加 tsm.init(),其他仍沿用 FSDP 的原生启动形式。

import torch.sagemaker as tsm
tsm.init()

# Set up a PyTorch model
model = ...

# Wrap the PyTorch model using the PyTorch FSDP module
model = FSDP(model,...)

# Optimizer needs to be created after FSDP wrapper
optimizer = ...

在使用 tsm.init() 完成 SMPv2 的启动后,使用 tsm.transform() wrapper 来对 torch.modules 形式的模型进行转换。

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_config(..)
model = tsm.transform(model) 

完整过程参考示例。训练及检查点(Checkpointing)过程与 FSDP 过程一致。检查点基于 PyTorch state_dict,可参考检查点(Checkpointing)相关的存储选型及性能对比

训练参数配置优化

训练时的吞吐性能与众多因素,包括模型大小、批次大小、并行形式、状态分片或 CPU 卸载形式等强相关,同时以上因素又涉及到基础设施的性能,包括 GPU 及卡间互联类型、实例及节点间互联类型、HBM 带宽等。在较多的实践中,所观察到的相对通用的性能优化建议有:

  • 使用最新的计算实例,如 Amazon P5 实例,其包括 8 x H100 Tensor Core GPU 并包含由 EFAUltraClusters 提供的 3.2Tbps 的节点间互联带宽,能够较好消除较重的跨节点通信所导致的性能瓶颈。
  • 使用延迟参数初始化加速训练任务的初始化过程,尤其是在 60B 及以上参数量的模型训练场景。
  • 根据参数量大小,确定层次化分片(Hybrid Shard)及张量并行所涉及的 GPU 及节点拓扑,保证高消耗的集合通信操作在仅节点内或较少的节点间完成。
  • 使用激活状态检查点或卸载会一定程度降低训练吞吐,但其带来的显存节省可以使得单个训练轮次容纳更大的批次大小。因此建议基于实际的训练任务调整相关参数进行性能调优。
  • 针对长上下文窗口的模型训练任务,通过增加 SMPv2 的张量并行度,在 FSDP 或 HSDP 的基础上进一步地对参数及其激活值(Activation)进行分片,进行高效训练。

除了在显存消耗密集的训练任务中引入 SMPv2 的张量并行来进行性能优化外,对于在大型集群上进行训练的场景,使用数据并行类的训练形式(如 FSDP 或 DeepSpeed),容易导致全局的批次大小过大,其可能对于收敛性造成影响。因此,在预设固定的全局批次大小或全局 token 量的约束下,可以结合使用 SMPv2 的张量并行来控制数据并行度及全局批次大小,在保证收敛性的同时,通过调整其他控制参数进一步优化训练性能。

FSDP、HSDP 及 SMPv2 性能对比

测试用例选取 Meta Llama-3-8B,8K 上下文窗口。同时为保证任一样本的显存占用及注意力计算消耗的一致性,将任一样本扩展至 8K 长度且不使用填充 Token 及掩码,排除由于样本采样差异所导致的填充占比差异,从而避免底层稀疏优化不均匀所引入的性能差异。以上三种训练形式在 1~2 x P5 (8 x H100 GPU) 实例上的训练吞吐对比如下:

其分别对应的主要配置为:

并行形式 激活状态检查点 激活状态卸载 张量并行度 微批次大小 OOM
SMPv2 N N 2 2 N
HSDP N N N.A. 2 CUDA OOM
HSDP Y N N.A. 4 CUDA OOM
HSDP Y N N.A. 2 N
FSDP Y N N.A. 2 N

以上数据可以看出:

  1. SMPv2 通过提升张量并行程度降低显存消耗,可以无需使用激活状态检查点(Activation Checkpointing)或卸载(Offloading)等涉及重算或内存间拷贝等以速度换取显存的低效操作,而仍确保不发生 OOM。
  2. 在开启张量并行,使得全局的批次大小降低的同时,仍能达到更高的训练吞吐。
  3. 该测试用例的训练场景,由于 P5 实例的 3.2Tbps 的 EFA 节点间带宽,导致 HSDP 相比于 FSDP 所带来的性能差异并不显著;同时 SMPv2 的整体性能更好,导致训练吞吐的提升更加显著。
  4. P5 实例的高速实例间带宽,结合 SMPv2 层次化的模型分片及张量并行,使得中小尺寸模型的训练性能可以更好地随集群规模进行扩展。

总结

SMPv2 在原有 SMP 和 MiCS 的基础之上,结合了PyTorch HSDP 和 Nvidia Transformer Engine,实现了高效的基于张量并行的大型模型训练。SMPv2 在 SageMaker 体系上简化了繁琐的训练初始化配置,并支持与 PyTorch FSDP 的无缝集成,极大降低了开发复杂性。通过配置层次化分片、张量并行度等参数,SMPv2 可以在 Amazon P5 等高性能实例上实现高效的集群扩展性能,同时有效降低显存消耗,支持高效的大型基础模型训练。


*前述特定亚马逊云科技生成式人工智能相关的服务仅在亚马逊云科技海外区域可用,亚马逊云科技中国仅为帮助您了解行业前沿技术和发展海外业务选择推介该服务。

本篇作者

郑昊

亚马逊云科技 AI/ML 解决方案架构师。主要专注于基础模型的训练、推理及性能优化;广告、排序算法等及其基于亚马逊云科技 AI/ML 技术栈的相关优化及方案构建。在阿里、平安有多年排序、定价及竞价机制等算法研发经验。

詹健宇

亚马逊云技术客户经理,在内核安全、电商推荐、机器学习领域均有经验。目前专注于 AI/ML 领域,致力于结合客户场景的 AI 解决方案的落地。

梁宇辉

亚马逊云科技机器学习产品技术专家,负责基于亚马逊云科技的机器学习方案的咨询与设计,专注于机器学习的推广与应用,深度参与了很多真实客户的机器学习项目的构建以及优化。对于深度学习模型分布式训练,推荐系统和计算广告等领域具有丰富经验。