微信基于 PyTorch 的大规模推荐系统训练实践
从现在开始,我们要努力学习啦!今天我给大家带来《微信基于 PyTorch 的大规模推荐系统训练实践》,感兴趣的朋友请继续看下去吧!下文中的内容我们主要会涉及到等等知识点,如果在阅读本文过程中有遇到不清楚的地方,欢迎留言呀!我们一起讨论,一起学习!
本文将介绍微信基于 PyTorch 进行的大规模推荐系统训练。推荐系统和其它一些深度学习领域不同,仍在使用 Tensorflow 作为训练框架,被广大开发者诟病。虽然也有使用 PyTorch 进行推荐训练的一些实践,但规模较小,也没有实际的业务验证,很难推动业务尝鲜。
2022 年 2 月,PyTorch 团队推出了官方推荐库 TorchRec。我们团队在 5 月开始在内部业务上尝试 TorchRec,并且与 TorchRec 团队展开了一系列的合作。在几个月的试用过程中,我们体会到 TorchRec 非常多的优点,也感受到 TorchRec 在超大规模模型上仍存在一些不足。针对这些不足,我们设计了扩展功能来填补它的问题。在 2022 年 9 月,我们设计的扩展功能 dynamic embedding 已经正式合进了 TorchRec 的主分支,目前仍在与官方团队持续优化。
一、TorchRec 可以为我们带来什么
我们先来聊聊 TorchRec 可以给我们带来什么?我们都知道推荐系统往往和公司的现金流直接挂钩,试错成本非常高,所以大家需要的是一个经过了业务测试的框架。这也是为什么之前的一些基于 PyTorch 的推荐框架都未曾被广泛应用过。而 TorchRec 作为一个官方的推荐框架,在 2022 年 1 月份推出之时,Meta就已经利用它在 Instagram Reels 业务上顺利训练并上线了一个 1250 亿参数的模型,成为了一个经过业务测试的 PyTorch 框架。有了 Instagram 这样一个大业务的支撑,让我们有了更多信心,终于可以去理性地考量一个基于 PyTorch 的推荐框架有什么样的优势了。
对于团队中的不同成员,TorchRec 有不同的好处。首先,对于团队中占绝大多数的算法工程师而言,PyTorch 推荐框架让大家终于可以享受到像 CV、NLP 工程师体会到的那种更人性化的动态图和调试的体验。
另外,PyTorch 极好的兼容性——一个基于 PyTorch1.8 做的模型,不需要改一行代码就可以在最新版本 1.13 上运行——让算法工程师终于可以放心地升级框架,从而享受到最新的框架功能和更优秀的性能。而反观一些基于 TensorFlow 的推荐框架,往往被卡在 TensorFlow 的某一个版本上,例如很多团队可能还在使用基于 TensorFlow 1.x 的内部框架。TensorFlow 1.x 在 2021 年 1 月份就已经停止维护了,这就意味着在近两年的时间内,所有新出的 bug、新出的特性都无法得到很好的支持。使用过程中遇到的问题,也只能靠内部维护团队去修复,增加了额外的成本。及时的框架升级还可以带来免费的速度提升,高版本的 PyTorch 往往匹配更高版本的 CUDA,以及像 CUDA graph 等的一些新特性,可以进一步提升训练速度,提升训练效率。
除了算法工程师,框架团队也是推荐团队的重要组成部分。公司中的框架团队会在选取开源框架之后基于内部需求进行二次开发。对于他们来说,一个 PyTorch 的推荐框架会带来更简化的开发体验。很多传统的 TensorFlow 推荐框架会模仿 TF serving 来做一个基于 C++ session 的扩展——这样的设计方案在当时算是非常先进的方案——但这使得只改一行代码也需要完整地编译整个 TensorFlow,耗时很长,甚至还要在解决在内网下载外部的依赖之类的琐碎问题,开发体验不太好。
使用 PyTorch 不会遇到这样的问题,因为 PyTorch 以 Python 哲学为核心,它希望大家能够自如地进行扩展。我们在进行二次开发的时候,只需要用 pybind11 这样比较成熟的 Python 库封装一下,把我们的库打包成一个动态链接库,就可以加载了。这样自然整体编译速度会快很多,同时学习成本也会低不少。
前面提到 PyTorch 是一个向后兼容性非常好的框架,这让维护团队不需要去维护多个版本,很多共性的问题都可以得到官方的解决,大家就可以专注于特化需求,团队人员效率就会有明显提升。
上面介绍的都是 TorchRec 作为一个 PyTorch 推荐框架的优势,让我们感到非常开心的是,TorchRec 团队没有止步于做一个 PyTorch 推荐框架。他们观察了现有推荐模型以及硬件的特点,在框架中加入了许多的新特性,使得 TorchRec 相比于传统的推荐框架有明显的性能优势。接下来我会选择其中的几个来进行介绍,分别是 GPU embedding,TorchRec 里面优秀的 GPU kernel,还有 TorchRec 能够根据网络通信进行的 embedding 划分。
首先是 GPU embedding。我们先来回顾一下传统的推荐系统 GPU 训练流程,我们会把具体的模型放在 GPU worker 上,embedding 存在远端 PS 上。每个迭代步会先从远端 PS 拉取参数,之后在 GPU 上进行模型的前向和反向计算,把梯度传回给 PS,在 PS 上进行参数更新。
图中绿色的部分是在 GPU 上进行的操作,红色的部分是网络或者 CPU 上进行的。可以看到虽然 GPU 是系统中最昂贵的部分,很多操作却都没有放在 GPU 上。
传统流程并没有充分利用好 GPU。同时,从硬件层面来说,GPU 单卡显存越来越大,dense 部分模型远远没有充分利用 GPU;在英伟达的不断优化下,NV link 以及 GPU direct RDMA 还让卡间通信速度越来越快。
GPU embedding 是一个非常简单的方案。他直接把 embedding 切分放在 GPU 上——比如单机上有 8 张卡,我们把 embedding 直接切分为 8 份,每份放在一张卡上——从而保证所有的操作全都留在卡上。GPU 的利用效率就会有明显提升,训练速度也会有质的飞跃。如果担心 GPU 上面的显存空间不足,TorchRec 还做了 UVM 的支持,可以提前划分一部分主机上的内存作为显存的补充,从而提升单机内部能放下的 embedding 大小。
除去 GPU embedding 以外,TorchRec 还实现了非常优秀的 GPU kernel。这些 kernel 充分利用了最新的硬件特性和 CUDA feature。
举例来说,假如果要实现一个 embedding lookup kernel,也就是要从一个大的 embedding 里面找到一堆 ID 对应的 embedding vector,那么普通的实现里,会给每个 GPU thread 分配一个 ID,让他们分别去找对应的 embedding。这个时候我们要考虑到,GPU 底层是按 warp 进行调度的,一个 warp 里的 32 个 thread 会一起进行显存读写。这意味着,在上述样流程里,虽然在读取 ID 时连续地访问了显存,但后续的拷贝变成了一个随机读写的状态。对于硬件来说,随机读写无法充分利用显存带宽,运行效率也就不够高。
TorchRec 则是在每个 thread 读到 ID 后,利用 shuffle_sync 这样的 warp primitive,将 ID 广播至 warp 内的所有thread 上,从而让一个 wrap 里 32 个 thread 去同时处理同一个 embedding,从而可以进行连续的内存读写,使得显存的带宽利用效率有明显的提升,让 kernel 的速度得到数倍提升。
这个表是官方测试的 embedding lookup 性能提升。这里 Fused EBC 是优化后的kernel,可以看到,不同的设置情况下 TorchRec 相较于原生的 PyTorch 有数十倍的性能提升。在 TorchRec 的基础之上,我们发现对于 embedding 比较小的情况(小于128),可能有半数甚至更多的 thread 空闲,所以进一步把 warp 内的 thread 分组,让他们同时去处理多条 embedding。
在我们的改进下,小 embedding dim 上 kernel 又有了 10% 到 30% 的提升。这一优化也已经合入官方 repo。要特别指出的是,TorchRec 的 kernel 放在了 FBGEMM 库里,有兴趣朋友可以去看一看。
最后想介绍一下 TorchRec 的 embedding 划分机制。前面提到,GPU embedding 就是把 embedding 切分一下放在卡上,那么怎么分就成了一个需要考虑的问题。传统来说有两种划分思路,Row wise 和 Column wise。Row wise 是指假如有 2 万个 feature, 0 号到第 10000 号放在卡 1 上,10000 号到 20000 号放在卡 2 上,这样我们在训练的时候,如果 ID 对应卡 1,我们就从卡 1 上拿,对应卡 2,就从卡 2 上拿。Row wise 的问题在于,因为我们不清楚前 10000 号的通信量和后 10000 号的是不是差距很大,通信都是不均衡的,无法充分利用网络硬件。
Column wise 则是从 embedding 长度角度去划分。例如 embedding 总长是128,可以前 64 维和后 64 维放在不同的位置,这样通信会比较均衡,但是在读取的时候,需要和所有的卡或者 PS 通信。
划分模式上的差别带来了选型中的 trade-off。传统的推荐框架会在设计中固定 embedding 的划分方式,而 TorchRec 则在支持了多种划分方式——比如 row wise、column wise,甚至 table wise,data parallel——的基础上,在内部提供了如 Planner、Estimator、PerfModel 等可以根据使用场景的带宽、显存、内存、模型大小等等参数自动地去计算划分的方式的模块。这样就可以根据我们实际硬件情况去最高效地划分 embedding,最高效地利用硬件。这些功能大都是在 Python 里面实现的。方便我们针对内部环境进行客制化,从而不费力地构建出一套最适合于我们内部环境的推荐系统。
二、在百亿模型上的实验效果
在我们的实验中,对于 DeepFM、DCN 这样的在标准模,TorchRec 相对于之前的基准的推荐框架会有惊人的 10 至 15 倍的性能提升。拿到了这样的性能收益,让我们有信心把 TorchRec 推到了业务上。
对于微信读书精排模型,在对齐精度的基础上,我们发现在真实数据上有 3 倍左右的性能提升,在假数据上甚至有 10 倍左右提升。这里的差别是因为训练读取数据变成瓶颈了,这方面我们还在做进一步的优化。
03
原始方案在千亿及更大模型上的不足
前面介绍的基本是百亿级别或者以下的模型,也就是单机就可以放得下的模型。在把 TorchRec 推到更大的模型的时候,我们观察到 TorchRec 的原生设计的一些问题。对于大模型来说,TorchRec 的纯 GPU embedding 方案需要更多的卡——可能原本 8 张卡的训练速度就可以吃进全部数据,但是我们要用 16 张卡放下 embedding,这使得好不容易提升上去的 GPU 硬件利用效率又被拖了下来。
而且对于大模型的场景,算法团队往往会提出 embedding 的动态增删需求,例如删除一周没有访问过的 ID。TorchRec 的方案是不支持这样特性的。还有,超大模型的业务一般都会涉及诸多团队,迁移基层框架会遇到很大的阻力。我们需要的支持逐步地渐进迁移,而不能让大家一起放下手头的工作,那样的成本过高,风险太大。
根据上述的需求,我们考虑如何去修改 TorchRec,使得它能够适应超大规模模型的场景。我们认为在超大规模训练中,仍然需要支持连接远程的 PS,因为远端 CPU PS 已经非常成熟了,非常容易支持 embedding 的动态增添。同时,对于跨团队的合作,可以用 PS 来隔离开训练和推理,实现渐进的迁移。
那么接下来一个问题就是该如何引入 PS。如果把 PS 直接连到 GPU embedding 上,每个迭代步还是要去访问远端的 PS,会重新使网络和 CPU 整体操作的占比提升,GPU 利用率又被拉下来了。
04
微信团队的 dynamic embedding 如何解决问题
这个时候我们发现单位时间内数据中的新 ID 实际上只占总数据中很少的一部分, HugeCTR 发表论文中也提到相似的结论:只有一小部分的 ID 会被频繁访问。由此,我们想到先正常使用 GPU embedding 进行训练,在显存放满时再将 ID 批量驱逐至 PS。
根据这样的一个思路,假如 GPU embedding 里面只能存下 n 个 ID,而总 ID 有 N 个,甚至无穷多个。可以将全局的 ID 按顺序映射到 0、1、2、3…,并把把映射关系存在一个叫 ID transform 的结构中,让 GPU embedding 利用映射的结果进行正常的训练。当 GPU embedding 放满了,也就是 ID transformer 中 n 对映射的时候,再批量驱逐 ID 至 PS。
在这种设计下,可以使得 PS 很少介入,只有在驱逐时才需要 GPU worker 和 PS 通信。
除此之外,这样的设计中 PS 只需要作为 KV,不需要支持参数更新,也就不需要实现优化器相关的操作,从而让 PS 团队专注于存储相关的工作。我们也支持实现了任意 KV 存储的插件,在开源版本中更是内置了 Redis 插件,让 Redis 也可以作为一个 PS 来使用。
下面介绍一些 dynamic embedding 中的设计细节。我们实现的最简基础的 ID Transformer,其实也就是用一个哈希表,使用的是 PyTorch 里高性能的 ska::flat_hash_map。
ID Transformer 作为流程中仅有的 CPU 操作,对性能要求可能会比较高,所以我们还实现了一个高性能的版本,以 L1 cacheline 为单位存储,从而进一步提升内存的访存效率。
另外,对于驱逐方案,我们希望在不增加内存缓存压力的情况下,高效地融合 LRU 和 LFU。受到 Redis 的 LFU 方案的启发,我们设计了一种概率的算法:只存储 ID 访问频数的指数。比如访问了 32 次即存储 5。在更新频数时,如果又访问到这个 ID,就生成 5 位的随机数,如果在 5 位全为 0,也就是发生了概率为 1/ 32 的事件,我们就增加频数指数为 6。通过这样的概率算法,就可以把 LRU 和LFU 的频数放到 uint32 里面,在不提高访存压力的情况下融合了 LRU 和 LFU。
最后来简单介绍一下我们的多卡方案。我们目前是将所有卡的数据都先 gather 到卡一的 ID Transformer 上,之后再 broadcast 回去。因为我们实现的 ID Transformer 的性能非常高,而且可以和 GPU 计算 Pipeline 起来,不会成为具体的性能瓶颈。
以上就是 dynamic embedding 在设计上一些想法。在我们内部的一个万亿级的业务上,在对齐精度情况下,dynamic embedding 方案相对于我们内部原始的 GPU Tensorflow 框架有 3 倍左右的性能提升。相比于 TF 优化版也仍然有 50% 以上的性能优势。
最后推荐大家去试用一下 Torchrec。对于相对较小的业务,比如百亿下的业务,推荐大家直接使用原生的 TorchRec:即插即用,不需要任何的二次开发,性能可以得到成倍的提升。对于极大的业务,则推荐大家尝试配合我们合进 TorchRec 的 dynamic embedding,一方面方便连接内部的 PS,另一方面也支持 embedding 的扩展和渐进迁移,同时还是可以获得一定的性能提升。
这里是我们已经对齐的一些精度的模型和已有的应用场景,有兴趣的朋友可以去试一下。
今天关于《微信基于 PyTorch 的大规模推荐系统训练实践》的内容介绍就到此结束,如果有什么疑问或者建议,可以在golang学习网公众号下多多回复交流;文中若有不正之处,也希望回复留言以告知!

