PyTorch动态批处理技巧分享
哈喽!今天心血来潮给大家带来了《PyTorch动态批次大小管理技巧》,想必大家应该对文章都不陌生吧,那么阅读本文就都不会很困难,以下内容主要涉及到,若是你正在学习文章,千万别错过这篇文章~希望能帮助到你!
PyTorch DataLoader与固定批次大小
在深度学习模型训练中,torch.utils.data.DataLoader是PyTorch提供的一个强大工具,用于高效地加载数据。它通常与Dataset结合使用,负责数据的批处理、打乱和多进程加载等任务。最常见的用法是指定一个固定的batch_size参数:
import torch from torch.utils.data import TensorDataset, DataLoader # 示例数据 x_train = torch.randn(8400, 4) y_train = torch.randint(0, 2, (8400,)) train_dataset = TensorDataset(x_train, y_train) # 使用固定批次大小的DataLoader dataloader_train = DataLoader(train_dataset, batch_size=64, shuffle=True) # 迭代DataLoader for batch_idx, (data, target) in enumerate(dataloader_train): print(f"Batch {batch_idx}: data shape {data.shape}, target shape {target.shape}") if batch_idx == 2: # 仅打印前3个批次 break
这种方法简单直接,适用于大多数场景。然而,在某些特定的训练策略中,我们可能需要根据训练阶段、模型状态或数据特性来动态调整批次大小,例如:
- 课程学习(Curriculum Learning):从简单样本的小批次开始,逐渐增加批次大小。
- 内存优化:处理包含不同大小元素的批次(如变长序列),以最大化GPU利用率。
- 梯度累积的变体:虽然梯度累积本身不改变DataLoader的批次大小,但在某些复杂策略下,可能需要更精细的控制。
解决方案:自定义采样器(Sampler)
PyTorch的DataLoader支持通过sampler或batch_sampler参数来完全控制批次中样本的索引选择。这是实现动态批次大小的关键。
- Sampler:一个Sampler子类负责生成单个样本的索引序列。如果使用自定义Sampler,DataLoader会根据这些索引和指定的batch_size(如果batch_size大于1)来创建批次。
- BatchSampler:一个BatchSampler子类直接生成批次索引的列表。这意味着它直接定义了每个批次包含哪些样本的索引。当使用BatchSampler时,DataLoader的batch_size参数将被忽略。
对于动态批次大小的需求,由于我们希望直接指定每个批次的大小(即每个批次包含多少个样本),因此自定义一个生成批次索引的采样器(更接近BatchSampler的功能)是最佳选择。
实现 VariableBatchSampler
我们将创建一个名为VariableBatchSampler的类,它继承自torch.utils.data.Sampler,但其行为更像一个BatchSampler,直接返回批次的索引列表。
import torch from torch.utils.data import TensorDataset, DataLoader, Sampler class VariableBatchSampler(Sampler): """ 一个自定义采样器,根据预定义的批次大小列表生成批次索引。 它将数据集按顺序切片,形成指定大小的批次。 """ def __init__(self, dataset_len: int, batch_sizes: list): """ 初始化VariableBatchSampler。 Args: dataset_len (int): 数据集的总长度。 batch_sizes (list): 一个整数列表,每个元素代表一个批次的样本数量。 所有批次大小之和应等于或大于dataset_len。 """ if not isinstance(dataset_len, int) or dataset_len <= 0: raise ValueError("dataset_len 必须是正整数。") if not isinstance(batch_sizes, list) or not all(isinstance(bs, int) and bs > 0 for bs in batch_sizes): raise ValueError("batch_sizes 必须是包含正整数的列表。") if sum(batch_sizes) < dataset_len: print(f"警告:批次大小总和 ({sum(batch_sizes)}) 小于数据集长度 ({dataset_len}),部分数据可能不会被加载。") self.dataset_len = dataset_len self.batch_sizes = batch_sizes self.current_batch_idx = 0 # 当前批次在batch_sizes列表中的索引 self.current_start_idx = 0 # 当前批次在数据集中的起始索引 def __iter__(self): """ 使采样器成为一个迭代器。每次新的迭代开始时,重置状态。 """ self.current_batch_idx = 0 self.current_start_idx = 0 return self def __next__(self): """ 生成下一个批次的索引。 """ # 如果已经遍历完所有批次或超出了数据集长度,则停止迭代 if self.current_start_idx >= self.dataset_len or \ self.current_batch_idx >= len(self.batch_sizes): raise StopIteration() # 获取当前批次的大小 current_batch_size = self.batch_sizes[self.current_batch_idx] # 计算当前批次的结束索引 current_end_idx = min(self.current_start_idx + current_batch_size, self.dataset_len) # 生成批次索引 batch_indices = torch.arange(self.current_start_idx, current_end_idx, dtype=torch.long) # 更新状态,为下一个批次做准备 self.current_start_idx = current_end_idx self.current_batch_idx += 1 return batch_indices.tolist() # DataLoader期望的是Python列表
代码解析:
- __init__(self, dataset_len, batch_sizes): 构造函数接收数据集总长度和包含所需批次大小的列表。它会进行一些基本的输入校验。
- __iter__(self): 这个方法使得VariableBatchSampler对象本身可以被迭代。每次新的迭代开始时(例如,每个epoch开始),它会重置内部状态(current_batch_idx和current_start_idx),确保从头开始生成批次。
- __next__(self): 这是迭代器的核心方法。
- 它首先检查是否已生成所有批次或是否已遍历完数据集。如果是,则抛出StopIteration。
- 根据self.batch_sizes列表获取当前批次的目标大小。
- 计算当前批次在数据集中的起始和结束索引。min(..., self.dataset_len)确保不会超出数据集的实际边界。
- 使用torch.arange生成批次的索引张量。
- 更新self.current_start_idx和self.current_batch_idx,指向下一个批次的起始位置和批次大小列表中的下一个元素。
- 返回生成的批次索引列表。
集成到DataLoader中
现在,我们将这个自定义采样器与DataLoader结合使用。
# 示例数据 x_train = torch.randn(8400, 4) y_train = torch.randint(0, 2, (8400,)) train_dataset = TensorDataset(x_train, y_train) # 定义动态批次大小列表 # 注意:这些批次大小的总和不一定需要精确等于数据集长度, # 我们的采样器会处理最后可能不足一个完整批次的情况。 list_batch_size = [30, 60, 110, 200, 50, 150, 90, 120, 70, 180] * 20 # 假设有20个这样的循环 # 确保批次大小总和足够覆盖数据集,或者让DataLoader处理剩余部分 if sum(list_batch_size) < len(train_dataset): print("警告:提供的批次大小总和小于数据集长度,部分数据可能不会被加载。") # 可以选择在末尾添加一个批次以覆盖剩余数据 # list_batch_size.append(len(train_dataset) - sum(list_batch_size)) # 实例化自定义采样器 variable_sampler = VariableBatchSampler(dataset_len=len(train_dataset), batch_sizes=list_batch_size) # 将采样器传递给DataLoader # 推荐使用 batch_sampler 参数 data_loader_dynamic = DataLoader(train_dataset, batch_sampler=variable_sampler, num_workers=0) # num_workers=0 for simplicity print(f"\n使用动态批次大小的DataLoader (通过 batch_sampler):") for batch_idx, (data, target) in enumerate(data_loader_dynamic): print(f"Batch {batch_idx}: data shape {data.shape}, target shape {target.shape}") if batch_idx >= 15: # 仅打印前16个批次 break print(f"总共生成了 {batch_idx + 1} 个批次。")
使用 batch_sampler 的优势:
当你的自定义采样器(如VariableBatchSampler)已经直接返回批次的索引列表时,将其作为DataLoader的batch_sampler参数传递是更推荐的做法。
- 更符合语义:BatchSampler就是用来生成批次索引的。
- 避免额外的维度:如果将VariableBatchSampler作为sampler参数传递,DataLoader会默认将batch_size设置为1(因为你没有显式指定),然后对每个由sampler返回的“批次”再进行一次批处理。这可能导致数据张量多出一个不必要的维度(例如,[1, batch_size, *data_shape])。使用batch_sampler则不会有这个问题。
注意事项
批次大小总和与数据集长度:确保batch_sizes列表中所有元素的总和能够覆盖整个数据集。如果总和小于数据集长度,那么部分数据将不会被模型训练到。如果总和大于数据集长度,VariableBatchSampler会自然地在达到dataset_len时停止。
数据打乱(Shuffling):我们当前的VariableBatchSampler是按顺序生成批次的。如果需要在每个epoch开始时打乱数据,你需要修改采样器:
- 在__init__或__iter__中,首先生成一个打乱的索引列表,例如shuffled_indices = torch.randperm(dataset_len).tolist()。
- 然后,在__next__方法中,从shuffled_indices中按当前批次大小截取子列表作为批次索引。
# 示例:带有打乱功能的VariableBatchSampler (概念性代码) class ShuffledVariableBatchSampler(Sampler): def __init__(self, dataset_len: int, batch_sizes: list): # ... (同上) self.dataset_len = dataset_len self.batch_sizes = batch_sizes self.shuffled_indices = None # 用于存储打乱后的索引 def __iter__(self): self.current_batch_idx = 0 self.current_start_idx = 0 # 在每个epoch开始时打乱索引 self.shuffled_indices = torch.randperm(self.dataset_len).tolist() return self def __next__(self): if self.current_start_idx >= self.dataset_len or \ self.current_batch_idx >= len(self.batch_sizes): raise StopIteration() current_batch_size = self.batch_sizes[self.current_batch_idx] # 从打乱的索引中获取批次 batch_indices_in_shuffled = self.shuffled_indices[self.current_start_idx : self.current_start_idx + current_batch_size] self.current_start_idx += len(batch_indices_in_shuffled) self.current_batch_idx += 1 return batch_indices_in_shuffled
drop_last参数:当使用batch_sampler时,DataLoader的drop_last参数会被忽略,因为批次的构成完全由batch_sampler控制。如果需要丢弃最后一个不完整的批次,你的VariableBatchSampler需要在生成批次索引时自行判断并处理。在我们的实现中,min(..., self.dataset_len)确保了即使最后一个批次不足指定大小,也会包含所有剩余数据。
总结
通过自定义torch.utils.data.Sampler或更具体地使用batch_sampler参数,我们可以灵活地控制PyTorch DataLoader的批次大小,以适应各种复杂的训练策略。VariableBatchSampler提供了一个实现动态、非固定批次大小的有效范例,它通过直接管理批次索引的生成,赋予了用户对数据加载过程的精细控制。在实际应用中,应根据具体需求考虑是否需要结合数据打乱功能。
以上就是本文的全部内容了,是否有顺利帮助你解决问题?若是能给你带来学习上的帮助,请大家多多支持golang学习网!更多关于文章的相关知识,也可关注golang学习网公众号。

