当前位置:首页 > 文章列表 > 科技周边 > 人工智能 > PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

来源:51CTO.COM 2023-11-22 17:52:15 0浏览 收藏

本篇文章主要是结合我之前面试的各种经历和实战开发中遇到的问题解决经验整理的,希望这篇《PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍》对你有很大帮助!欢迎收藏,分享给更多的需要的朋友学习~

从年初到现在,生成式 AI 发展迅猛。但很多时候,我们又不得不面临一个难题:如何加快生成式 AI 的训练、推理等,尤其是在使用 PyTorch 的情况下。

本文 PyTorch 团队的研究者为我们提供了一个解决方案。文章重点介绍了如何使用纯原生 PyTorch 加速生成式 AI 模型,此外,文章还介绍了 PyTorch 新功能,以及如何组合这些功能的实际示例。

结果如何呢?PyTorch 团队表示,他们重写了 Meta 的「分割一切」 (SAM) 模型,从而使代码比原始实现快 8 倍,并且没有损失准确率,所有这些都是使用原生 PyTorch 进行优化的。 

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

博客地址:https://pytorch.org/blog/accelerating-generative-ai/

在阅读本文后,你将会获得以下的了解:

  • Torch.compile:PyTorch 模型编译器, PyTorch 2.0 加入了一个新的函数,叫做 torch.compile (),能够通过一行代码对已有的模型进行加速;
  • GPU 量化:通过降低运算精度来加速模型;
  • SDPA(Scaled Dot Product Attention ):内存高效的注意力实现方式;
  • 半结构化 (2:4) 稀疏性:一种针对 GPU 优化的稀疏内存格式;
  • Nested Tensor:Nested Tensor 把 {tensor, mask} 打包在一起,将非均匀大小的数据批处理到单个张量中,例如不同大小的图像;
  • Triton 自定义操作:使用 Triton Python DSL 编写 GPU 操作,并通过自定义操作符注册轻松将其集成到 PyTorch 的各种组件中。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

PyTorch 原生特性所带来的吞吐量增加以及减少的内存开销。

有关此研究的更多信息,请参考Meta提出的SAM。详细文章可在「CV不存在了?Meta发布「分割一切」AI模型,CV或迎来GPT-3时刻」中找到

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

接下来,我们将介绍SAM的优化过程,包括性能分析、瓶颈识别,以及如何将这些新功能整合进PyTorch以解决SAM所面临的问题。此外,我们还会介绍PyTorch的一些新特性,包括torch.compile、SDPA、Triton kernels、Nested Tensor以及semi-structured sparsity(半结构化稀疏)

内容的逐层深入,本文最后将介绍快速版 SAM。对于感兴趣的读者,可以前往 GitHub 下载。此外,通过使用 Perfetto UI 对这些数据进行了可视化,以展示 PyTorch 各项特性的应用价值

GitHub 地址:https://github.com/pytorch-labs/segment-anything-fast 可以找到这个项目的源代码

对分割一切模型 SAM 的重写

该研究指出,本文使用的SAM基线数据类型为float32 dtype,批处理大小为1,并使用PyTorch Profiler来查看核心追踪的结果如下:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

本文发现 SAM 有两个地方可以优化:

第一个是对 aten::index 的长调用,这是由张量索引操作(例如 [])产生的底层调用导致的。然而实际上 GPU 花费在 aten::index 上的时间相对较低,原因在于 aten::index 在启动两个内核的过程中,两者之间发生了阻塞 cudaStreamSynchronize。这意味着 CPU 会等待 GPU 完成处理,直到启动第二个内核。因而为了优化 SAM,本文认为应该致力于消除导致空闲时间的阻塞 GPU 同步。

第二个问题是在矩阵乘法中,SAM花费了大量的GPU时间(如图所示的深绿色部分),这在Transformers模型中非常普遍。如果我们能够减少SAM模型在矩阵乘法上的GPU时间,那么我们就能够显著提高SAM的速度

