【PyTorch】dataset & dataloader对象设计

dataset

定位:数据的定义 / 封装 / 预处理。

  • Dataset(MapDataset)

    • 映射式数据集,位于 torch.utils.data.Dataset
    • 数据样本与整数索引之间存在一一对应的映射关系,支持按指定索引随机访问单条样本
    • 自定义子类必须实现两个方法:
      • __len__(self):返回数据集的总样本数量,为 DataLoader 提供批量划分、进度计算、索引范围界定的依据。
      • __getitem__(self, index):接收一个整数索引 index(也可以是切片、列表),返回该索引对应的单条样本,在此方法中实现数据读取、预处理逻辑。
    • 当 num_workers > 0 时,DataLoader 会自动完成「索引分片」
  • IterableDataset:

    • 迭代式数据集。位于 torch.utils.data.IterableDataset
    • 继承自 Dataset,但重写了数据访问的核心逻辑,自定义子类仅需实现 __iter__方法
    • 适用于大规模数据集,惰性生成数据,节省内存。工程上常使用,但实现逻辑通常又较复杂,需要多关注
    • 当 DataLoader 的 num_workers > 0时,需要额外处理数据分片,避免多个进程重复读取同一部分数据。

dataloader

定位:数据的batch加载 / shuffle / 多进程读取 / 整理(collate_fn)

  • 一个可迭代对象
  • __len__ 实现时是除以batch size的
  • 多进程读取数据时,IterableDataset额外处理数据分片例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch.utils.data import IterableDataset, get_worker_info


class MyIterableDataset(IterableDataset):
"""自定义迭代式数据集(IterableDataset)"""
def __init__(self, data_generator, transform=None):
# 接收数据生成器(或数据流来源),不提前存储所有样本
self.data_generator = data_generator
self.transform = transform

def __iter__(self):
"""必须实现:返回迭代器/生成器,顺序生成样本"""
# 多进程处理:获取当前进程信息,实现数据分片
worker_info = get_worker_info()
worker_id = worker_info.id if worker_info is not None else 0
num_workers = worker_info.num_workers if worker_info is not None else 1

# 模拟数据分片:每个进程处理不同部分的数据流
for idx, (data, label) in enumerate(self.data_generator):
# 仅处理当前进程负责的样本,避免重复
if idx % num_workers == worker_id:
# 预处理
if self.transform is not None:
data = self.transform(data)

# 转换为张量并返回
data = torch.tensor(data, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.long)
yield (data, label)


# 定义一个数据生成器(模拟有序数据流,支持无限生成)
def data_generator():
"""模拟数据流:生成 (数据, 标签) 元组,此处生成 4 条样本(可改为无限循环)"""
for i in range(4):
data = [i+1, i+2, i+3]
label = 0 if i % 2 == 0 else 1
yield (data, label)


# 测试迭代式数据集
iter_dataset = MyIterableDataset(data_generator())
# 仅支持顺序迭代(核心特性)
for sample in iter_dataset:
print(f"数据:{sample[0]}, 标签:{sample[1]}")

# 与 DataLoader 协作(支持多进程分片)
from torch.utils.data import DataLoader
dataloader = DataLoader(
iter_dataset,
batch_size=2,
num_workers=0, # 如需多进程,设为 >0 即可(已实现分片逻辑)
shuffle=False # 原生不支持 shuffle=True
)

print("\n迭代式数据集 + DataLoader 批量加载")
for batch_data, batch_label in dataloader:
print(f"批量数据形状:{batch_data.shape}")
print(f"批量数据:\n{batch_data}")

可迭代对象与迭代器

IterableDataset 与 DataLoader 的协作流程遵循「迭代器协议」,因此想实现好这两个类必须充分理解Iterable和Iterator。

  • Iterable

    • 实现 __iter__ 方法的对象
    • __iter__方法返回值必须是一个迭代器对象。例外情况:__iter__中含yield语句
    • 容器类可迭代对象可可选实现 __len__ 增强功能
    • iter(可迭代对象)会返回一个新的迭代器
    • 持有数据或数据规则,不维护迭代状态
  • Iterator

    • 实现__iter____next__方法
    • __iter__经常return self,不会重置游标状态
    • 通常不实现 __len__
    • 维护迭代状态,带有 “单向移动游标” 的对象

为什么需要同时有可迭代对象和迭代器两个概念?

  • 可迭代对象无法满足「惰性迭代」需求。惰性迭代需求是必要的、不可替代的,如处理无限序列、处理超大数据、优化性能开销
  • 迭代器的「一次性遍历」缺陷无法满足「多次遍历」的核心需求

可迭代对象遍历时访问的是什么方法?

  • 第一步:访问可迭代对象的 __iter__() 方法,获取一个迭代器
  • 第二步:反复访问迭代器的 __next__() 方法,获取每个元素
  • 当没有更多元素可返回时,__next__() 会抛出 StopIteration 异常,Python 会自动捕获这个异常并终止遍历,不会暴露给开发者。

为什么 yield是例外情况?
函数中使用 yield 关键字是实现迭代器的一种简洁高效的方式。在 Python 中,包含 yield 关键字的函数被称为「生成器函数」,调用生成器函数不会执行函数体,而是直接返回一个「生成器迭代器(Generator Iterator)」。生成器迭代器的 __next__ 逻辑与 yield 的执行流程强绑定,对应关系如下:

  1. 第一次调用 next(生成器迭代器):执行 __iter__ 方法体,直到遇到第一个 yield 语句,返回 yield 后的值,暂停方法执行(保留当前变量状态,如循环计数器);
  2. 后续每次调用 next(生成器迭代器):从上次暂停的位置继续执行方法体,直到遇到下一个 yield 语句,返回对应值并再次暂停;
  3. __iter__ 方法体执行完毕(无更多 yield 语句),自动抛出 StopIteration 异常,终止迭代,这与 __next__ 方法的终止要求完全一致。

动态批处理

将token拼到一起训练,避免padding降低训练效率,依赖可变长attention。

多源数据集

将多个Dataset(或IterableDataset)作为参数传给MultiSourceDataset,实现自定义采样逻辑,确保在固定的buffer内最大程度拼到设定的最大token数。

------ 本文结束------
赞赏此文?求鼓励,求支持!
0%