CIFAR-10 数据集,是为数不多的可以用笔记本跑的 深度学习 数据集,总共有6w张彩色图片,图片大小为32x32,分为10类,其中5w张训练集,1w张测试集,本文主要介绍在 Pytorch 中使用该数据集的方法。
关联文档
https://pytorch.org/vision/stable/datasets.html
https://www.cs.toronto.edu/~kriz/cifar.html
代码示例
1、下载数据集
from torchvision import datasets # 下载数据集到datas目录 file_path = './datas' # 文件存在后,就不需要重复下载了 cifar10 = datasets.CIFAR10(file_path, download=True)
2、使用 matplotlib 展示图片
# 读取并显示一个样本 from matplotlib import pyplot as plt img, label = cifar10[0] plt.imshow(img) plt.show()
3、数据转为 Tensor 格式
from torch.utils.data import DataLoader # 读取数据,并转化为Tensor,默认PIL.image cifar10 = datasets.CIFAR10(file_path, transform=transforms.ToTensor()) # 加载数据,并打乱顺序 data = DataLoader(cifar10, batch_size=10, shuffle=True) for img, label in data: print(img.size()) # 图像数据 print(label.item()) # 图像类别
4、常用transform方法
train_trans = transforms.Compose([ transforms.CenterCrop(), # 中心裁剪 transforms.Grayscale(), # 转灰度图 transforms.Resize((32, 32)), # 缩放 transforms.RandomCrop(32, padding=4), # 随机裁剪 transforms.ToTensor(), # 图片转张量,同时归一化0-255,范围0-1 ]) data = datasets.CIFAR10(file_path, transform=train_trans)
本文为 陈华 原创,欢迎转载,但请注明出处:http://ichenhua.cn/read/235
- 上一篇:
- 手写AI算法之TF-IDF关键词提取
- 下一篇:
- Pytorch自定义数据集Dataset