Pytorch 深度学习框架和 ImageNet 数据集深受科研工作者的喜爱。本文使用 Pytorch 1.0.1 版本对 ImageNet 数据集进行图像分类实战,包括训练、测试、验证等。
ImageNet 数据集下载及预处理
数据集选择常用的 ISLVRC2012
(ImageNet Large Scale Visual Recognition Challenge)
下载地址:
- 测试集 http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_test.tar(12.7GB)
- 验证集http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar(6.3GB)
- 训练集http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar(138GB)
预处理:
为了使用 Pytorch 自带的 DataLoader 函数进行数据集加载,我们需要将每一个相同类的图片放到相同的文件夹。
训练集只需要解压缩即可:
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train |
但是验证集图片都在一个文件夹,需要重新分类:
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar |
代码详解
参数设置
参数设置的方式有很多种,有的人喜欢直接在主文件中进行设置;有的人喜欢用 argparse 这个模块;也有人喜欢用 json 格式的文件,我个人喜欢单独创建个 Python 类,以类属性的形式定义参数,详情见下:
class DefaultConfigs(object): |
评价指标
当我们需要评价一个模型的准确率时,需要输出 top1、top5 等准确率,使用下面函数进行封装。其中 AverageMeter
类可快速计算多个值的平均值等。
class AverageMeter(object): |
验证模型准确率
当验证模型和训练模型时都需要使用验证集验证模型准确率,来指导下一步操作。注意需要将 model
切换为 evaluate
模式。其中 torch.no_grad()
表示计算时不会改变模型梯度。
def validate(val_loader, model, criterion): |
训练模型
注意需要将 model
切换为 train
模式。
def train(train_loader, model, criterion, optimizer): |
主体函数
注意在数据集加载时,train_loader
的 shuffle
为 True
。
def main(): |
总结
本文使用的 Pytorch 版本为 1.0.1,且暂时只适用于 ImageNet 数据集,其他数据集需要一定地修改,完整代码地址如下:https://gist.github.com/xunge/d7be591bc1b41350273a61722c0d398a