- 上一篇
- 司机路上的“神助攻”!北理工研发混合脑机接口驾驶辅助系统,提高驾驶安全性

- 下一篇
- arXiv正式规定:预印本不允许以ChatGPT等工具为作者
-
- 科技周边 · 人工智能 | 5小时前 |
- 小米SU7订单18万未交付,月产能暴增6倍
- 361浏览 收藏
-
- 科技周边 · 人工智能 | 5小时前 | iPhone17Pro 天蓝色 M4MacBookAir
- iPhone17Pro/ProMax弃钛金属,拥抱天蓝色
- 272浏览 收藏
-
- 科技周边 · 人工智能 | 8小时前 |
- 问界M8快报:MAX+版最火,BAL车主热捧
- 335浏览 收藏
-
- 科技周边 · 人工智能 | 10小时前 |
- 港大与Adobe联手推出PixelFlow图像生成模型
- 135浏览 收藏
-
- 科技周边 · 人工智能 | 12小时前 | 摩尔线程 招聘诈骗 @mthreads.com 官方客服 法律责任
- 摩尔线程重磅声明发布
- 406浏览 收藏
-
- 科技周边 · 人工智能 | 15小时前 |
- 玛莎拉蒂GT2Stradale国内首秀售414.5万
- 226浏览 收藏
-
- 科技周边 · 人工智能 | 17小时前 |
- 美股反弹艰难,三大指数涨跌不一,英伟达跌3%
- 301浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 542次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 508次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 497次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- 笔灵AI生成答辩PPT
- 探索笔灵AI生成答辩PPT的强大功能,快速制作高质量答辩PPT。精准内容提取、多样模板匹配、数据可视化、配套自述稿生成,让您的学术和职场展示更加专业与高效。
- 30次使用
-
- 知网AIGC检测服务系统
- 知网AIGC检测服务系统,专注于检测学术文本中的疑似AI生成内容。依托知网海量高质量文献资源,结合先进的“知识增强AIGC检测技术”,系统能够从语言模式和语义逻辑两方面精准识别AI生成内容,适用于学术研究、教育和企业领域,确保文本的真实性和原创性。
- 44次使用
-
- AIGC检测-Aibiye
- AIbiye官网推出的AIGC检测服务,专注于检测ChatGPT、Gemini、Claude等AIGC工具生成的文本,帮助用户确保论文的原创性和学术规范。支持txt和doc(x)格式,检测范围为论文正文,提供高准确性和便捷的用户体验。
- 40次使用
-
- 易笔AI论文
- 易笔AI论文平台提供自动写作、格式校对、查重检测等功能,支持多种学术领域的论文生成。价格优惠,界面友好,操作简便,适用于学术研究者、学生及论文辅导机构。
- 53次使用
-
- 笔启AI论文写作平台
- 笔启AI论文写作平台提供多类型论文生成服务,支持多语言写作,满足学术研究者、学生和职场人士的需求。平台采用AI 4.0版本,确保论文质量和原创性,并提供查重保障和隐私保护。
- 43次使用
-
- GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
- 2023-04-25 501浏览
-
- 单块V100训练模型提速72倍!尤洋团队新成果获AAAI 2023杰出论文奖
- 2023-04-24 501浏览
-
- ChatGPT 真的会接管世界吗?
- 2023-04-13 501浏览
-
- VR的终极形态是「假眼」?Neuralink前联合创始人掏出新产品:科学之眼!
- 2023-04-30 501浏览
-
- 实现实时制造可视性优势有哪些?
- 2023-04-15 501浏览