3.5 图像分类数据集(Fashion-MNIST)

获取数据集

通过torchvision的`torchvision.datasets`来下载这个数据集。第一次调用时会自动从网上获取数据。通过参数`train`来指定获取训练数据集或测试数据集。参数`transform = transforms.ToTensor()`使所有数据转换为`Tensor`,如果不进行转换则返回的是PIL图片。`transforms.ToTensor()`将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为`np.uint8`的NumPy数组转换为尺寸为(C x H x W)且数据类型为`torch.float32`且位于[0.0, 1.0]的`Tensor`。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

print(len(mnist_train), len(mnist_test)) #获取该数据集的大小
#output: 60000 10000

feature, label = mnist_train[0]  # 可以通过下标来访问任意一个样本
print(feature.shape, label)  # Channel x Height x Width
#output: torch.Size([1, 28, 28]) tensor(9)

读取小批量

数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的`DataLoader`中一个很方便的功能是允许使用多进程来加速数据读取。

batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议
上一篇
下一篇