当前位置:首页 > 文章列表 > 文章 > python教程 > PyTorchDDP多进程训练在Kaggle的正确启动方式

PyTorchDDP多进程训练在Kaggle的正确启动方式

2026-01-20 14:36:57 0浏览 收藏

小伙伴们对文章编程感兴趣吗?是否正在学习相关知识点?如果是,那么本文《PyTorch DDP多进程训练在Kaggle正确启动方法》,就很适合你,本篇文章讲解的知识点主要包括。在之后的文章中也会多多分享相关知识点,希望对大家的知识积累有所帮助!

PyTorch DDP 多进程训练在 Kaggle 笔记本中的正确启动方式

在 Kaggle 等基于 Jupyter 的环境中直接运行 PyTorch DDP(DistributedDataParallel)多进程代码会因 `__main__` 模块序列化失败而报错;根本解决方案是将 DDP 主逻辑写入独立 `.py` 文件,并通过命令行方式执行,避开 notebook 的模块上下文限制。

PyTorch 的 torch.multiprocessing.spawn 要求被启动的函数(如 main)必须可被子进程通过 pickle 反序列化——这在标准 Python 脚本中自然成立,因为 if __name__ == "__main__": 块内定义的函数属于顶层模块 __main__。但在 Kaggle 或 Jupyter Notebook 中,整个 cell 代码实际运行在 这一动态、不可序列化的内置模块中,导致子进程无法定位 main 函数,从而抛出:

AttributeError: Can't get attribute 'main' on 

✅ 正确做法:分离定义与执行
将 DDP 训练逻辑封装为标准 .py 文件,而非在 notebook cell 中直接调用 mp.spawn()。

✅ 实施步骤(Kaggle 环境)

  1. 使用 %%writefile 魔法命令创建独立脚本
    在 notebook 新建 cell,粘贴并保存完整 DDP 代码(参考 PyTorch 官方示例),顶部添加 %%writefile ddp.py:
%%writefile ddp.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.multiprocessing as mp
from torchvision import datasets, transforms
import os

def main(rank, world_size, epochs=5, batch_size=32, lr=1e-3):
    # 初始化进程组
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank
    )

    # 设置设备
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # 构建模型、数据集、优化器等(此处省略细节)
    model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10)).to(device)
    model = DDP(model, device_ids=[rank])

    train_dataset = datasets.MNIST("./data", train=True, download=True,
                                    transform=transforms.ToTensor())
    sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        sampler.set_epoch(epoch)  # 关键:确保每个 epoch 数据打乱一致
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data.view(data.size(0), -1))
            loss = nn.functional.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

    dist.destroy_process_group()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--world_size", type=int, default=torch.cuda.device_count())
    args = parser.parse_args()

    # 注意:Kaggle 中需显式设置环境变量(spawn 自动读取)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"

    mp.spawn(main, args=(args.world_size, 5, 32, 1e-3), nprocs=args.world_size, join=True)
  1. 在另一个 cell 中执行脚本
    使用系统命令运行,绕过 notebook 解释器上下文:
!python -W ignore ddp.py

⚠️ 注意事项:

  • 务必设置 MASTER_ADDR 和 MASTER_PORT:spawn 依赖这些环境变量初始化 NCCL 后端,Kaggle 默认未设置。
  • 避免在 notebook 中直接调用 mp.spawn():即使加了 if __name__ == "__main__":,notebook 的 __main__ 仍不可序列化。
  • -W ignore 是可选的:用于抑制 PyTorch 分布式警告(如 UserWarning: ... is deprecated),提升日志可读性。
  • 单节点多卡适用:本方案专为 Kaggle 提供的 2×T4 场景设计;跨节点需额外配置 MASTER_ADDR 和网络互通。

该方法严格遵循 Python 多进程的“spawn”启动方式语义,确保每个子进程从干净的 .py 文件入口重新导入模块,彻底规避 AttributeError。这是在受限 notebook 环境中安全启用 PyTorch DDP 的工业级实践。

理论要掌握,实操不能落!以上关于《PyTorchDDP多进程训练在Kaggle的正确启动方式》的详细介绍,大家都掌握了吧!如果想要继续提升自己的能力,那么就来关注golang学习网公众号吧!

盐神居改作品信息及简介教程盐神居改作品信息及简介教程
上一篇
盐神居改作品信息及简介教程
Python系统部署原理与实战教程
下一篇
Python系统部署原理与实战教程
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    543次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    516次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    500次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    485次学习
查看更多
AI推荐
  • ljg-skills -
    ljg-skills
    ljg-skills 是李继刚开源的 AI 技能与提示词集合,面向大模型使用者整理了一批可复用的 prompt、角色设定和任务技能模板,适合用于学习提示词设计、搭建个人 AI 工作流和沉淀团队常用智能体能力。
    871次使用
  • MELO音乐 - AI 音乐生成平台,支持多模态创作能力
    MELO音乐
    MELO音乐是一站式AI视频与音乐制作助手,对标suno, udio的高品质体验。提供伴奏生成、原创写词、无损导出、哼唱识曲、混音变声等全套音频与短视频编辑工具。无论是流行Kpop、电音说唱、民谣古风、摇滚儿歌还是商用轻音乐,MELO为你免费谱曲,轻松做同款!
    847次使用
  • UniScribe - AI 免费在线音视频转文字平台
    UniScribe
    UniScribe 是一款 AI 音视频转文字与内容整理工具,支持上传音频、视频文件或粘贴 YouTube 链接,自动生成转写文本、摘要、思维导图和关键问题,并支持多格式导出,适合会议记录、课程学习、访谈整理和内容创作复盘。
    784次使用
  • 剧云 - 免费 AI 智能中文剧本创作平台
    剧云
    剧云是专业中文剧本创作平台,安全稳定运行十余年,集成AI编剧、剧本医生审核、人物小传、剧情关系图、大纲编写、多人协作、Word导入导出、版权管控功能,数据安全防护,轻松高效创作剧本。
    976次使用
  • 万象有声 - AI 一站式有声内容创作平台
    万象有声
    万象有声,一个专为有声创作者打造的新一代智能有声内容创作平台。平台提供专业的智能拆章、智能画本编辑、AI配音、AI生成音效、后期制作、智能对轨、智能审听等有声创作全流程工具,可以帮助创作者高效、低成本创作出引人入胜的有声作品。立即体验,让有声书制作更简单!
    950次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码