dataset
定位:数据的定义 / 封装 / 预处理。
Dataset(MapDataset)
- 映射式数据集,位于 torch.utils.data.Dataset
- 数据样本与整数索引之间存在一一对应的映射关系,支持按指定索引随机访问单条样本
- 自定义子类必须实现两个方法:
__len__(self):返回数据集的总样本数量,为 DataLoader 提供批量划分、进度计算、索引范围界定的依据。__getitem__(self, index):接收一个整数索引 index(也可以是切片、列表),返回该索引对应的单条样本,在此方法中实现数据读取、预处理逻辑。
- 当 num_workers > 0 时,DataLoader 会自动完成「索引分片」
IterableDataset:
- 迭代式数据集。位于 torch.utils.data.IterableDataset
- 继承自 Dataset,但重写了数据访问的核心逻辑,自定义子类仅需实现
__iter__方法 - 适用于大规模数据集,惰性生成数据,节省内存。工程上常使用,但实现逻辑通常又较复杂,需要多关注
- 当 DataLoader 的 num_workers > 0时,需要额外处理数据分片,避免多个进程重复读取同一部分数据。
dataloader
定位:数据的batch加载 / shuffle / 多进程读取 / 整理(collate_fn)
- 一个可迭代对象
__len__实现时是除以batch size的- 多进程读取数据时,IterableDataset额外处理数据分片例子
1 | import torch |
可迭代对象与迭代器
IterableDataset 与 DataLoader 的协作流程遵循「迭代器协议」,因此想实现好这两个类必须充分理解Iterable和Iterator。
Iterable
- 实现
__iter__方法的对象 __iter__方法返回值必须是一个迭代器对象。例外情况:__iter__中含yield语句- 容器类可迭代对象可可选实现
__len__增强功能 iter(可迭代对象)会返回一个新的迭代器- 持有数据或数据规则,不维护迭代状态
- 实现
Iterator
- 实现
__iter__、__next__方法 __iter__经常return self,不会重置游标状态- 通常不实现
__len__ - 维护迭代状态,带有 “单向移动游标” 的对象
- 实现
为什么需要同时有可迭代对象和迭代器两个概念?
- 可迭代对象无法满足「惰性迭代」需求。惰性迭代需求是必要的、不可替代的,如处理无限序列、处理超大数据、优化性能开销
- 迭代器的「一次性遍历」缺陷无法满足「多次遍历」的核心需求
可迭代对象遍历时访问的是什么方法?
- 第一步:访问可迭代对象的
__iter__()方法,获取一个迭代器 - 第二步:反复访问迭代器的
__next__()方法,获取每个元素 - 当没有更多元素可返回时,
__next__()会抛出 StopIteration 异常,Python 会自动捕获这个异常并终止遍历,不会暴露给开发者。
为什么 yield是例外情况?
函数中使用 yield 关键字是实现迭代器的一种简洁高效的方式。在 Python 中,包含 yield 关键字的函数被称为「生成器函数」,调用生成器函数不会执行函数体,而是直接返回一个「生成器迭代器(Generator Iterator)」。生成器迭代器的 __next__ 逻辑与 yield 的执行流程强绑定,对应关系如下:
- 第一次调用 next(生成器迭代器):执行
__iter__方法体,直到遇到第一个 yield 语句,返回 yield 后的值,暂停方法执行(保留当前变量状态,如循环计数器); - 后续每次调用 next(生成器迭代器):从上次暂停的位置继续执行方法体,直到遇到下一个 yield 语句,返回对应值并再次暂停;
- 当
__iter__方法体执行完毕(无更多 yield 语句),自动抛出 StopIteration 异常,终止迭代,这与__next__方法的终止要求完全一致。
动态批处理
将token拼到一起训练,避免padding降低训练效率,依赖可变长attention。
多源数据集
将多个Dataset(或IterableDataset)作为参数传给MultiSourceDataset,实现自定义采样逻辑,确保在固定的buffer内最大程度拼到设定的最大token数。