亚马逊AWS官方博客
深度解析 TalkingData 使用 DJL 进行大规模深度学习打分应用
前言
TalkingData 是一家总部在北京的数据智能服务提供商,通过提供数据智能产品和服务,来帮助企业获得对消费者行为、偏好和倾向的洞察。TalkingData的一项重要服务是基于机器学习的用户行为分析:通过分析用户信息,可以为用户提供更具有价值的广告。比如有一个汽车经销商想要为想买车的用户投放最近的促销信息,他可以通过这个产品来找到在未来三个月有买车倾向的用户群,进而定向投放广告。最开始,TalkingData的模型是基于XGBoost构建的。后来随着技术演进和精度要求的提升,TalkingData研发部门进一步开发了基于深度学习模型的应用。经过实验以及测试论证,他们的数据科学家成功用PyTorch将模型的recall rate(recall rate是模型在阈值下是否能够提供推理的比例)提升了13%。换句话说,相比于传统机器学习模型,他们的深度学习模型在基于相同的精度情况下可以带来更多的深度学习推理结果。
但是TalkingData在大规模部署深度学习应用中遇到了很大的挑战:模型需要每天对数亿的数据进行深度学习推理。为了能够更高效地进行大规模计算,他们使用了基于Apache Spark的大规模分布式架构来快速批量推理。可是,由于Spark主要是基于JVM的框架,使用Python应用(PySpark) 进行深度学习推理往往会造成内存溢出问题。因为基于JVM本身的内存管理很难去对一个Python的进程产生影响。在过去,因为XGBoost对于Java的支持,TalkingData可以使用XGBoost Java API在Spark平台进行部署。现今使用了PyTorch,由于没有一个很好设计的Java API,以及各种内存溢出问题,他们没有办法在Apache Spark上调用PyTorch模型进行推理任务。这导致了他们被迫转向使用一个GPU的实例来单独进行深度学习推理,这种方案大大增加了后期维护成本。
通过这篇文章,TalkingData发现了AWS基于Java开发的深度学习框架DJL(Deep Java Library)可以很好的解决上述的困境。在这个博客中,我们将带领大家了解TalkingData部署的模型,以及他们是如何利用DJL在Apache Spark上实现生产环境部署深度学习模型。这个解决方案最终将之前的生产架构简化,一切任务都可以在Apache Spark轻松运行,总时间也减少了66%。从长远角度上,这也显著节省了维护成本。
关于模型
该模型为一个用于推断活跃用户是否有可能购买汽车的二分类模型,使用的特征来自于嵌入TalkingData SDK的应用收集的数据。在将原始数据聚合和处理的过程中,特征不可避免地会成为稀疏的分类特征。当TalkingData使用传统的机器学习模型(例如逻辑回归和XGBoost)时,这些简单的模型在从这些稀疏的特征中学习的过程中很容易过拟合。 另外,考虑到数以百万计的训练数据可以支持更复杂,更强大的模型,TalkingData将其模型升级为了DNN(深度神经网络)模型。
在合规性前提下,TalkingData模型将用户信息、用户应用信息和广告事件作为输入。用户信息包括设备名称和设备型号,用户应用信息包含嵌入SDK的应用包名,广告事件信息是用户参与的广告活动信息。这些不同域的输入根据时间聚合,然后预处理成分类特征,预处理包括标记化(tokenization)和规范化(normalization)。 受Wide and Deep learning和 YouTube Deep Neural Networks的启发,首先将分类特征根据预先生成的映射表映射到对应的索引,截断或添补为固定长度,然后再输入到PyTorch DNN模型。模型基于对应每个域的词嵌入进行训练。嵌入 (embedding) 是一种用数值向量表征分类变量的方法:用于降低稀疏的分类变量维度。例如,百万维的类别特征可以用几百维的向量表征,实现模型特征的降维。然后对不同域的嵌入简单的求平均,再拼接成固定长度的向量,喂入前馈神经网络。在训练过程中,最大训练轮数设置为40,提前停止轮数设置为15。与Spark XGBoost 模型相比,DNN模型在测试集上的AUC(Area under the ROC Curve)提升了 6.5%,期望精确度下的召回提升了26%。考虑到TalkingData巨大的数据量,DNN模型的结果很不错 。
生产环境中的困境
虽然模型的效果很令人满意,但是部署深度学习模型成为了很大的困难。由于生产环境中主要的代码都是基于Scala的,直接部署PyTorch在Scala上面临着内存溢出的挑战:JVM的资源回收系统无法看到C++所使用的资源(底层PyTorch API). 为了避免频繁的任务失败问题,最终TalkingData选择使用了单独开启一个GPU的实例来做离线的大数据推理任务。
但是,这个解决方案没有很好的解决下面几个问题
- 性能问题:下载和上传数据需要花费大约30分钟时间
- 单点故障问题:无法使用类似于Spark的多点计算功能。推理任务完全在一个单GPU的机器上进行,如果出现任务失败,那GPU本身没有回退机制可以良好应对。
- 维护问题:生产环境中需要同时维护Scala和Python两个环境
- 扩展问题:如果数据量再增大一些,单GPU的处理性能可能不足。
总体的数据量大约几百GB。总体任务在上述框架下需要6个多小时才可以完成,比TalkingData预期的时间超出两倍。这个设计最终成为了整个生产环节的性能瓶颈。
基于DJL的实现
为了优化这个方案,TalkingData采用DJL重构了他们的推理应用。DJL提供了基于Java的PyTorch引擎库,这使得他们可以直接将这个库部署在Spark上。如下图所示,所有的任务都可以在Spark集群中实现:
这个设计体现了下面几个优势:
- 降低了失败率:相比于单点故障,Spark可以很轻松的调度算力来进行重启。
- 降低算力成本:相比于GPU的解决方案,完全基于Spark的方案可以充分使用Apache Spark本身的算力,从而节约成本。减少使用GPU机器大约节省了20%的总计算成本。
- 降低维护成本:Spark的容错机制可以轻松应对故障,同时单一语言降低了多个语言维护的成本。
- 大幅提升性能:DJL的多线程支持在Apache Spark上提升了性能。
在使用了DJL之后,TalkingData成功的将总体任务运行时间降低到了2小时左右。相比于从前的单GPU解决方案,性能提升了三倍。同时,这套方案无论从短期还是长期都降低了运行成本。
DJL的多线程优势
DJL可以让用户更灵活的调度分布式算力。在PyTorch高级设置中,用户可以选择优化算子的并行数,还有最多线程的并行数(用于提升推理性能)。DJL提供了类似的选项:只需设定num_interop_threads 和 num_threads 便可以轻松调度。这些选项可以同时与Apache Spark每个Executor的核心数目一起改变 –num-executors 这样他们都可以使用相同的CPU数目。这样,DJL可以像PyTorch本身一样帮助用户进一步调优算力。
准确度测试
为了确保基于DJL的PyTorch解决方案与基于Python的PyTorch结果一致。TalkingData数据科学组进行了严谨的比对测试。测试集大约44万条。下面是PyTorch Python与DJL的比对结果:
Column 1 | Column 2 | Column 3 |
项目 | 结果 | 解释 |
总数 | 4387830 | 测试集总数 |
平均值 | 5.27E-08 | 偏差平均值 |
方差 | 1.09E-05 | 方差 |
p25 | 0.00E+00 | 从小到大排列前25%的数据偏差 |
p50 | 2.98E-08 | 从小到大排列前50%的数据偏差 |
p75 | 5.96E-08 | 从小到大排列前75%的数据偏差 |
p99 | 1.79E-07 | 从小到大排列前99%的数据偏差 |
max | 7.22E-03 | 最大偏差值 |
这个实验验证了DJL在推理应用上十分可靠,超过99%的数据都是在10^-7以内的,换句话说,浮点数差别低于0.0000001。我们同时也验证了最大偏差值的产生原因:是因为数据在传递过程中的精度损失导致的。
想了解更多关于DJL在Spark上的应用,可以参考GitHub上DJL Spark的案例,也可以参考这篇博客了解更多关于Spark的推理应用。
总结
TalkingData现今已经在生产环境中使用DJL在Apache Spark上进行大规模的深度学习推理应用。他们选择DJL的几个主要原因 1)显著减少了很大的其他架构的维护成本 2)DJL帮助TalkingData充分利用了已有的Spark算力 3)DJL不局限于深度学习引擎,他们可以在很少改变代码的情况下,在未来新的任务上很轻松的部署任何深度学习模型。
关于DJL
Deep Java Library (DJL) 是一个基于Java的深度学习框架,同时支持训练以及推理。DJL博取众长,构建在多个深度学习框架之上 (TenserFlow, PyTorch, MXNet, etc)也同时具备多个框架的优良特性。你可以轻松使用DJL来进行训练然后部署你的模型。它同时拥有着强大的模型库支持:只需一行便可以轻松读取各种预训练的模型。现在DJL的模型库同时支持高达70个来自GluonCV, HuggingFace, TorchHub以及Keras的模型。
请参考我们的 GitHub, demo repository, Slack channel 以及知乎来获取更多信息!