接下来,我们将以SAM的吞吐量(img/s)和内存开销(GiB)来建立基准。然后就是优化过程

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

需要进行改写的句子是:Bfloat16 半精度(加上 GPU 同步和批处理)

为了解决上述问题,即减少矩阵乘法所需的时间,本文转向bfloat16。bfloat16是常用的半精度类型,通过降低每个参数和激活的精度,能够节省大量的计算时间和内存

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍


将填充类型替换为 bfloat16

此外,本文发现有两个位置可以进行优化,以移除 GPU 同步

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍


PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

具体来说,根据上图更容易理解,该研究发现在SAM的图像编码器中,有两个变量q_coords和k_coords充当坐标缩放器,这些变量都在CPU上进行分配和处理。然而,一旦这些变量用于在rel_pos_resized中建立索引,索引操作会自动将这些变量移动到GPU上,从而导致GPU同步的问题。为了解决这个问题,该研究指出可以使用torch.where函数重写这部分内容来解决问题,具体如上所示

核心追踪

在对这些更改进行应用之后,我们注意到单个内核调用之间存在明显的时间间隔,特别是在小批量(这里为1)的情况下更为明显。为了更深入地了解这一现象,我们开始对批大小为8的SAM推理进行性能分析

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

在分析每个内核所花费的时间时,我们注意到 SAM 的大部分 GPU 时间都用于逐元素内核和 softmax 操作

现在可以看到矩阵乘法的相对开销小了很多。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

将 GPU 同步和 bfloat16 优化结合在一起,SAM 性能提高了 3 倍。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

Torch.compile(+graph breaks 和 CUDA graphs)

在研究SAM的过程中发现了许多细小的操作。研究人员认为使用编译器来整合这些操作非常有益,因此PyTorch对torch.compile进行了以下优化

  • 将 nn.LayerNorm 或 nn.GELU 等操作序列融合成一个单一的 GPU 内核;
  • 融合紧跟在矩阵乘法内核之后的操作,以减少 GPU 内核调用的数量。

通过这些优化,该研究减少了 GPU 全局内存往返次数(roundtrips),从而加快了推理速度。我们现在可以在 SAM 的图像编码器上尝试 torch.compile。为了最大限度地提高性能,本文使用了一些高级编译技术:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

核心追踪

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

根据结果显示,torch.compile 的表现非常出色

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

可以观察到 softmax 占了很大一部分时间,然后是各种 GEMM 变体。以下测量的是批大小为 8 及以上的变化。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

SDPA: scaled_dot_product_attention

接下来,本文又对 SDPA(scaled_dot_product_attention)进行了实验,研究的重点是注意力机制。一般来讲,原生注意力机制在时间和内存上随序列长度呈二次方扩展。PyTorch 的 SDPA 操作基于 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力原理构建,可以显着加快 GPU 注意力。与 torch.compile 相结合,这个操作允许在 MultiheadAttention 的变体中表达和融合一个共同的模式。经过一小部分更改后,现在模型可以使用 scaled_dot_product_attention。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

核心追踪

现在可以看到内存高效的注意力内核占用了 GPU 上大量的计算时间:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

使用 PyTorch 的原生 scaled_dot_product_attention,可以显著增加批处理大小。下图为批大小为 32 及以上的变化。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

接下来,该研究进行了对 Triton、NestedTensor、批处理 Predict_torch、int8 量化、半结构化 (2:4) 稀疏性等操作的实验

例如本文使用自定义 positional Triton 内核,观察到批大小为 32 的测量结果。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

采用 Nested Tensor 技术,并调整批大小为 32 及以上

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

添加量化后,批大小为 32 及以上变化的测量结果。

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

