diff --git a/transfer_learning.ipynb b/transfer_learning.ipynb new file mode 100644 index 00000000..98369cb7 --- /dev/null +++ b/transfer_learning.ipynb @@ -0,0 +1,506 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# 通过飞桨高层API实现迁移学习任务(102种鲜花识别)\n", + "\n", + "\n", + "作者:[lwb](https://github.com/lwbmowgli)\n", + "\n", + "\n", + "时间:2020.11\n", + "\n", + "\n", + "本篇将介绍如何通过飞桨高层API实现迁移学习任务,数据集采用VGG出品的[鲜花数据集](https://www.robots.ox.ac.uk/~vgg/data/flowers/),该数据集可通过paddle的API直接下载。\n", + "\n", + "\n", + "\n", + "该案例主要参考了[『跟着雨哥学AI』系列01:初识飞桨框架高层API](https://aistudio.baidu.com/aistudio/projectdetail/1243085?shared=1)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.0.0-rc0\n" + ] + } + ], + "source": [ + "import paddle\n", + "print(paddle.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 1 数据集导入及预处理\n", + "\n", + "\n", + "对于数据预处理与数据加载,飞桨框架提供了许多API,列表如下:\n", + "\n", + "\n", + "\n", + "1、 飞桨框架内置数据集:paddle.vision.datasets内置包含了许多CV领域相关的数据集,直接调用API即可使用;\n", + "\n", + "\n", + "2、 飞桨框架数据预处理:paddle.vision.transforms飞桨框架对于图像预处理的方式,可以快速完成常见的图像预处理的方式,如调整色调、对比度,图像大小等;\n", + "\n", + "\n", + "3、 飞桨框架数据加载:paddle.io.Dataset与paddle.io.DataLoader飞桨框架标准数据加载方式,可以”一键”完成数据的批加载与异步加载;\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "\n", + "#### 1.1 飞桨框架内置数据集\n", + "\n", + "\n", + "首先,飞桨框架将常用的数据集作为领域API对用户开放,对应API所在目录为paddle.vision.datasets包含的数据集如下所示。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "飞桨框架CV领域内置数据集:['DatasetFolder', 'ImageFolder', 'MNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']\n" + ] + } + ], + "source": [ + "print(\"飞桨框架CV领域内置数据集:\" + str(paddle.vision.datasets.__all__))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "#### 1.2 飞桨框架预处理方法\n", + "\n", + "\n", + "飞桨框架提供了20多种数据集预处理的接口,方便开发者快速实现数据增强,目前都集中在 paddle.vision.transforms 目录下,具体包含的API如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "视觉数据预处理方法:['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'to_grayscale', 'normalize']\n" + ] + } + ], + "source": [ + "print('视觉数据预处理方法:' + str(paddle.vision.transforms.__all__))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "#### 1.3 数据集加载及预处理\n", + "\n", + "\n", + "本应用案例使用飞桨内置的鲜花数据集,共102种。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "视觉数据预处理方法:['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'to_grayscale', 'normalize']\n" + ] + } + ], + "source": [ + "print('视觉数据预处理方法:' + str(paddle.vision.transforms.__all__))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddle.vision.datasets import Flowers\n", + "import paddle.vision.transforms as T\n", + "\n", + "class FlowerDataset(Flowers):\n", + " def __init__(self, mode, transform):\n", + " super(FlowerDataset, self).__init__(mode=mode, transform=transform)\n", + "\n", + " def __getitem__(self, index):\n", + " image, label = super(FlowerDataset, self).__getitem__(index)\n", + "\n", + " return image, label - 1\n", + "\n", + "transform = T.Compose([\n", + " T.Resize([224,224]),\n", + " T.Transpose(),\n", + " T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", + " ])\n", + " \n", + "train_dataset = FlowerDataset(mode='train', transform=transform)\n", + "test_dataset = FlowerDataset(mode='test', transform=transform)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 2 模型构建\n", + "\n", + "\n", + "在网络构建模块,飞桨高层API与基础API保持完全的一致,都使用paddle.nn下的API进行组网。这也是尽可能的减少需要暴露的概念,从而提升框架的易学性。飞桨框架 paddle.nn 目录下包含了所有与模型组网相关的API,如卷积相关的 \n", + "\n", + "\n", + "Conv1D、Conv2D、Conv3D,循环神经网络相关的 RNN、LSTM、GRU 等。\n", + "\n", + "对于组网方式,飞桨框架统一支持 Sequential 或 SubClass 的方式进行模型的组建。我们根据实际的使用场景,来选择最合适的组网方式。如针对顺序的线性网络结构我们可以直接使用 Sequential ,相比于 SubClass ,Sequential 可以\n", + "\n", + "\n", + "快速的完成组网。 如果是一些比较复杂的网络结构,我们可以使用 SubClass 定义的方式来进行模型代码编写,在 init 构造函数中进行 Layer 的声明,在 forward 中使用声明的 Layer 变量进行前向计算。通过这种方式,我们可以组建\n", + "\n", + "\n", + "更灵活的网络结构。\n", + "\n", + "\n", + "除了自定义模型结构外,飞桨框架还”贴心”的内置了许多模型,真正的一行代码实现深度学习模型。目前,飞桨框架内置的模型都是CV领域领域的模型,在paddle.vision.models目录下,具体包含如下的模型:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "视觉相关模型: ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet']\n" + ] + } + ], + "source": [ + "print(\"视觉相关模型: \" + str(paddle.vision.models.__all__))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddle.vision.models import resnet18\n", + "\n", + "# build model\n", + "model = resnet18(pretrained=True, num_classes=102, with_pool=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "### 3 模型训练\n", + "\n", + "\n", + "过去常常困扰深度学习开发者的一个问题是,模型训练的代码过于复杂,常常要写好多步骤,才能正确的使程序运行起来,冗长的代码使许多开发者望而却步。\n", + "\n", + "\n", + "现在,飞桨高层API将训练、评估与预测API都进行了封装,直接使用Model.prepare()、Model.fit()、Model.evaluate()、Model.predict()完成模型的训练、评估与预测。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "#### 3.1 配置模型" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from paddle.metric import Accuracy\n", + "\n", + "# 调用飞桨框架的VisualDL模块,保存信息到目录中。\n", + "callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')\n", + "\n", + "model = paddle.Model(model)\n", + "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", + "# model.prepare()\n", + "# 配置模型\n", + "model.prepare(\n", + " optim,\n", + " paddle.nn.CrossEntropyLoss(),\n", + " Accuracy(topk=(1, 2))\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "其实,这里也会用到上文提到的自定义Callback,我们只需要自定义与loss相关的Callback,然后保存loss信息,最后将其转化为图片即可。具体的实现过程如下所示:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# from paddle.metric import Accuracy\n", + "\n", + "# # 调用飞桨框架的VisualDL模块,保存信息到目录中。\n", + "# # callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')\n", + "\n", + "\n", + "# # 自定义Callback 记录训练过程中的loss信息\n", + "# class LossCallback(paddle.callbacks.Callback):\n", + "\n", + "# def on_train_begin(self, logs={}):\n", + "# # 在fit前 初始化losses,用于保存每个batch的loss结果\n", + "# self.losses = []\n", + "# self.acc = []\n", + "\n", + "# def on_train_batch_end(self, step, logs={}):\n", + "# # 每个batch训练完成后调用,把当前loss添加到losses中\n", + "# self.losses.append(logs.get('loss'))\n", + "# self.acc.append(logs.get('acc_top1'))\n", + "# # 初始化一个loss_log 的实例,然后将其作为参数传递给fit\n", + "# loss_log = LossCallback()\n", + "\n", + "# model = paddle.Model(model)\n", + "# optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", + "# # model.prepare()\n", + "# # 配置模型\n", + "# model.prepare(\n", + "# optim,\n", + "# paddle.nn.CrossEntropyLoss(),\n", + "# Accuracy(topk=(1, 2))\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "#### 3.2 模型训练" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", + " return (isinstance(seq, collections.Sequence) and\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 193/193 [==============================] - loss: 1.3456 - acc_top1: 0.6142 - acc_top2: 0.7214 - 20s/step \n", + "Epoch 2/2\n", + "step 193/193 [==============================] - loss: 1.1202 - acc_top1: 0.8453 - acc_top2: 0.9270 - 20s/step \n" + ] + } + ], + "source": [ + "model.fit(train_dataset,\n", + " epochs=2,\n", + " batch_size=32,\n", + " verbose=1,\n", + " #callbacks=loss_log\n", + " callbacks=callback\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "然后我们调用VisualDL工具,在命令行中输入: `visualdl --logdir ./visualdl_log_dir --port 8080`,打开浏览器,输入网址 http://127.0.0.1:8080 就可以在浏览器中看到相关的训练信息,具体如下:![](https://ai-studio-static-online.cdn.bcebos.com/630361eb3f6348aba2b53b13b05c4c5459124871dec44e72a75128f2074913ef)(attachment:1607334281%281%29.jpg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "使用自定义回调函数的可视化信息如下:\n", + "\n", + "\n", + "![](https://ai-studio-static-online.cdn.bcebos.com/f258a59bba904ae1a2dc3fc7ea021d6ec1908f92d9ca443d892b7500bc45368c)\n", + "![](https://ai-studio-static-online.cdn.bcebos.com/a098f36042d9419793a1e7e291315d8c3c55fb0138c144899dc63d7da8f373a7)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "#### 3.3 模型验证" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval begin...\n", + "step 32/32 [==============================] - loss: 1.9857 - acc_top1: 0.6000 - acc_top2: 0.7255 - 20s/step \n", + "Eval samples: 1020\n" + ] + }, + { + "data": { + "text/plain": [ + "{'loss': [1.9857073], 'acc_top1': 0.6, 'acc_top2': 0.7254901960784313}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.evaluate(test_dataset, batch_size=32, verbose=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PaddlePaddle 2.0.0b0 (Python 3.5)", + "language": "python", + "name": "py35-paddle1.2.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}