输入输出
数据加载(Dataloader)
Dataloader
是数据集读取的接口,该接口的目的是将自定义的Dataset根据 batch_size
大小、是否shuffle等封装成一个 batch_size
大小的数组,用于网络的训练。
Dataloader
由数据集和采样器组成,初始化参数如下:
dataset(Dataset) – 传入的数据集
batch_size(int, optional) – 每个batch的样本数, 默认为1
shuffle(bool, optional) – 在每个epoch开始的时候,对数据进行重新排序,默认为False
sampler(Sampler, optional) – 自定义从数据集中取样本的策略
batch_sampler(Sampler, optional) – 与sampler类似,但是一次只返回一个batch的索引
collate_fn(callable, optional) – 将一个list的sample组成一个mini-batch的函数
drop_last(bool, optional) – 如果设置为True,对于最后一个batch,如果样本数小于batch_size就会被扔掉,比如batch_size设置为64,而数据集只有100个样本,那么训练的时候后面的36个就会被扔掉。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
以导入MNIST数据集为例:
root = './Datasets/MNIST' # 数据集的地址
train_set = dataset(root, is_train=True) # 训练集
test_set = dataset(root, is_train=False) # 测试集
bat_size = 20
# 创建DataLoader
train_loader = spaic.Dataloader(train_set, batch_size=bat_size, shuffle=True)
test_loader = spaic.Dataloader(test_set, batch_size=bat_size, shuffle=False)
Note
- 需要注意的是:
1、创建
Dataloader
时如果指定了sampler
这个参数,那么shuffle
必须为False2、如果指定
batch_sampler
这个参数,那么batch_size
,shuffle
,sampler
,drop_last
就不能再指定了