- 上一篇
- Win8系统备份教程与恢复方法详解

- 下一篇
- 百度畅听上线电台自动播放功能
-
- 文章 · python教程 | 1小时前 |
- PydanticV2解析逗号浮点数技巧
- 262浏览 收藏
-
- 文章 · python教程 | 1小时前 |
- Python实战:个人理财工具开发教程
- 273浏览 收藏
-
- 文章 · python教程 | 2小时前 |
- Python语言种类及特点对比解析
- 465浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- Pythonrequests库使用教程详解
- 381浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- Ubuntu下Python应用的Docker实践
- 277浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- Python连接Redis的实用技巧与操作方法
- 249浏览 收藏
-
- 文章 · python教程 | 3小时前 |
- Python判断文件或文件夹是否存在方法
- 487浏览 收藏
-
- 文章 · python教程 | 4小时前 |
- Scapy混杂模式错误解决方法分享
- 161浏览 收藏
-
- 文章 · python教程 | 5小时前 |
- Python高效筛选CSV关联JSON日志技巧
- 310浏览 收藏
-
- 文章 · python教程 | 5小时前 | Python Pandas
- Pandas处理NaN数据的实用技巧
- 346浏览 收藏
-
- 前端进阶之JavaScript设计模式
- 设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
- 543次学习
-
- GO语言核心编程课程
- 本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
- 516次学习
-
- 简单聊聊mysql8与网络通信
- 如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
- 499次学习
-
- JavaScript正则表达式基础与实战
- 在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
- 487次学习
-
- 从零制作响应式网站—Grid布局
- 本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
- 484次学习
-
- PandaWiki开源知识库
- PandaWiki是一款AI大模型驱动的开源知识库搭建系统,助您快速构建产品/技术文档、FAQ、博客。提供AI创作、问答、搜索能力,支持富文本编辑、多格式导出,并可轻松集成与多来源内容导入。
- 366次使用
-
- AI Mermaid流程图
- SEO AI Mermaid 流程图工具:基于 Mermaid 语法,AI 辅助,自然语言生成流程图,提升可视化创作效率,适用于开发者、产品经理、教育工作者。
- 1149次使用
-
- 搜获客【笔记生成器】
- 搜获客笔记生成器,国内首个聚焦小红书医美垂类的AI文案工具。1500万爆款文案库,行业专属算法,助您高效创作合规、引流的医美笔记,提升运营效率,引爆小红书流量!
- 1182次使用
-
- iTerms
- iTerms是一款专业的一站式法律AI工作台,提供AI合同审查、AI合同起草及AI法律问答服务。通过智能问答、深度思考与联网检索,助您高效检索法律法规与司法判例,告别传统模板,实现合同一键起草与在线编辑,大幅提升法律事务处理效率。
- 1182次使用
-
- TokenPony
- TokenPony是讯盟科技旗下的AI大模型聚合API平台。通过统一接口接入DeepSeek、Kimi、Qwen等主流模型,支持1024K超长上下文,实现零配置、免部署、极速响应与高性价比的AI应用开发,助力专业用户轻松构建智能服务。
- 1253次使用
-
- Flask框架安装技巧:让你的开发更高效
- 2024-01-03 501浏览
-
- Django框架中的并发处理技巧
- 2024-01-22 501浏览
-
- 提升Python包下载速度的方法——正确配置pip的国内源
- 2024-01-17 501浏览
-
- Python与C++:哪个编程语言更适合初学者?
- 2024-03-25 501浏览
-
- 品牌建设技巧
- 2024-04-06 501浏览