文章的最后是半结构化稀疏性。该研究表示,矩阵乘法仍然是需要面对的一个瓶颈。解决的办法是使用稀疏化来近似矩阵乘法。通过稀疏矩阵(即将值归零)可以使用更少的位来存储权重和激活张量。该研究将张量中哪些权重设置为零的过程称为剪枝。剪枝掉较小的权重可以潜在地减小模型大小,而不会显着损失准确率。

剪枝的方法有很多种,从完全非结构化到高度结构化都有。虽然理论上来说非结构化剪枝对精度的影响最小,但是在稀疏情况下,GPU可能会遇到显著的性能下降,尽管在进行大型密集矩阵乘法时非常高效。最近PyTorch支持的一种剪枝方法是半结构化(或2:4)稀疏性,旨在寻求平衡。这种稀疏存储方式将原始张量减少了50%,同时产生密集张量的输出。请参考下图进行说明

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

为了使用这种稀疏存储格式和相关的快速内核,接下来要做的是剪枝权重。本文在 2:4 的稀疏度下选择最小的两个权重进行剪枝,将权重从默认的 PyTorch(“strided”)布局更改为这种新的半结构化稀疏布局很容易。要实现 apply_sparse (model),只需要 32 行 Python 代码:

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

在稀疏度为2:4的情况下,我们观察到vit_b和批大小为32时的SAM峰值性能

PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍

最终,对这篇文章的概括如下:本文介绍了截至目前在PyTorch上实现Segment Anything的最快方法,借助官方发布的一系列新功能,本文在纯PyTorch中重新编写了原始的SAM,并且没有损失准确度

对于感兴趣的读者,可以查看原博客以获取更多信息

好了,本文到此结束,带大家了解了《PyTorch团队重新实现“分割一切”模型,速度比原始实现提升八倍》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多科技周边知识!

版本声明
本文转载于:51CTO.COM 如有侵犯,请联系study_golang@163.com删除
用深度催眠诱导LLM「越狱」,香港浸会大学初探可信大语言模型用深度催眠诱导LLM「越狱」,香港浸会大学初探可信大语言模型
上一篇
用深度催眠诱导LLM「越狱」,香港浸会大学初探可信大语言模型
Stable Video Diffusion来了,代码权重已上线
下一篇
Stable Video Diffusion来了,代码权重已上线
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    542次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    511次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    498次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    484次学习
查看更多
AI推荐
  • 蛙蛙写作:AI智能写作助手,提升创作效率与质量
    蛙蛙写作
    蛙蛙写作是一款国内领先的AI写作助手,专为内容创作者设计,提供续写、润色、扩写、改写等服务,覆盖小说创作、学术教育、自媒体营销、办公文档等多种场景。
    8次使用
  • AI代码助手:Amazon CodeWhisperer,高效安全的代码生成工具
    CodeWhisperer
    Amazon CodeWhisperer,一款AI代码生成工具,助您高效编写代码。支持多种语言和IDE,提供智能代码建议、安全扫描,加速开发流程。
    20次使用
  • 畅图AI:AI原生智能图表工具 | 零门槛生成与高效团队协作
    畅图AI
    探索畅图AI:领先的AI原生图表工具,告别绘图门槛。AI智能生成思维导图、流程图等多种图表,支持多模态解析、智能转换与高效团队协作。免费试用,提升效率!
    49次使用
  • TextIn智能文字识别:高效文档处理,助力企业数字化转型
    TextIn智能文字识别平台
    TextIn智能文字识别平台,提供OCR、文档解析及NLP技术,实现文档采集、分类、信息抽取及智能审核全流程自动化。降低90%人工审核成本,提升企业效率。
    55次使用
  • SEO  简篇 AI 排版:3 秒生成精美文章,告别排版烦恼
    简篇AI排版
    SEO 简篇 AI 排版,一款强大的 AI 图文排版工具,3 秒生成专业文章。智能排版、AI 对话优化,支持工作汇报、家校通知等数百场景。会员畅享海量素材、专属客服,多格式导出,一键分享。
    53次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码