使用TensorFlow迁移学习进行图像分类

deeplearning.ai第一门课的作业中,猫咪分类器的训练集只有209张图片,测试集只有50张图片,调了很久的参数,训练出来的模型的准确率最高才72%,然后用自己的图片进行预测,预测结果简直就跟随机猜测没啥区别,一直在考虑怎么提高模型的准确率,直到最近才突然想尝试一下进行迁移学习,看看效果。

TensorFlow官方仓库models中提供了一个图像分类的demo库——slim,里面包含了若干预训练的分类模型。这里先学习一下如何用TensorFlow官方的demo进行迁移学习,后面再在这个猫咪数据集上进行迁移学习。

整体思路:

  • 安装slim库
  • 准备数据集
  • 准备预先训练的模型
  • 进行训练
  • 查看训练过程

安装slim库

slim库包含两部分:

  • TensorFlow源码仓库中的库
  • Models仓库中的图像分类模型库

第一部分在我们安装TensorFlow的时候已经自动安装好了,所以我们只需要安装Models仓库中的slim图像分类模型库,具体步骤如下。

  1. 下载models仓库:

    1
    $ git clone https://github.com/tensorflow/models/
  2. slim图像分类模型库位于models的research/slim目录下。我们进入该目录,然后验证一下是否没问题:

    1
    2
    $ cd models/research/slim
    $ python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"

    如果没有提示什么错误信息,则说明一切OK。

准备数据集

这里我们直接使用TensorFlow官方提供的flowers数据集。官方提供了一个脚本来下载并预处理该数据集,我们只需要指定存放该数据集的目录,然后调用该脚本即可:

1
2
3
4
$ DATA_DIR=/tmp/data/flowers
$ python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir="${DATA_DIR}"

下载完成后,脚本会自动将该数据集转换为TensorFlow模型所需要的.tfrecord格式:

1
2
3
4
5
6
7
8
$ ls ${DATA_DIR}
flowers_train-00000-of-00005.tfrecord
...
flowers_train-00004-of-00005.tfrecord
flowers_validation-00000-of-00005.tfrecord
...
flowers_validation-00004-of-00005.tfrecord
labels.txt

准备预训练的模型

官方提供的预训练的模型包括:

我们这里尝试Inception V3这个模型。先指定一个目录用于存放该模型,然后下载并解压该模型:

1
2
3
4
5
6
$ CHECKPOINT_DIR=/tmp/checkpoints
$ mkdir -p ${CHECKPOINT_DIR}
$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
$ tar -xvf inception_v3_2016_08_28.tar.gz
$ mv inception_v3.ckpt ${CHECKPOINT_DIR}
$ rm inception_v3_2016_08_28.tar.gz

训练模型

准备接续后,我们就可以开始训练了。同样地,我们先指定一下目录。

1
2
3
$ DATASET_DIR=/tmp/flowers
$ TRAIN_DIR=/tmp/flowers-models/inception_v3
$ CHECKPOINT_PATH=/tmp/my_checkpoints/inception_v3.ckpt

其中:

  • DATASET_DIR: flower数据集所在的目录
  • TRAIN_DIR: 用来保存训练log的目录
  • CHECKPOINT_PATH: 用来保存模型checkpoint的目录

然后,如果你的PC包含GPU,则输入如下的命令进行训练:

1
2
3
4
5
6
7
8
9
$ python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--model_name=inception_v3 \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits

如果你的PC不包含GPU,则我们只能用CPU来进行训练,则在上面的命令后再加一个选项--clone_on_cpu=True:

1
2
3
4
5
6
7
8
9
10
$ python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--model_name=inception_v3 \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--clone_on_cpu=True

查看训练的过程

启动TensorBoard:

1
tensorboard --logdir=$TRAIN_DIR

然后在浏览器中打开http://localhost:6006,就能看到训练过程。