当前位置:首页 > 文章列表 > 文章 > python教程 > PyTorch 无循环张量多对一求和方法

PyTorch 无循环张量多对一求和方法

2026-04-09 18:45:44 0浏览 收藏
本文揭秘了如何利用 PyTorch 的 `scatter_add_` 原语,结合 `repeat_interleave` 和索引展平技巧,以完全向量化、零 Python 循环的方式高效实现一维张量到另一维张量的“一对多”映射累加(如多源值聚合至目标位置),不仅大幅提升 GPU 并行计算效率、保持梯度可导性,还显著简化代码逻辑——告别慢速循环与手动索引遍历,让复杂映射操作变得简洁、健壮且生产就绪。

如何在 PyTorch 中高效实现张量的一对多映射求和(无显式循环)

本文介绍使用 torch.Tensor.scatter_add_ 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。

本文介绍使用 `torch.Tensor.scatter_add_` 配合索引展开与值重复,高效完成一维张量到另一维张量的一对多映射累加操作,避免 Python 循环,完全基于向量化运算。

在 PyTorch 中处理「一对多」映射关系(即每个输入元素贡献至多个输出位置)并执行聚合(如求和)时,若采用 Python 循环或列表推导,不仅代码冗长,更会严重拖慢训练速度、破坏计算图完整性,且无法充分利用 GPU 并行能力。幸运的是,PyTorch 提供了高度优化的原语——scatter_add,它专为这类“按索引分散累加”场景设计,可一次性完成全部映射与聚合。

核心思想是将不规则映射结构(如嵌套列表 mapping)转化为两个齐次一维张量:

  • src:待累加的源值序列,其中每个 input[i] 根据其映射目标数量被重复;
  • index:对应的目标位置索引序列,与 src 严格对齐;
  • out:初始化为零的输出张量,长度由最大目标索引决定。

以下为完整实现示例:

import torch

# 输入定义
input = torch.tensor([0, 1, 2, 3], dtype=torch.float32)
mapping = [[1], [0, 2, 4], [0, 3], [1, 2]]

# 步骤 1:计算各输入项的重复次数(即每个 input[i] 映射到多少个 output 位置)
reps = torch.tensor([len(x) for x in mapping])

# 步骤 2:构建 src —— 按 reps 重复 input 中每个元素
src = input.repeat_interleave(reps)  # tensor([0, 1, 1, 1, 2, 2, 3, 3])

# 步骤 3:构建 index —— 展平 mapping,得到所有 (src[i] → output[j]) 的 j 序列
index = torch.tensor([j for sublist in mapping for j in sublist])  # tensor([1, 0, 2, 4, 0, 3, 1, 2])

# 步骤 4:初始化输出张量(长度 = max(index) + 1)
out = torch.zeros(max(index) + 1, dtype=src.dtype)

# 步骤 5:执行向量化累加:out[j] += src[i] for each (i,j) pair
result = out.scatter_add(dim=0, index=index, src=src)

print(result)  # tensor([3., 3., 4., 2., 1.])

关键优势

  • 全程无 Python 循环,100% 张量操作,支持 CUDA 加速;
  • 时间复杂度为 O(∑|mapping[i]|),空间复杂度为 O(len(output)),理论最优;
  • 自动兼容梯度传播(scatter_add 是可微分操作),适用于模型中间层。

⚠️ 注意事项

  • index 中的索引必须是非负整数,且严格小于 out.size(dim),否则抛出 RuntimeError;
  • 若 mapping 可能为空(如 []),需提前过滤或用 max(index, default=0) 防御;
  • 当 output 维度极大但稀疏时,该方法仍会分配全量内存;如需极致稀疏支持,可考虑结合 torch.sparse 或自定义 CUDA kernel,但绝大多数场景 scatter_add 已足够高效。

总结而言,scatter_add 是解决 PyTorch 中「一对多映射+聚合」问题的标准、简洁且高性能方案。掌握其与 repeat_interleave、索引展平等组合技巧,能显著提升数据预处理与自定义层的表达力与执行效率。

好了,本文到此结束,带大家了解了《PyTorch 无循环张量多对一求和方法》,希望本文对你有所帮助!关注golang学习网公众号,给大家分享更多文章知识!

Win11关闭U盘自动播放方法Win11关闭U盘自动播放方法
上一篇
Win11关闭U盘自动播放方法
Go语言解析JSON文件流程全解析
下一篇
Go语言解析JSON文件流程全解析
查看更多
最新文章
资料下载
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    543次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    516次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    500次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    485次学习
查看更多
AI推荐
  • ChatExcel酷表:告别Excel难题,北大团队AI助手助您轻松处理数据
    ChatExcel酷表
    ChatExcel酷表是由北京大学团队打造的Excel聊天机器人,用自然语言操控表格,简化数据处理,告别繁琐操作,提升工作效率!适用于学生、上班族及政府人员。
    4257次使用
  • Any绘本:开源免费AI绘本创作工具深度解析
    Any绘本
    探索Any绘本(anypicturebook.com/zh),一款开源免费的AI绘本创作工具,基于Google Gemini与Flux AI模型,让您轻松创作个性化绘本。适用于家庭、教育、创作等多种场景,零门槛,高自由度,技术透明,本地可控。
    4614次使用
  • 可赞AI:AI驱动办公可视化智能工具,一键高效生成文档图表脑图
    可赞AI
    可赞AI,AI驱动的办公可视化智能工具,助您轻松实现文本与可视化元素高效转化。无论是智能文档生成、多格式文本解析,还是一键生成专业图表、脑图、知识卡片,可赞AI都能让信息处理更清晰高效。覆盖数据汇报、会议纪要、内容营销等全场景,大幅提升办公效率,降低专业门槛,是您提升工作效率的得力助手。
    4499次使用
  • 星月写作:AI网文创作神器,助力爆款小说速成
    星月写作
    星月写作是国内首款聚焦中文网络小说创作的AI辅助工具,解决网文作者从构思到变现的全流程痛点。AI扫榜、专属模板、全链路适配,助力新人快速上手,资深作者效率倍增。
    6191次使用
  • MagicLight.ai:叙事驱动AI动画视频创作平台 | 高效生成专业级故事动画
    MagicLight
    MagicLight.ai是全球首款叙事驱动型AI动画视频创作平台,专注于解决从故事想法到完整动画的全流程痛点。它通过自研AI模型,保障角色、风格、场景高度一致性,让零动画经验者也能高效产出专业级叙事内容。广泛适用于独立创作者、动画工作室、教育机构及企业营销,助您轻松实现创意落地与商业化。
    4873次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码