diff --git a/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb new file mode 100644 index 00000000..4e544b82 --- /dev/null +++ b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb @@ -0,0 +1,666 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MNIST数据集使用LeNet进行图像分类\n", + "本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n", + "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.0.0-alpha0\n" + ] + } + ], + "source": [ + "import paddle\n", + "print(paddle.__version__)\n", + "paddle.enable_imperative()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 加载数据集\n", + "我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "download training data and load training data\n", + "load finished\n", + "\n" + ] + } + ], + "source": [ + "print('download training data and load training data')\n", + "train_dataset = paddle.incubate.hapi.datasets.MNIST(mode='train')\n", + "test_dataset = paddle.incubate.hapi.datasets.MNIST(mode='test')\n", + "print('load finished')\n", + "print(type(train_dataset[0][0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "取训练集中的一条数据看一下。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_data0 label is: [5]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkhBFmR9j3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H9rYhou3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A9ZGSHFUpBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbAmqetUAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n", + "train_data0 = train_data0.transpose(1,2,0)\n", + "plt.figure(figsize=(2,2))\n", + "plt.imshow(train_data0, cmap=plt.cm.binary)\n", + "print('train_data0 label is: ' + str(train_label_0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2.组网&训练方案1\n", + "paddle支持用model类,直接完成模型的训练,具体如下。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 首先需要继承Model来自定义LeNet网络。" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class LeNet(paddle.incubate.hapi.model.Model):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n", + " self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n", + " self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10, act='softmax')\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.max_pool1(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " x = self.linear3(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 初始化Model,并定义相关的参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.incubate.hapi.model import Input\n", + "from paddle.incubate.hapi.loss import CrossEntropy\n", + "from paddle.incubate.hapi.metrics import Accuracy\n", + "\n", + "inputs = [Input([None, 1, 28, 28], 'float32', name='image')]\n", + "labels = [Input([None, 1], 'int64', name='label')]\n", + "model = LeNet()\n", + "optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n", + "\n", + "model.prepare(\n", + " optim,\n", + " CrossEntropy(),\n", + " Accuracy(topk=(1, 2)),\n", + " inputs=inputs,\n", + " labels=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 使用fit来训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/2\n", + "step 10/938 - loss: 2.1912 - acc_top1: 0.2719 - acc_top2: 0.4109 - 16ms/step\n", + "step 20/938 - loss: 1.6389 - acc_top1: 0.4109 - acc_top2: 0.5367 - 15ms/step\n", + "step 30/938 - loss: 1.1486 - acc_top1: 0.4797 - acc_top2: 0.6135 - 15ms/step\n", + "step 40/938 - loss: 0.7755 - acc_top1: 0.5484 - acc_top2: 0.6770 - 15ms/step\n", + "step 50/938 - loss: 0.7651 - acc_top1: 0.5975 - acc_top2: 0.7266 - 15ms/step\n", + "step 60/938 - loss: 0.3837 - acc_top1: 0.6393 - acc_top2: 0.7617 - 15ms/step\n", + "step 70/938 - loss: 0.6532 - acc_top1: 0.6712 - acc_top2: 0.7888 - 15ms/step\n", + "step 80/938 - loss: 0.3394 - acc_top1: 0.6969 - acc_top2: 0.8107 - 15ms/step\n", + "step 90/938 - loss: 0.2527 - acc_top1: 0.7189 - acc_top2: 0.8283 - 15ms/step\n", + "step 100/938 - loss: 0.2055 - acc_top1: 0.7389 - acc_top2: 0.8427 - 14ms/step\n", + "step 110/938 - loss: 0.3987 - acc_top1: 0.7531 - acc_top2: 0.8536 - 14ms/step\n", + "step 120/938 - loss: 0.2372 - acc_top1: 0.7660 - acc_top2: 0.8622 - 14ms/step\n", + "step 130/938 - loss: 0.4071 - acc_top1: 0.7780 - acc_top2: 0.8708 - 14ms/step\n", + "step 140/938 - loss: 0.1315 - acc_top1: 0.7895 - acc_top2: 0.8780 - 14ms/step\n", + "step 150/938 - loss: 0.3168 - acc_top1: 0.7981 - acc_top2: 0.8843 - 15ms/step\n", + "step 160/938 - loss: 0.2782 - acc_top1: 0.8063 - acc_top2: 0.8901 - 15ms/step\n", + "step 170/938 - loss: 0.2030 - acc_top1: 0.8144 - acc_top2: 0.8956 - 15ms/step\n", + "step 180/938 - loss: 0.2336 - acc_top1: 0.8203 - acc_top2: 0.9000 - 15ms/step\n", + "step 190/938 - loss: 0.5915 - acc_top1: 0.8260 - acc_top2: 0.9038 - 15ms/step\n", + "step 200/938 - loss: 0.4995 - acc_top1: 0.8310 - acc_top2: 0.9076 - 15ms/step\n", + "step 210/938 - loss: 0.2190 - acc_top1: 0.8359 - acc_top2: 0.9106 - 15ms/step\n", + "step 220/938 - loss: 0.1835 - acc_top1: 0.8397 - acc_top2: 0.9130 - 15ms/step\n", + "step 230/938 - loss: 0.1321 - acc_top1: 0.8442 - acc_top2: 0.9159 - 15ms/step\n", + "step 240/938 - loss: 0.2406 - acc_top1: 0.8478 - acc_top2: 0.9183 - 15ms/step\n", + "step 250/938 - loss: 0.1245 - acc_top1: 0.8518 - acc_top2: 0.9209 - 15ms/step\n", + "step 260/938 - loss: 0.1570 - acc_top1: 0.8559 - acc_top2: 0.9236 - 15ms/step\n", + "step 270/938 - loss: 0.1647 - acc_top1: 0.8593 - acc_top2: 0.9259 - 15ms/step\n", + "step 280/938 - loss: 0.1876 - acc_top1: 0.8625 - acc_top2: 0.9281 - 14ms/step\n", + "step 290/938 - loss: 0.2247 - acc_top1: 0.8650 - acc_top2: 0.9300 - 15ms/step\n", + "step 300/938 - loss: 0.2070 - acc_top1: 0.8679 - acc_top2: 0.9318 - 15ms/step\n", + "step 310/938 - loss: 0.1122 - acc_top1: 0.8701 - acc_top2: 0.9333 - 14ms/step\n", + "step 320/938 - loss: 0.0857 - acc_top1: 0.8729 - acc_top2: 0.9351 - 14ms/step\n", + "step 330/938 - loss: 0.2414 - acc_top1: 0.8751 - acc_top2: 0.9365 - 14ms/step\n", + "step 340/938 - loss: 0.2631 - acc_top1: 0.8774 - acc_top2: 0.9380 - 14ms/step\n", + "step 350/938 - loss: 0.1347 - acc_top1: 0.8796 - acc_top2: 0.9396 - 14ms/step\n", + "step 360/938 - loss: 0.2295 - acc_top1: 0.8816 - acc_top2: 0.9409 - 14ms/step\n", + "step 370/938 - loss: 0.2971 - acc_top1: 0.8842 - acc_top2: 0.9423 - 14ms/step\n", + "step 380/938 - loss: 0.1623 - acc_top1: 0.8863 - acc_top2: 0.9436 - 14ms/step\n", + "step 390/938 - loss: 0.1020 - acc_top1: 0.8880 - acc_top2: 0.9448 - 14ms/step\n", + "step 400/938 - loss: 0.0716 - acc_top1: 0.8895 - acc_top2: 0.9459 - 14ms/step\n", + "step 410/938 - loss: 0.0889 - acc_top1: 0.8914 - acc_top2: 0.9469 - 14ms/step\n", + "step 420/938 - loss: 0.1010 - acc_top1: 0.8931 - acc_top2: 0.9478 - 14ms/step\n", + "step 430/938 - loss: 0.0486 - acc_top1: 0.8945 - acc_top2: 0.9487 - 14ms/step\n", + "step 440/938 - loss: 0.1723 - acc_top1: 0.8958 - acc_top2: 0.9495 - 14ms/step\n", + "step 450/938 - loss: 0.2270 - acc_top1: 0.8974 - acc_top2: 0.9503 - 14ms/step\n", + "step 460/938 - loss: 0.1197 - acc_top1: 0.8987 - acc_top2: 0.9512 - 14ms/step\n", + "step 470/938 - loss: 0.2837 - acc_top1: 0.9002 - acc_top2: 0.9519 - 14ms/step\n", + "step 480/938 - loss: 0.1091 - acc_top1: 0.9017 - acc_top2: 0.9528 - 14ms/step\n", + "step 490/938 - loss: 0.1397 - acc_top1: 0.9029 - acc_top2: 0.9535 - 14ms/step\n", + "step 500/938 - loss: 0.1034 - acc_top1: 0.9040 - acc_top2: 0.9543 - 14ms/step\n", + "step 510/938 - loss: 0.0095 - acc_top1: 0.9054 - acc_top2: 0.9550 - 14ms/step\n", + "step 520/938 - loss: 0.0092 - acc_top1: 0.9068 - acc_top2: 0.9558 - 14ms/step\n", + "step 530/938 - loss: 0.0633 - acc_top1: 0.9077 - acc_top2: 0.9565 - 14ms/step\n", + "step 540/938 - loss: 0.0936 - acc_top1: 0.9086 - acc_top2: 0.9571 - 14ms/step\n", + "step 550/938 - loss: 0.1180 - acc_top1: 0.9097 - acc_top2: 0.9577 - 14ms/step\n", + "step 560/938 - loss: 0.1600 - acc_top1: 0.9106 - acc_top2: 0.9583 - 14ms/step\n", + "step 570/938 - loss: 0.1338 - acc_top1: 0.9118 - acc_top2: 0.9590 - 14ms/step\n", + "step 580/938 - loss: 0.0496 - acc_top1: 0.9128 - acc_top2: 0.9595 - 14ms/step\n", + "step 590/938 - loss: 0.0651 - acc_top1: 0.9138 - acc_top2: 0.9600 - 14ms/step\n", + "step 600/938 - loss: 0.1306 - acc_top1: 0.9147 - acc_top2: 0.9605 - 14ms/step\n", + "step 610/938 - loss: 0.0744 - acc_top1: 0.9157 - acc_top2: 0.9610 - 14ms/step\n", + "step 620/938 - loss: 0.1679 - acc_top1: 0.9166 - acc_top2: 0.9616 - 14ms/step\n", + "step 630/938 - loss: 0.0789 - acc_top1: 0.9173 - acc_top2: 0.9621 - 14ms/step\n", + "step 640/938 - loss: 0.0767 - acc_top1: 0.9182 - acc_top2: 0.9626 - 14ms/step\n", + "step 650/938 - loss: 0.1776 - acc_top1: 0.9188 - acc_top2: 0.9630 - 14ms/step\n", + "step 660/938 - loss: 0.1371 - acc_top1: 0.9196 - acc_top2: 0.9634 - 14ms/step\n", + "step 670/938 - loss: 0.1011 - acc_top1: 0.9204 - acc_top2: 0.9639 - 14ms/step\n", + "step 680/938 - loss: 0.0447 - acc_top1: 0.9209 - acc_top2: 0.9642 - 14ms/step\n", + "step 690/938 - loss: 0.0230 - acc_top1: 0.9217 - acc_top2: 0.9646 - 14ms/step\n", + "step 700/938 - loss: 0.0541 - acc_top1: 0.9224 - acc_top2: 0.9649 - 14ms/step\n", + "step 710/938 - loss: 0.1395 - acc_top1: 0.9231 - acc_top2: 0.9653 - 14ms/step\n", + "step 720/938 - loss: 0.0426 - acc_top1: 0.9238 - acc_top2: 0.9657 - 14ms/step\n", + "step 730/938 - loss: 0.0540 - acc_top1: 0.9247 - acc_top2: 0.9660 - 14ms/step\n", + "step 740/938 - loss: 0.1132 - acc_top1: 0.9253 - acc_top2: 0.9664 - 14ms/step\n", + "step 750/938 - loss: 0.0088 - acc_top1: 0.9261 - acc_top2: 0.9668 - 14ms/step\n", + "step 760/938 - loss: 0.0282 - acc_top1: 0.9266 - acc_top2: 0.9672 - 14ms/step\n", + "step 770/938 - loss: 0.1233 - acc_top1: 0.9272 - acc_top2: 0.9675 - 14ms/step\n", + "step 780/938 - loss: 0.2208 - acc_top1: 0.9275 - acc_top2: 0.9677 - 14ms/step\n", + "step 790/938 - loss: 0.0599 - acc_top1: 0.9281 - acc_top2: 0.9680 - 14ms/step\n", + "step 800/938 - loss: 0.0270 - acc_top1: 0.9287 - acc_top2: 0.9683 - 14ms/step\n", + "step 810/938 - loss: 0.1546 - acc_top1: 0.9291 - acc_top2: 0.9687 - 14ms/step\n", + "step 820/938 - loss: 0.0252 - acc_top1: 0.9297 - acc_top2: 0.9689 - 14ms/step\n", + "step 830/938 - loss: 0.0276 - acc_top1: 0.9304 - acc_top2: 0.9693 - 14ms/step\n", + "step 840/938 - loss: 0.0620 - acc_top1: 0.9309 - acc_top2: 0.9695 - 14ms/step\n", + "step 850/938 - loss: 0.0505 - acc_top1: 0.9314 - acc_top2: 0.9699 - 14ms/step\n", + "step 860/938 - loss: 0.0156 - acc_top1: 0.9319 - acc_top2: 0.9701 - 14ms/step\n", + "step 870/938 - loss: 0.0229 - acc_top1: 0.9325 - acc_top2: 0.9704 - 14ms/step\n", + "step 880/938 - loss: 0.0498 - acc_top1: 0.9330 - acc_top2: 0.9707 - 14ms/step\n", + "step 890/938 - loss: 0.0183 - acc_top1: 0.9335 - acc_top2: 0.9710 - 14ms/step\n", + "step 900/938 - loss: 0.1282 - acc_top1: 0.9339 - acc_top2: 0.9712 - 14ms/step\n", + "step 910/938 - loss: 0.0426 - acc_top1: 0.9342 - acc_top2: 0.9715 - 14ms/step\n", + "step 920/938 - loss: 0.0641 - acc_top1: 0.9347 - acc_top2: 0.9717 - 14ms/step\n", + "step 930/938 - loss: 0.0745 - acc_top1: 0.9351 - acc_top2: 0.9719 - 14ms/step\n", + "step 938/938 - loss: 0.0118 - acc_top1: 0.9354 - acc_top2: 0.9721 - 14ms/step\n", + "save checkpoint at mnist_checkpoint/0\n", + "Eval begin...\n", + "step 10/157 - loss: 0.1032 - acc_top1: 0.9828 - acc_top2: 0.9969 - 5ms/step\n", + "step 20/157 - loss: 0.2664 - acc_top1: 0.9781 - acc_top2: 0.9953 - 5ms/step\n", + "step 30/157 - loss: 0.1626 - acc_top1: 0.9766 - acc_top2: 0.9943 - 5ms/step\n", + "step 40/157 - loss: 0.0247 - acc_top1: 0.9734 - acc_top2: 0.9926 - 5ms/step\n", + "step 50/157 - loss: 0.0225 - acc_top1: 0.9738 - acc_top2: 0.9925 - 5ms/step\n", + "step 60/157 - loss: 0.2119 - acc_top1: 0.9737 - acc_top2: 0.9927 - 5ms/step\n", + "step 70/157 - loss: 0.0559 - acc_top1: 0.9723 - acc_top2: 0.9920 - 5ms/step\n", + "step 80/157 - loss: 0.0329 - acc_top1: 0.9725 - acc_top2: 0.9918 - 5ms/step\n", + "step 90/157 - loss: 0.1064 - acc_top1: 0.9741 - acc_top2: 0.9925 - 5ms/step\n", + "step 100/157 - loss: 0.0027 - acc_top1: 0.9744 - acc_top2: 0.9923 - 5ms/step\n", + "step 110/157 - loss: 0.0044 - acc_top1: 0.9750 - acc_top2: 0.9925 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 120/157 - loss: 0.0093 - acc_top1: 0.9768 - acc_top2: 0.9931 - 5ms/step\n", + "step 130/157 - loss: 0.1247 - acc_top1: 0.9774 - acc_top2: 0.9935 - 5ms/step\n", + "step 140/157 - loss: 0.0031 - acc_top1: 0.9785 - acc_top2: 0.9940 - 5ms/step\n", + "step 150/157 - loss: 0.0495 - acc_top1: 0.9794 - acc_top2: 0.9944 - 5ms/step\n", + "step 157/157 - loss: 0.0020 - acc_top1: 0.9790 - acc_top2: 0.9944 - 5ms/step\n", + "Eval samples: 10000\n", + "Epoch 2/2\n", + "step 10/938 - loss: 0.1735 - acc_top1: 0.9766 - acc_top2: 0.9938 - 16ms/step\n", + "step 20/938 - loss: 0.0723 - acc_top1: 0.9750 - acc_top2: 0.9922 - 15ms/step\n", + "step 30/938 - loss: 0.0593 - acc_top1: 0.9781 - acc_top2: 0.9927 - 15ms/step\n", + "step 40/938 - loss: 0.1243 - acc_top1: 0.9793 - acc_top2: 0.9938 - 15ms/step\n", + "step 50/938 - loss: 0.0127 - acc_top1: 0.9797 - acc_top2: 0.9944 - 15ms/step\n", + "step 60/938 - loss: 0.0319 - acc_top1: 0.9779 - acc_top2: 0.9938 - 15ms/step\n", + "step 70/938 - loss: 0.0404 - acc_top1: 0.9783 - acc_top2: 0.9946 - 15ms/step\n", + "step 80/938 - loss: 0.1120 - acc_top1: 0.9781 - acc_top2: 0.9943 - 15ms/step\n", + "step 90/938 - loss: 0.0222 - acc_top1: 0.9780 - acc_top2: 0.9944 - 15ms/step\n", + "step 100/938 - loss: 0.0726 - acc_top1: 0.9788 - acc_top2: 0.9948 - 15ms/step\n", + "step 110/938 - loss: 0.0255 - acc_top1: 0.9790 - acc_top2: 0.9952 - 15ms/step\n", + "step 120/938 - loss: 0.2556 - acc_top1: 0.9790 - acc_top2: 0.9948 - 15ms/step\n", + "step 130/938 - loss: 0.0795 - acc_top1: 0.9786 - acc_top2: 0.9945 - 15ms/step\n", + "step 140/938 - loss: 0.1106 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 150/938 - loss: 0.0564 - acc_top1: 0.9784 - acc_top2: 0.9946 - 15ms/step\n", + "step 160/938 - loss: 0.1016 - acc_top1: 0.9784 - acc_top2: 0.9947 - 15ms/step\n", + "step 170/938 - loss: 0.0665 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n", + "step 180/938 - loss: 0.0443 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n", + "step 190/938 - loss: 0.0696 - acc_top1: 0.9789 - acc_top2: 0.9947 - 15ms/step\n", + "step 200/938 - loss: 0.0552 - acc_top1: 0.9791 - acc_top2: 0.9948 - 15ms/step\n", + "step 210/938 - loss: 0.1540 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n", + "step 220/938 - loss: 0.0422 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n", + "step 230/938 - loss: 0.2994 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 240/938 - loss: 0.0246 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 250/938 - loss: 0.0802 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n", + "step 260/938 - loss: 0.1142 - acc_top1: 0.9787 - acc_top2: 0.9947 - 15ms/step\n", + "step 270/938 - loss: 0.0195 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n", + "step 280/938 - loss: 0.0559 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 290/938 - loss: 0.1101 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n", + "step 300/938 - loss: 0.0078 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n", + "step 310/938 - loss: 0.0877 - acc_top1: 0.9789 - acc_top2: 0.9944 - 15ms/step\n", + "step 320/938 - loss: 0.0919 - acc_top1: 0.9790 - acc_top2: 0.9945 - 15ms/step\n", + "step 330/938 - loss: 0.0395 - acc_top1: 0.9789 - acc_top2: 0.9945 - 15ms/step\n", + "step 340/938 - loss: 0.1892 - acc_top1: 0.9787 - acc_top2: 0.9945 - 15ms/step\n", + "step 350/938 - loss: 0.0457 - acc_top1: 0.9784 - acc_top2: 0.9944 - 15ms/step\n", + "step 360/938 - loss: 0.1036 - acc_top1: 0.9786 - acc_top2: 0.9944 - 15ms/step\n", + "step 370/938 - loss: 0.0614 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 380/938 - loss: 0.2316 - acc_top1: 0.9787 - acc_top2: 0.9944 - 15ms/step\n", + "step 390/938 - loss: 0.0126 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 400/938 - loss: 0.0614 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n", + "step 410/938 - loss: 0.0374 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 420/938 - loss: 0.0924 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 430/938 - loss: 0.0151 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 440/938 - loss: 0.0223 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n", + "step 450/938 - loss: 0.0111 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n", + "step 460/938 - loss: 0.0112 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n", + "step 470/938 - loss: 0.0239 - acc_top1: 0.9794 - acc_top2: 0.9947 - 15ms/step\n", + "step 480/938 - loss: 0.0821 - acc_top1: 0.9795 - acc_top2: 0.9948 - 15ms/step\n", + "step 490/938 - loss: 0.0493 - acc_top1: 0.9796 - acc_top2: 0.9948 - 15ms/step\n", + "step 500/938 - loss: 0.0627 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 510/938 - loss: 0.0331 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 520/938 - loss: 0.0831 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 530/938 - loss: 0.0687 - acc_top1: 0.9796 - acc_top2: 0.9949 - 15ms/step\n", + "step 540/938 - loss: 0.1556 - acc_top1: 0.9794 - acc_top2: 0.9949 - 15ms/step\n", + "step 550/938 - loss: 0.2394 - acc_top1: 0.9795 - acc_top2: 0.9950 - 15ms/step\n", + "step 560/938 - loss: 0.0353 - acc_top1: 0.9794 - acc_top2: 0.9950 - 15ms/step\n", + "step 570/938 - loss: 0.0179 - acc_top1: 0.9794 - acc_top2: 0.9951 - 15ms/step\n", + "step 580/938 - loss: 0.0307 - acc_top1: 0.9796 - acc_top2: 0.9951 - 15ms/step\n", + "step 590/938 - loss: 0.0806 - acc_top1: 0.9796 - acc_top2: 0.9952 - 15ms/step\n", + "step 600/938 - loss: 0.0320 - acc_top1: 0.9796 - acc_top2: 0.9953 - 15ms/step\n", + "step 610/938 - loss: 0.0201 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n", + "step 620/938 - loss: 0.1524 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n", + "step 630/938 - loss: 0.0062 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n", + "step 640/938 - loss: 0.0908 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n", + "step 650/938 - loss: 0.0467 - acc_top1: 0.9799 - acc_top2: 0.9954 - 15ms/step\n", + "step 660/938 - loss: 0.0156 - acc_top1: 0.9801 - acc_top2: 0.9954 - 15ms/step\n", + "step 670/938 - loss: 0.0318 - acc_top1: 0.9802 - acc_top2: 0.9955 - 15ms/step\n", + "step 680/938 - loss: 0.0133 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 690/938 - loss: 0.0651 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 700/938 - loss: 0.0052 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 710/938 - loss: 0.1208 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 720/938 - loss: 0.1519 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 730/938 - loss: 0.0954 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 740/938 - loss: 0.0059 - acc_top1: 0.9806 - acc_top2: 0.9955 - 15ms/step\n", + "step 750/938 - loss: 0.1000 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 760/938 - loss: 0.0629 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 770/938 - loss: 0.0182 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 780/938 - loss: 0.0215 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 790/938 - loss: 0.0418 - acc_top1: 0.9804 - acc_top2: 0.9956 - 15ms/step\n", + "step 800/938 - loss: 0.0132 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 810/938 - loss: 0.0546 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 820/938 - loss: 0.0373 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 830/938 - loss: 0.0965 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 840/938 - loss: 0.0143 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 850/938 - loss: 0.0578 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 860/938 - loss: 0.0205 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 870/938 - loss: 0.0384 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 880/938 - loss: 0.0157 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 890/938 - loss: 0.0457 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 900/938 - loss: 0.0202 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 910/938 - loss: 0.0240 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 920/938 - loss: 0.0585 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 930/938 - loss: 0.0414 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n", + "step 938/938 - loss: 0.0180 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n", + "save checkpoint at mnist_checkpoint/1\n", + "Eval begin...\n", + "step 10/157 - loss: 0.1093 - acc_top1: 0.9828 - acc_top2: 0.9984 - 5ms/step\n", + "step 20/157 - loss: 0.2292 - acc_top1: 0.9789 - acc_top2: 0.9969 - 5ms/step\n", + "step 30/157 - loss: 0.1203 - acc_top1: 0.9797 - acc_top2: 0.9969 - 5ms/step\n", + "step 40/157 - loss: 0.0068 - acc_top1: 0.9773 - acc_top2: 0.9961 - 5ms/step\n", + "step 50/157 - loss: 0.0049 - acc_top1: 0.9775 - acc_top2: 0.9959 - 5ms/step\n", + "step 60/157 - loss: 0.0399 - acc_top1: 0.9779 - acc_top2: 0.9956 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 70/157 - loss: 0.0299 - acc_top1: 0.9768 - acc_top2: 0.9953 - 5ms/step\n", + "step 80/157 - loss: 0.0108 - acc_top1: 0.9771 - acc_top2: 0.9955 - 5ms/step\n", + "step 90/157 - loss: 0.0209 - acc_top1: 0.9793 - acc_top2: 0.9958 - 5ms/step\n", + "step 100/157 - loss: 0.0031 - acc_top1: 0.9806 - acc_top2: 0.9962 - 5ms/step\n", + "step 110/157 - loss: 4.0509e-04 - acc_top1: 0.9808 - acc_top2: 0.9962 - 5ms/step\n", + "step 120/157 - loss: 8.9143e-04 - acc_top1: 0.9820 - acc_top2: 0.9965 - 5ms/step\n", + "step 130/157 - loss: 0.0119 - acc_top1: 0.9833 - acc_top2: 0.9968 - 5ms/step\n", + "step 140/157 - loss: 6.7999e-04 - acc_top1: 0.9844 - acc_top2: 0.9970 - 5ms/step\n", + "step 150/157 - loss: 0.0047 - acc_top1: 0.9853 - acc_top2: 0.9972 - 5ms/step\n", + "step 157/157 - loss: 1.6522e-04 - acc_top1: 0.9847 - acc_top2: 0.9973 - 5ms/step\n", + "Eval samples: 10000\n", + "save checkpoint at mnist_checkpoint/final\n" + ] + } + ], + "source": [ + "model.fit(train_dataset,\n", + " test_dataset,\n", + " epochs=2,\n", + " batch_size=64,\n", + " save_dir='mnist_checkpoint')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 组网&训练方式1结束\n", + "以上就是组网&训练方式1,可以非常快速的完成网络模型的构建与训练。此外,paddle还可以用下面的方式来完成模型的训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3.组网&训练方式2\n", + "方式1可以快速便捷的完成组网&训练,将细节都隐藏了起来。而方式2则可以用最基本的方式,完成模型的组网与训练。具体如下。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 通过继承Layer的方式来构建模型" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class LeNet(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n", + " self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n", + " self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10,act='softmax')\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.max_pool1(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " x = self.linear3(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [2.2982373], acc is: [0.15625]\n", + "epoch: 0, batch_id: 100, loss is: [0.25794172], acc is: [0.96875]\n", + "epoch: 0, batch_id: 200, loss is: [0.25025752], acc is: [0.984375]\n", + "epoch: 0, batch_id: 300, loss is: [0.17673397], acc is: [0.984375]\n", + "epoch: 0, batch_id: 400, loss is: [0.09535598], acc is: [1.]\n", + "epoch: 0, batch_id: 500, loss is: [0.08496016], acc is: [1.]\n", + "epoch: 0, batch_id: 600, loss is: [0.14111154], acc is: [0.984375]\n", + "epoch: 0, batch_id: 700, loss is: [0.07322718], acc is: [0.984375]\n", + "epoch: 0, batch_id: 800, loss is: [0.2417614], acc is: [0.984375]\n", + "epoch: 0, batch_id: 900, loss is: [0.10721541], acc is: [1.]\n", + "epoch: 1, batch_id: 0, loss is: [0.02449418], acc is: [1.]\n", + "epoch: 1, batch_id: 100, loss is: [0.151768], acc is: [0.984375]\n", + "epoch: 1, batch_id: 200, loss is: [0.06956144], acc is: [0.984375]\n", + "epoch: 1, batch_id: 300, loss is: [0.2008793], acc is: [1.]\n", + "epoch: 1, batch_id: 400, loss is: [0.03839134], acc is: [1.]\n", + "epoch: 1, batch_id: 500, loss is: [0.0217573], acc is: [1.]\n", + "epoch: 1, batch_id: 600, loss is: [0.10977131], acc is: [0.984375]\n", + "epoch: 1, batch_id: 700, loss is: [0.02774046], acc is: [1.]\n", + "epoch: 1, batch_id: 800, loss is: [0.13530938], acc is: [0.984375]\n", + "epoch: 1, batch_id: 900, loss is: [0.0282761], acc is: [1.]\n" + ] + } + ], + "source": [ + "import paddle\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def train(model):\n", + " model.train()\n", + " epochs = 2\n", + " batch_size = 64\n", + " optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = LeNet()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 对模型进行验证" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch_id: 0, loss is: [0.0054796], acc is: [1.]\n", + "batch_id: 100, loss is: [0.12248081], acc is: [0.984375]\n", + "batch_id: 200, loss is: [0.06583288], acc is: [1.]\n", + "batch_id: 300, loss is: [0.07927508], acc is: [1.]\n", + "batch_id: 400, loss is: [0.02623187], acc is: [1.]\n", + "batch_id: 500, loss is: [0.02039231], acc is: [1.]\n", + "batch_id: 600, loss is: [0.03374948], acc is: [1.]\n", + "batch_id: 700, loss is: [0.05141395], acc is: [1.]\n", + "batch_id: 800, loss is: [0.1005884], acc is: [1.]\n", + "batch_id: 900, loss is: [0.03581202], acc is: [1.]\n" + ] + } + ], + "source": [ + "import paddle\n", + "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def test(model):\n", + " model.eval()\n", + " batch_size = 64\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + "test(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 组网&训练方式2结束\n", + "以上就是组网&训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paddle2.0_docs/n_gram_model/n_gram_model.ipynb b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb new file mode 100644 index 00000000..d46e601f --- /dev/null +++ b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 用N-Gram模型在莎士比亚文集中训练word embedding\n", + "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n", + "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n", + "本示例在莎士比亚文集上实现了trigram。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.0.0-alpha0'" + ] + }, + "execution_count": 189, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "paddle.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据集&&相关参数\n", + "训练数据集采用了莎士比亚文集,[下载](https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt),保存为txt格式即可。
\n", + "context_size设为2,意味着是trigram。embedding_dim设为256。" + ] + }, + { + "cell_type": "code", + "execution_count": 190, + "metadata": {}, + "outputs": [], + "source": [ + "embedding_dim = 256\n", + "context_size = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 191, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of text: 1115394 characters\n" + ] + } + ], + "source": [ + "# 文件路径\n", + "path_to_file = './shakespeare.txt'\n", + "test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", + "\n", + "# 文本长度是指文本中的字符个数\n", + "print ('Length of text: {} characters'.format(len(test_sentence)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 去除标点符号\n", + "用`string`库中的punctuation,完成英文符号的替换。" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n" + ] + } + ], + "source": [ + "from string import punctuation\n", + "process_dicts={i:'' for i in punctuation}\n", + "print(process_dicts)" + ] + }, + { + "cell_type": "code", + "execution_count": 193, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12848\n" + ] + } + ], + "source": [ + "punc_table = str.maketrans(dicts)\n", + "test_sentence = test_sentence.translate(punc_table)\n", + "test_sentence = test_sentence.lower().split()\n", + "vocab = set(test_sentence)\n", + "print(len(vocab))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据预处理\n", + "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。" + ] + }, + { + "cell_type": "code", + "execution_count": 194, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[['first', 'citizen'], 'before'], [['citizen', 'before'], 'we'], [['before', 'we'], 'proceed']]\n" + ] + } + ], + "source": [ + "trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n", + " for i in range(len(test_sentence) - 2)]\n", + "\n", + "word_to_idx = {word: i for i, word in enumerate(vocab)}\n", + "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n", + "# 看一下数据集\n", + "print(trigram[:3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 构建`Dataset`类 加载数据\n", + "用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。" + ] + }, + { + "cell_type": "code", + "execution_count": 184, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "batch_size = 256\n", + "paddle.enable_imperative()\n", + "class TrainDataset(paddle.io.Dataset):\n", + " def __init__(self, tuple_data):\n", + " self.tuple_data = tuple_data\n", + "\n", + " def __getitem__(self, idx):\n", + " data = self.tuple_data[idx][0]\n", + " label = self.tuple_data[idx][1]\n", + " data = list(map(lambda w: word_to_idx[w], data))\n", + " label = word_to_idx[label]\n", + " \n", + " return data, label\n", + " \n", + " def __len__(self):\n", + " return len(self.tuple_data)\n", + " \n", + "train_dataset = TrainDataset(trigram)\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.fluid.cpu_places(), return_list=True,\\\n", + " shuffle=True, batch_size=batch_size, drop_last=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 组网&训练\n", + "这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。" + ] + }, + { + "cell_type": "code", + "execution_count": 185, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "class NGramModel(paddle.nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, context_size):\n", + " super(NGramModel, self).__init__()\n", + " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", + " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 1024)\n", + " self.linear2 = paddle.nn.Linear(1024, vocab_size)\n", + "\n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x = paddle.reshape(x, [-1, context_size * embedding_dim])\n", + " x = self.linear1(x)\n", + " x = paddle.nn.functional.relu(x)\n", + " x = self.linear2(x)\n", + " x = paddle.nn.functional.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 定义`train()`函数,对模型进行训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 195, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [9.4609375]\n", + "epoch: 0, batch_id: 100, loss is: [6.8079987]\n", + "epoch: 0, batch_id: 200, loss is: [7.15846]\n", + "epoch: 0, batch_id: 300, loss is: [6.536172]\n", + "epoch: 0, batch_id: 400, loss is: [6.847218]\n", + "epoch: 0, batch_id: 500, loss is: [6.6856213]\n", + "epoch: 0, batch_id: 600, loss is: [7.023284]\n", + "epoch: 0, batch_id: 700, loss is: [6.668572]\n", + "epoch: 1, batch_id: 0, loss is: [6.251506]\n", + "epoch: 1, batch_id: 100, loss is: [6.3487673]\n", + "epoch: 1, batch_id: 200, loss is: [6.3545203]\n", + "epoch: 1, batch_id: 300, loss is: [6.2171617]\n", + "epoch: 1, batch_id: 400, loss is: [6.0640473]\n", + "epoch: 1, batch_id: 500, loss is: [6.469482]\n", + "epoch: 1, batch_id: 600, loss is: [6.342491]\n", + "epoch: 1, batch_id: 700, loss is: [6.3045797]\n", + "epoch: 2, batch_id: 0, loss is: [5.993376]\n", + "epoch: 2, batch_id: 100, loss is: [5.8828726]\n", + "epoch: 2, batch_id: 200, loss is: [5.9760466]\n", + "epoch: 2, batch_id: 300, loss is: [5.756211]\n", + "epoch: 2, batch_id: 400, loss is: [6.112233]\n", + "epoch: 2, batch_id: 500, loss is: [5.8354917]\n", + "epoch: 2, batch_id: 600, loss is: [5.8915625]\n", + "epoch: 2, batch_id: 700, loss is: [5.820907]\n", + "epoch: 3, batch_id: 0, loss is: [5.9928036]\n", + "epoch: 3, batch_id: 100, loss is: [5.8935637]\n", + "epoch: 3, batch_id: 200, loss is: [6.1485195]\n", + "epoch: 3, batch_id: 300, loss is: [5.932296]\n", + "epoch: 3, batch_id: 400, loss is: [5.8619576]\n", + "epoch: 3, batch_id: 500, loss is: [6.1057215]\n", + "epoch: 3, batch_id: 600, loss is: [5.9520254]\n", + "epoch: 3, batch_id: 700, loss is: [5.6877956]\n", + "epoch: 4, batch_id: 0, loss is: [5.550747]\n", + "epoch: 4, batch_id: 100, loss is: [5.757982]\n", + "epoch: 4, batch_id: 200, loss is: [6.2809753]\n", + "epoch: 4, batch_id: 300, loss is: [5.860643]\n", + "epoch: 4, batch_id: 400, loss is: [5.9789114]\n", + "epoch: 4, batch_id: 500, loss is: [5.763079]\n", + "epoch: 4, batch_id: 600, loss is: [5.85236]\n", + "epoch: 4, batch_id: 700, loss is: [6.244775]\n" + ] + } + ], + "source": [ + "vocab_size = len(vocab)\n", + "epochs = 5\n", + "losses = []\n", + "def train(model):\n", + " model.train()\n", + " optim = paddle.optimizer.Adam(learning_rate=0.01, parameter_list=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " losses.append(avg_loss.numpy())\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = NGramModel(vocab_size, embedding_dim, context_size)\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 打印loss下降曲线\n", + "通过可视化loss的曲线,可以看到模型训练的效果。" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 187, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAA7AklEQVR4nO3dd3ic1ZX48e+ZGfVerWZJtmXcuzE2YIqppoQQSH6QkEpCCCSBZJ/dhd0sKZvNbnazaZDgOCFsCnEKHWLABAjFgG25d1uSJVmy1ZvVy9zfH1M0M5qRRvLI0ojzeR49Hr3vq3eupdHRnXPPvVeMMSillAp/lolugFJKqdDQgK6UUlOEBnSllJoiNKArpdQUoQFdKaWmCNtEPXF6eropLCycqKdXSqmwtHPnzgZjTIa/cxMW0AsLCykuLp6op1dKqbAkIhWBzmnKRSmlpggN6EopNUVoQFdKqSkiqIAuIveJyAEROSgi9/s5f5mItIrIHufHQyFvqVJKqWGNOCgqIguBLwCrgF7gZRF50RhT4nPp28aYG8ahjUoppYIQTA99HrDNGNNpjOkH3gQ+Mr7NUkopNVrBBPQDwFoRSRORWOA6YLqf69aIyF4ReUlEFvi7kYjcJSLFIlJcX19/Fs1WSinla8SAbow5DHwf2AK8DOwBBnwu2wUUGGOWAA8Dzwa410ZjzEpjzMqMDL918SM6WnOG/91ylMb2njF9vVJKTVVBDYoaYx4zxqwwxlwCNAPHfM63GWPanY83AxEikh7y1gKl9e08/HoJ9RrQlVLKS7BVLpnOf/Nx5M//4HM+S0TE+XiV876NoW2qQ6TV0eTefvt43F4ppcJWsFP/nxKRNKAPuNcY0yIidwMYYzYAtwJfEpF+oAu4zYzTVkhREY6A3qMBXSmlvAQV0I0xa/0c2+Dx+BHgkRC2K6AomxXQHrpSSvkKu5mikTZXD913XFYppT7Ywi6gR7kCep/20JVSylPYBXRXD713QAO6Ukp5CruArj10pZTyLwwDumNQtEd76Eop5SXsArp7ULRPB0WVUspT2AV0d8pFyxaVUspL2AV0nSmqlFL+hV1At1iESKtFe+hKKeUj7AI6ONIu2kNXSilvYRnQI20WnSmqlFI+wjKgR9k05aKUUr7CMqBHaspFKaWGCMuAHmWzaspFKaV8hGdAj9AeulJK+QrLgK5li0opNVSwW9DdJyIHROSgiNzv57yIyE9FpERE9onI8pC31ENUhAZ0pZTyNWJAF5GFwBeAVcAS4AYRKfK5bD0w2/lxF/BoiNvpJdKqKRellPIVTA99HrDNGNNpjOkH3sSxUbSnm4DfGof3gWQRyQ5xW910UFQppYYKJqAfANaKSJqIxALXAdN9rskFTnp8XuU85kVE7hKRYhEprq+vH2ubdVBUKaX8GDGgG2MOA98HtgAvA3uAMXWPjTEbjTErjTErMzIyxnILQAdFlVLKn6AGRY0xjxljVhhjLgGagWM+l1Tj3WvPcx4bFzooqpRSQwVb5ZLp/DcfR/78Dz6XPA98ylntshpoNcacDmlLPURarZpyUUopH7Ygr3tKRNKAPuBeY0yLiNwNYIzZAGzGkVsvATqBz45HY10cPXQdFFVKKU9BBXRjzFo/xzZ4PDbAvSFs17CibBb6Bgx2u8FikXP1tEopNamF50xR5zZ0vbpRtFJKuYVlQI+yWQHo6dOArpRSLmEZ0F099J4BzaMrpZRLWAb0KFdA1x66Ukq5hXVA1xy6UkoNCuuArj10pZQaFKYB3TkoqrXoSinlFpYB3V22qLNFlVLKLSwDujvlogFdKaXcwjSgO1Iu2kNXSqlBYRnQI7WHrpRSQ4RlQB9MueigqFJKuYRnQI/QQVGllPIVlgE90qopF6WU8hWWAT0qQuvQlVLKV7A7Fn1NRA6KyAER2SQi0T7nPyMi9SKyx/nx+fFproOrh64pF6WUGjRiQBeRXOCrwEpjzELACtzm59I/GWOWOj9+FeJ2eomwCiKaclFKKU/BplxsQIyI2IBY4NT4NWlkIkKUzaI9dKWU8jBiQDfGVAM/ACqB0zg2gN7i59JbRGSfiDwpItP93UtE7hKRYhEprq+vP6uGR1ot2kNXSikPwaRcUoCbgBlADhAnInf4XPYCUGiMWQy8CvzG372MMRuNMSuNMSszMjLOquFREVYdFFVKKQ/BpFyuBE4YY+qNMX3A08CFnhcYYxqNMT3OT38FrAhtM4fSHrpSSnkLJqBXAqtFJFZEBLgCOOx5gYhke3z6Id/z4yEqQgO6Ukp5so10gTFmm4g8CewC+oHdwEYR+Q5QbIx5HviqiHzIeb4J+Mz4NdkhymbVQVGllPIwYkAHMMZ8E/imz+GHPM4/CDwYwnaNKNKmPXSllPIUljNFwbFAV0+fDooqpZRLWAd03SRaKaUGhXVA102ilVJqUBgHdKv20JVSykPYBnTHoKjm0JVSyiVsA7qmXJRSylvYBvRIHRRVSikvYRvQtYeulFLewjig66CoUkp5CtuAHmmzMGA39GtQV0opIIwDepRNN4pWSilPYRvQI226r6hSSnkK24AeZbMC2kNXSimXMA7o2kNXSilPYRvQI905dJ0tqpRSEGRAF5GvichBETkgIptEJNrnfJSI/ElESkRkm4gUjktrPeigqFJKeQtmk+hc4KvASmPMQsAK3OZz2Z1AszGmCPgR8P1QN9RXpAZ0pZTyEmzKxQbEiIgNiAVO+Zy/CfiN8/GTwBXO/UfHzeCgqKZclFIKggjoxphq4Ac4Nos+DbQaY7b4XJYLnHRe3w+0Amm+9xKRu0SkWESK6+vrz6rhURE6KKqUUp6CSbmk4OiBzwBygDgRuWMsT2aM2WiMWWmMWZmRkTGWW7hFWjXlopRSnoJJuVwJnDDG1Btj+oCngQt9rqkGpgM40zJJQGMoG+orOkIDulJKeQomoFcCq0Uk1pkXvwI47HPN88CnnY9vBV43xpjQNXOoSKsjh64pF6WUcggmh74Nx0DnLmC/82s2ish3RORDzsseA9JEpAT4OvDAOLXXLSpC69CVUsqTLZiLjDHfBL7pc/ghj/PdwEdD2K4R6UxRpZTyNgVmimpAV0opCOeA7qpy0V2LlFIKCOOAbrNasFqE3gHNoSulFIRxQAfdV1QppTyFfUDXfUWVUsohrAN6pPbQlVLKLawDepTNqnXoSinlFNYBPVJTLkop5RbWAV0HRZVSalD4B3SdWKSUUkCYB/RIm0Wn/iullFNYB3QdFFVKqUFhHdAjNeWilFJuYR3QozTlopRSbmEe0K3aQ1dKKadg9hSdIyJ7PD7aROR+n2suE5FWj2seCnC7kNKUi1JKDRpxgwtjzFFgKYCIWHHsH/qMn0vfNsbcENLWjcBRtqiDokopBaNPuVwBlBpjKsajMaOlOXSllBo02oB+G7ApwLk1IrJXRF4SkQVn2a6guCYWjfN+1EopFRaCDugiEgl8CPiLn9O7gAJjzBLgYeDZAPe4S0SKRaS4vr5+DM31FhVhBdD1XJRSitH10NcDu4wxtb4njDFtxph25+PNQISIpPu5bqMxZqUxZmVGRsaYG+3i2oZO0y5KKTW6gH47AdItIpIlIuJ8vMp538azb97woiJ0o2illHIZscoFQETigKuAL3ocuxvAGLMBuBX4koj0A13AbeYcJLa1h66UUoOCCujGmA4gzefYBo/HjwCPhLZpI9MeulJKDQr7maKA1qIrpRRhHtA15aKUUoPCOqBrykUppQaFd0B3ply0h66UUmEe0CNtrh665tCVUiqsA3qUK6DrRtFKKRXeAd3VQ9ep/0opFeYBXXvoSik1KMwDurMOXXvoSikV3gHdPSjap4OiSikV1gHdnXLRskWllArvgK4zRZVSalBYB3SLRYi06kbRSikFYR7QQfcVVUopl7AP6JE2i84UVUoppkBAd20UrZRSH3QjBnQRmSMiezw+2kTkfp9rRER+KiIlIrJPRJaPW4t9RGrKRSmlgCB2LDLGHAWWAoiIFagGnvG5bD0w2/lxAfCo899xF2WzaspFKaUYfcrlCqDUGFPhc/wm4LfG4X0gWUSyQ9LCEURFaA9dKaVg9AH9NmCTn+O5wEmPz6ucx7yIyF0iUiwixfX19aN8av9GW7bY1TtAvy4VoJSagoIO6CISCXwI+MtYn8wYs9EYs9IYszIjI2Ost/ESFTG6gH79T9/m4ddLQvLcSik1mYymh74e2GWMqfVzrhqY7vF5nvPYuIu0Bp9yae7opayhg0On28a5VUopde6NJqDfjv90C8DzwKec1S6rgVZjzOmzbl0QRjMoWlrfDsCplq6g7z9gN2Nql1JKnWtBBXQRiQOuAp72OHa3iNzt/HQzUAaUAL8E7glxOwMazaDoaAP6n3ecZOm3t9DY3jPm9iml1LkyYtkigDGmA0jzObbB47EB7g1t04IzmkHR0voOAJo7++jqHSAm0hrw2vKGDr75/EG6+gY4VtvOmviokLRXKaXGS/jPFB3FoGhJXbv78anWwL30/gE7X//zHuzGkW6pbOo4u0YqpdQ5EPYBPdJqHVXKJTPB0dMeLu3yi7fK2FXZwn9+ZBE2i1DZ1BmStiql1HgK+4Du6KGPPCja3TfAyaZOLp6dDgQO6AeqW/nRq8e4fnE2Ny/LJTclhopGDehKqckv/AO6zULfgKFvhMlC5Y0d2A1cXJSOCFS3dA+5pqd/gK//eQ+pcZH8x4cXIiLkp8ZqD10pFRbCPqAvzEkCYPP+4askS+scefA5WQlMS4jmtJ8e+o4TzRyrbeehG+eTHBsJoAFdKRU2wj6gr5ubSVFmPBveLMOYwDXjrpLFmenx5CRH+x0ULak7A8CqwlT3sYK0WFo6+2jt6gtxy5VSKrTCPqBbLMIXL5nJ4dNtvHks8PowpfXt5CbHEBNpJTs5hlN+Ui6l9R0kRNvISBgsUcxPjQXgpPbSlVKTXNgHdICbluaSnRTNo38vDXhNSV07RZnxAOQmx1Dd0jWkR19S186sjHhExH0sPzUOYNwHRt86Vs/F33+d9p7+cX0epdTUNSUCeqTNwufXzmTbiSZ2VjQPOW+3G8rqO5iV4QjoOUnR9Pbbaezo9bqutL7dfY1Lfpqjhz7eefQX9p6iqrmLsvr2kS9WSik/pkRAB7jt/OkkxUSw4c2hvfTTbd109Q0wK9PR285JjgG8SxfbuvuoO9Pj7sW7xEfZSIuLHNfJRcYY3i1tHNImpZQajSkT0OOibHz6wkJePVTL8dozXudcM0TdPXR3QB/Mo5e6r4kbcu/pqbHjmnKpbOqk2hnIq5o1oCulxmbKBHSAz1xYSHSEhQ1vlnkddwVrzxw6ePeGXeu8+PbQwVHpMp4pF1fv3FEfrwFdKTU2Uyqgp8ZFctv5+Ty3p9qrKqW0vp2kmAjS4hy15cmxEURHWHwCejsRVmG6s6rFU0FqLKdausZtq7utJQ1kJkRRlBGvKRel1JhNqYAO8MVLZ2IR4VGPXLpjsDPOXb0iIuQkx3jVopfUtVOQFkeEdei3ZHpqLHYzPvltYwzvlTZy4aw0clNitIeulBqzKRfQs5NiuHVlHk8WV3HaGbBL6jqGVK84Shc9cuj17RRlDE23ABSkOUsXxyHtcqy2ncaOXi4sSicnOYZqzaErpcYo2A0ukkXkSRE5IiKHRWSNz/nLRKRVRPY4Px4an+YG50uXzsJuDL94s4zWzj4a2odWr+Qkxbin//f226lo7HRXwfhyTS4ajzz6u6UNAI4eenIMzZ19dPZqLbpSavSC2uAC+AnwsjHmVudm0UMTzfC2MeaG0DVt7KanxnLzslw2ba9kzSzHvhy+PfSc5BjqzvTQ0+9YhXHAbvwOiAJkJkQRZbNQ2Rj60sWtJY3kp8aSlxJLXsrgYG1RZkLIn0spNbWN2EMXkSTgEuAxAGNMrzGmZZzbddbuvbyIvgE733nhEACzfHvoydEA1Lb2UOJcuMs36LtYLI7B0lD30PsH7Gwra+SiojRnmxwBXUsXlVJjEUzKZQZQDzwuIrtF5FfOPUZ9rRGRvSLykogs8HcjEblLRIpFpLi+PvC6K6FQmB7Hh5bkUN3SRaTVwnRn79fFFTyrW7oGF+4KENDBUekS6lr0g6faONPTz5pZjjXac/3UxyulVLCCCeg2YDnwqDFmGdABPOBzzS6gwBizBHgYeNbfjYwxG40xK40xKzMyMsbe6iB9eV0RIlCYHovNp3rFc7ZoaV072UnRxEcFzkC5eujDreg4Wlud+fM1Mx099MyEKKwWobpFFwJTSo1eMAG9Cqgyxmxzfv4kjgDvZoxpM8a0Ox9vBiJEJD2kLR2DoswE7rlsFh9dMX3IuewkR8rlVEsXJX7WcPFVkBZLZ+/AkPVfzsZ7pY3MmZbgXt3RZrWQlRitPXSl1JiMGNCNMTXASRGZ4zx0BXDI8xoRyRJnkbeIrHLetzHEbR2Tf7xmLl+4ZOaQ49ERVtLjIznV6uihBxoQdXFVunimXY7WnGH7iSa/1xtj+Mqm3bwUYOONnv4BdpQ3uQdtXXJTtHRRKTU2wVa5fAV4wlnhUgZ8VkTuBjDGbABuBb4kIv1AF3CbCWVuYpzkJMewu7KFjt4Bv2u4eCpIG1wXfUVBCqdbu/j4L9+n324o/saVQyYkHT59hhf2nqKzp5/1i7KH3G/vyVa6++xc6BvQk2MC/pFQSqnhBBXQjTF7gJU+hzd4nH8EeCR0zTo3spOieeVgLRC4wsUlL2Wwh97bb+eeJ3bR1NmLMY7UySXneY8JvHKwBoDdJ1swxnitsQ5QXOEI2ud77I4EjoBe09ZN/4B9SN5fKaWGE2wPfUpyDYyC/0W5PEVHWMlKjKayqZP/+Oshdle28MOPLeEbzx7gpQM1QwL6lkO1iEBTRy+VTZ3u2aYuuyqamZURR4pzfRmX3JQYBuyG2jM97qoXpYbzu/creP1wLa1djq0SO3oGeOjG+Vzn552hmto+0F1AV8BMiPLedi6Q/NRYthyq4TfvVfD5i2fwkeV5rJubyauHahiwD2aYTjZ1cvh0GzcvzQVgz8kWr/sYY9hZ0cyKgpQhz+Eup9Q8ugrST/52nP3VbcRG2piblUhrVx9vH2+Y6GaNaGtJA49vPTHRzZhSPtAB3RU8Z2XGD0mJ+JOfFsuZ7n5WFabyz+vnArB+YTYN7b3sKB/Me7vSLV9eV0RspJXdlS1e9ylr6KC5s89vQPe3tK9SgXT3DdDQ3sOn1hTw+89fwM8+sZzZ0+Kpap78pa8/e6OE720+THffwEQ3ZcrQgM7I+XOX1TPTmJkRxyMfX+YeBL1sTgZRNgsvH6hxX7flUC1zsxKYmRHP4rwkdld6b4vn2ibPfw/dUU6pqy6qYLj+8Od5TJzLC4NVO3v77eyqbKZvwAzp8Kix+0AHdNcvwXnTggvot67I4/V/uIzMxGj3sbgoG5eel8FLB05jtxsa23soLm/i6gVZACzLT+HgqTavXsiuimaSYiKYmT70eWMjbaTGRer0fxUU1+vEc7wlLyWW6uahm6D7U9nYyU9fOx7SCXPBOHDKUeUFUFyuVV2h8oEO6OnxUfzuzlV8YnXBWd1n/aIsatt62H2yhdcO12E3cPX8aQAsm55Mv91w8FSr+/piZ/7cYvGf5slNjjmrlEvfgJ2fvnacujadoDTVuXrieR4bs+SlxNDTb6e+vWfEr//D9kp++Oqxc96B2OEszc1Oima7BvSQ+UAHdIC1szOGnfIfjHVzpxFhFV4+cJoth2rITY5hQU4iAEvzkwHcbytbOnspqWv3m25xyUmOPqu3zM/tOcUPXz3GI2+UBHV9a1cfR2raxvx8Z8NuN+O2E9QHQVVzJ1aLMM1jUN/1zjOYIH2gujXoa0NpR3kTM9PjuHLeNHZVNNM/oK+BUPjAB/RQSIqJ4KKidF7cd5q3jjdw9YJp7kHWzIRocp0TmGAwsC/PDxzQc5MdW955vg3+12f2c+f/7RixLXa7YeNbjt2antpZxZnuvhG/5h/+vIf1P3n7nFccVDV3cuWP3uSeJ3ad0+edSqqbu8hOivaas5Cb7OitjxSkjTHsdwb0k+dwENVuN+wob+b8wlTOn5FKR+8Ah0+fGfkL1Yg0oIfIdQuzOd3aTW+/navnZ3mdW5af7B4Y3VnRjNUiLJmeFPBeuSkxdPYO0NLpCMblDR1s2l7Ja0fqhpRA+vr7sTqO1bbz2YsK6egd4KmdVcNef/BUK387XEdWYjTffuEQ//7iIez28c+nltSd4dZH36OsvoPXj9TSFMI1cj5Iqpq7hsxXyE0JrvT1ZFMXrV2O11jVOG6C7utY3Rlau/o4f0Yq5xc6OjYjpV16++3c+PA7/P79ijE95x+2VfK9zYfH9LXhRAN6iFw1fxpWi5ASG+F+kbosy0/hVGs3tW3d7KxoZkFOIrGRgdM8uT6VLhvfLsNmsZAQZeOxd4bvRW94s4ycpGj+5bp5LMtP5rfvVQwboH/+RikJUTY2f3Utn7mwkMfeOcE9T+wa11Ky/VWtfOwXjmUT/vvWxdgNvHa4dtyebyqrbulyz2J2iY+ykRIbMWLpoqt3LnJuUy6u/PmqwlSyk2LIS4lxHwvkREMH+6tb+bfnDvDivlOjer72nn7+66XDPPF+xTkf/D3XNKCHSEpcJHdckM+dF88YMmV/mTOPvqO8iT0nW4ZNt8DgW+bqli7q2rp5sriKW1fmcduq6WzefzrggOmuyma2n2jizrUzibBa+PSaQsoaOni7xP8kk5K6M2w+cJpPXVhASlwk3/rQAv7thvm8cqiGf3xy3yi/A4HZ7YaTTZ28cbSOX7xZyu2/fJ+YCCt/uXsNH12RR47HEgwqeL39dmrbut09ck95KbEjBun91a1EWIWl05PPacple3kzWYnRTE91tHtVYSrFFU3DBtuSOseeBfmpsXz9T3vdWzcG4887TtLW3U9H7wAN7VP7naAG9BD69k0L+fK62UOOL8hJJNJqYdP2Srr6BoYdEAWPWvTmLn69tZx+u50vXjKTT19YiDGG37xX7vfrNr5ZRlJMBLed71gu+LpF2aTHR/Hbd/1f//O/lxJts/K5i2a4j9158QxuOz+f1w/Xes1+HatXD9Wy8FuvsPa/3+Czj+/gP186Qn5qLE9+aQ0z0uMQEa5ekMXbx+vPei/VUy1dPLu7+qzbHC5qWruxG8jzs0REXkpMED30FuZkJTAzPf6c9dCNMWw/0cj5M1Ld40znz0ilob2XEw2Bt3gsrW9HBP501xoK02O567c7vSrHAukfsPPYOyfchQ/l47CNZEtn74ip0HNFA/o5EGWzMi8nka0ljhWFRwroqXGRREdYOFLTxu/fr+D6xTkUpMWRlxLLtQuz2LStko4e7+BXVt/OK4dq+OTqAuKcL95Im4WPr5rO60frqPTZbamysZPn9pzi4xfkkxbvvezBBc6BqqM1/geqDp5qDTr4btpeSUK0je/dvIg/f3ENu/7tKjbft5bspMEgdPWCafT023nr2NntYvXb9yq4/097eONI3VndZ7wZY9i0vfKsxw2qnBuh5PnpoecmOyYXBer1GmM4UN3Gotxk8lIcC8L19I//jM2TTV3UtvWwyiMt6VqgbscwefTS+nZyk2PISorm/z67ioRoG595fId7kl4gmw/UUN3Sxf1XOjpa5cP80Rirf3lmPx/d8O6kGAfSgH6OLJueDDjqbnNGWHRLRMhNjuGpXdW09/Rz96WD67nfefEM2rr7eXqX92DnL98uc6RZLiz0Ov6J1QVYRfjd++Vexze8VYpVhLv8rBXvSgntqhz6y9Lc0ctNj2zlv18+Ouz/ARy5y3dKGrhhcQ4fvyCfVTNSSfVZjAwcb7mTYyPOOu1S4ex9fefFQ5O6FLK0vp0Hn97Pf7985Kzu455U5DflEkN3nz3ghiyuAdFFuUlMT43FmHOz9aFr8HPVjMFlo2dlxJEaF8n2E4GDc0nd4CY0Ockx/PZzq7BZhFsefZd/enIvjX5q7o1xVHzNzIjjk2sKsFok5D3047VneOlADX0DhpcO+N/74FzSgH6OuPLoI/XOXXKSHasuXnpeBgtyBitiluensGR6Mr/eWo7dOWHp87/ZwabtJ/nYyrwhi4xNS4zmmoVZ/GnHSX72Rgm/fucEv3u/gieLq/joyjymecx6dZmeGkN6fKTfgP5eWSP9dsNTO6tG7KW/dazeWfUzbdjrbFYLV8ydxmuHa+k7i3rkisZO0uOjONHQMakXfXLlg5/aVXVWE8iqm7sQwevdjotroDRQKmVfdQsAi3KTPOrWxz+Pvv1EI0kxEcz2WN1URFhZkBKwh263G8rqO7xWRJ09LYG/ff1SvnjJTJ7eVc26/32T379f4VXP/l5ZIweq2/jC2plE2azkpcRQ3hDa/+PP3ighJsLK9NQYXtg7usHa8RBUQBeRZBF5UkSOiMhhEVnjc15E5KciUiIi+0RkeaB7fVCtKEjBInDBzLSRL2bwbfQ9l83yOi4i3HnxDE40dPCxX7zH9T99h+0nmvjHa+bwjevn+73XXWtn0jtg539eOcp3XjzEvz17ABG4+9JZfq8XEZblp7DLz9vZrSUNWATO9PTz4t7heyRbDtaQGhcZ1B+xaxZMo627n21lY5s1aIyhorGDGxZnc+W8zEk9U7a03tFLNMbxzmqsqlu6mJYQTaRt6K9xXurwQXp/dSuRVgvnZcUzPdW1ecv459Ed9edDZ0mvmpFKZVMntX5+Zqdau+jqGxiy5lJclI0Hr5vH5vvWMi87gW88e4CrfvQWz+2pxm43/PKtMtLjI7l5mWPV08K0uJD20MsbOnh+7ynuWF3ALcvz2HaiiZrWoe3/5VtlPLFtbOWWoxVsD/0nwMvGmLnAEsC3oHM9MNv5cRfwaMhaOEXkpcTy0n2XcPv5Q/c39eejK6dz/5WzWTUjdci59QuzyEuJ4fDpNr66roi3/3kd915eRHSE1e+9lkxP5vB3ruXIv1/Lnoeu4r0H1/HuA+vcv8j+rChIobyxc8hb2a0lDVw+J5PzpsUP+yLtG7Dz2pE6rpibGdRGHWtnZxAdYXGvVAmOPP+9f9jFtrKRdzNs7Oilo3eA/NRYvnH9fPoGDP91limN8VJa105OUjQ3Lc1l0/ZKGoKYou9PVXOn3/w5DK7tEqiHfqC6lTlZCUTZHOv82ywy7pUudWe6OdHQMWRTFxg+j+76AxhoV7HzpiWw6Qur+cUnVxBls3DfH/dw1Y/e5I2j9XxqTaH792JGehzlDR0hK1189O+lRFgtfH7tDG5ckoMx8FefLSeP1pzhey8d5l+fOcCjfy8NyfMOZ8TfNBFJAi4BHgMwxvQaY1p8LrsJ+K1xeB9IFhFdXd/HnKyEoHchWp6fwv1Xnud3Wd8Iq4Xn7r2Idx+4gq9fPYekmIgR7yciREdYSY6NJDspZshAqL/nB7xWwqtq7qS8sZOLitL5+Kp89la1uqeO+9pW1sSZ7n73ImUjiYm0cul5Gbx6qBa73fDygdNc//Db/HXfaf7xyX0j1sW79notTI+lMD2OO9fO4Old1SMOmk2E0vp2ZmXG86XLZtHTbx9zeqi6pctv/hwgITqC5AC16MYY9le1sjDXkcqzWoSc5Jhxr3QpLnf8LPx1UubnJBITYfVbj+5KUQ23CY2IcM2CLDZ/dS0P374MYyAx2sYdHus0FaTFhqx0sbqli6d2VXH7qnwyE6KZlRHPgpzEIWmXH716jPhIG+sXZvH9l4+4Z3GPl2CiywygHnhcRHaLyK9ExPdPZS5w0uPzKucxNU7S4qNIih05kI/V4rwkbBbxyqO/66zSuXh2OjcvzyM6wsIT2yr9fv2WQzXERFhZOzs96Oe8ZkEWNW3d3P37ndz9+13MTI/jv29ZTGVTJ794c/jURGWToxeXn+p4aX758iKmJUbx3b8eGu7LzspYenrGGErrO5iVEU9RZjzrF2bx23cr3DM2gzVgN5xu6Q7YQwdnpYufIF3Z1Elbdz+L8wbHZqanxnDSz2zRjp7+gMtH2O2Gb79wkH1VLUG1eWdFM1E2i/sPiacIq4UVBSm87yflVlrfTnJshN8BdV8Wi3Djkhxe/fqlbH1gndfXFKY7XhuhSLts+HspIngVFdy4JIc9J1vcFWUHqlt5+WANn7t4Bg/fvowbFmfzvc1H+NVZpNlGEkxAtwHLgUeNMcuADuCBsTyZiNwlIsUiUlxff3Ylamp8RUdYmZ+T6NXDfaekgYyEKGZnxpMUE8GNi3N4bk/1kF94YwxbDtZyyXnpAdNA/qybm4nVImw5VMvnLprBX+6+kI+dP53rF2Xz87+X+A04LuUNnYjgnqwSF2Xj3suL2F3Zwt4Q1wgbY/jW8we55sdvjXoQt7ath/aefnf64J7LijjT08/vAswtCHyfbvrtxj0JzR9HLfrQgO6aIbrII7DmJfufiPSVTbv5XIA1hPZXt/L41nL+tOOk3/O+9lW1sCAncciG6i5rZqVxtPbMkBRUaV07RRnBbULjYrUICdHeHZ5C5zaQZ1u6WNfWzZ+KT3LrijyvirUbFjuSEi84Z7L+6NVjJMVEcOdax2TDH/+/pVy/KJvv/vXwuA3aBxPQq4AqY8w25+dP4gjwnqoBz+RwnvOYF2PMRmPMSmPMyoyMDN/TapJZnp/CvqpW+gfsGGN4t7SBi2aluX+xPrG6gM7eAZ7b4/02c391KzVt3UPWtBlJcmwk/3PrYh7/7Pk8dON892Dfv14/D4sI//5i4N52ZVMnOUkxRNkG/4DcvCyX2Eir3/U/7HbDP/5l75gmIv3ob8f5v3fLOVbbPuqUTmm9I33gGuBbmJvE5XMy+PXW8lFNrHItCxEo5QKDs0V930nsr3IOiE5LcB+bnhpDQ3sPXb2Dqa3efjtbSxrYUd7sd/XPVw85ykwDpd089Q/YOVDdxuK85IDXXDjLUTDwvs+YSWl9e9Cb0AwnLyUmJKWLf9heSd+AfUhRQV5KLCsKUnhh7yl2Vzbz2pE67rpkJonOPyw2q4Uf37aUDy/NYUa6//GAszViQDfG1AAnRWSO89AVgO9v1vPAp5zVLquBVmPMxBdlqrOyvCCFrr4BjtSccfacermwaDCFsiQvifnZiTyxrdIraLxysAarRVg3N3PUz/mR5XlcPsf763KSY/jKFUVsOVTL34/6nzRU0dhBvs8gb0J0BDctzeWFfado7fR+F/Hi/tP8ZWcVDzy9b8ikq+H87r1yfvracW5amkOEVUY9iclfPvjL64po6ujlj9uD6+nCYPXKcCmXvJQYuvoGhkx42V/dytzsBK/qGNcAuWfOfX91Cz3Oev5XPHbkcnEF9MM1Z0Z8p3K8rp2uvgGWOudj+LMoN4n4KBvvlg4G9JbOXhrae5mVefYBMMJqcZQujuLn7ctuNzy1q4oLZ6UN2fgd4MbF2RypOcM/P7WP1LhIPuMzLyTCauHHty3jsjmj/90IRrBVLl8BnhCRfcBS4HsicreI3O08vxkoA0qAXwL3hLqh6txb7qyd31nR7J7lepFHQBcRPrE6n8On2/jm8wfZVtbIgN2RbllVmEpKEDnPYN158QxmpMfxrecP+p3RWNHYSUHa0PTDHavz6e6z86THRKy+ATs/3HKUmelxWEX412f3B5UP/+u+0zz0/EGunJfJ/350CecXpvL6KAN6aX37kE3JVxSkcsGMVDa+VRb0bM1qPzsV+fJXi+5aMtc3j+1vDXVXPjsvJcZri0VwVCAdrT3D0unJ9PbbOVY7/PK3rjy7Z97el81q4YIZqbznEdBd72iGGxAdjcK0uLNKuewob+JkUxe3LM/ze/66xdlYBI7VtnP3pTPds7bPlaACujFmjzNVstgY82FjTLMxZoMxZoPzvDHG3GuMmWWMWWSMKR7fZqtzITc5hmmJUeyqbGZrSQMz0+OGBJCbl+WyfmEWf9x+kv+38X1WfPdVjte1c/WC4ScTjVaUzcq/3TCP8sZOtvjMKG3v6aexo5d8PwF9QU4Sy/KTvVba+0txFeWNnfzLdfP45/Vzeft4A0/vGj71cvBUK1/70x5W5Kfw8O3LsVktrJubyfG69mFz+75K69uZ6WdT8nsvL6KmrZtnRmiHS1VzF+nxUcOOUbiCtGe6pKKxkzPd/V75c4DpzuDvWbq4/UQT502L59YVeeyoaKL+zGBue8shR4D/2lXnASOnXfacbCUx2ubOYweyZlYaJxo63BOuSutcJYuhCuixVDR2jrl08aldVcRFWrl2of90YmZCNBcVpZOZEMUnVxeeRUvHRmeKqoBEhOX5KRSXN7OtrJELi4ZOioqNtPHoHSvY9dBV/Ozjy7nsvAzmZydy/aLQV61eMjuDmAjrkD0oXVP+AwWLOy4ooKyhg3dLG+nuG+Anrx1jeX4yV8zL5I4LClien8y///XQsPXgfz9aT++AnQ2fXEFMpCOIXu5MKQVKA/lTWtfht5567ex0FuUm8eibpUHt3jNcyaJLrp8ZoK46ad/JXunxUUTaLO4eev+AneLyJlbNSOXahVkYMxjEwZFumZuVwNqidOKjbO6B1kD2VbWwOC854LaLLhfOcrwDdPXSS+rbibRZhiwRPFaF6XG09/SPqXSxq3eAzftruG5R9rDLX//4/y3lmXsvcr9OziUN6GpYy/NTqG7poqN3gIuLApcgxkfZuH5xNj++bRmb71vrtZF2qNisFpZOT6bYZyDSlQP3zaG7XL84m+TYCH7/fgW/ebec2rYe/unauYgIFovw/VsW09HTP+yga0ldO1mJ0aR71O/PTI+jIC026LTLme4+atq6/aYPRIR7Ly+iorGTzT7pjbeP1/PLt8q8epVVzV3D5s8BEqMjSIy2uYN0W3cfG98qY93cTK8BUXCU++WlDJYuHjrdRkfvABfMSGPOtAQK02LdaZfmjl52lDdx1fxpWCzCgpxE9lcH3sKwu8+x0Ntw6RaXuVkJpMZFstW5PG5pXbsjNTbCH4JguUoXK8YwMPrKwRrae/q5ZYX/dItLWnzUsKmw8aQBXQ1reUEy4NgEYXWQyxaMp5WFKRw+3Ua7x2qTrkEufzl0cJRgfnRFHlsO1fKzN0q45LwMr//L7GkJ3HNZEc/tOcW7AdeObx8SiEWEy+dk8m5po1d1iDGGe/+wix9u8V7ArKx++PTB1fOnUZQZz8/fKMFuNxhj+NXbZXzq19v5j82H2eIchLTbjWNjiyCChue66I+9fYLWrj6+7kyT+JqeEutOuWx3bULhXOb22oXZvFfaSGtnH68fcWyEfpVzjZ6FuUkcPt0WcGD04Kk2+u2GJcMMiLpYLMKamWm8V9rorNkPTYWLi+td3HBL9RpjeOd4A20+5bhP7qwiLyWGVX5muk4WGtDVsBbkJBFptbAoN4nk2NANco7VioIU7Ab2eMxgrWzqIC0uckjdsaePX1DAgN3Q1t3PP10zZ8j5L102i0irhTf9LOFrtzsCi7+e9eVzM+npt3uV2v25+CR/3Xeax7eWe81w9S1Z9GWxCPdcNosjNWd45WANDz69n+/+9TDXLshizrQEvvPCITp7+2no6KG33z5iDx0G10Vv7ujlsXdOcO2CLL8TewavdQT/98uaKEyLdS/edu3CLPrthr8druXVQ7VkJUa78/CLcpPo7be7K3h8uQZElwxTsuhpzaw0Trd2c7T2DJVNncwK0YAoDJYuVgSodOns7eerf9zDHY9t4yM/f9f9juVUSxdbSxv4yPK8EdNGE0kDuhpWdISVr6wr4ksBFvI615YXpCACxRWDefSKxk6/A6KeZqTHceuKPO5Yne83oEVHWCnKjOfQ6aGpg9Nt3XT2DvgN6BfMSCUmwupOu9Sf6eE//nqYrMRozvT0e5U1ltS1Y7NIwHcS4JhtmJcSw5c37eaPO07ylXVF/Ozjy/nuzQupbuni4ddLhl0215erh77x7TI6evvdg5j+TE+NpaWzj9auPnaUN3GBxxK3S/KSyE6K5tk91bx1vJ4r52e6B3Zd389AefS9J1uYlhhFVlJwaThXPfqmbZXYTeA1XMbCVbp4wk/K5WRTJx/5+bu8uO8Un15TQF1bNzf/fCu7K5t5Znc1xsAtyyf3BHgN6GpEX7liNuvHYZBzLBKjI5gzLcFrQk9FYycFwyw05vKDjy7hux9eFPD8vOxEv7vPD7eWSHSElYuK0nn9SB3GGL7z4iG6++z87s5VpMdHeU26Kq1vpyAtNuBMSXAEnK9deR6RzpmF/3D1HCwW4fzCVG5dkccv3yrjzaOOdxHDzRJ1cW04/tg7J7hxcQ5zshICXuvq8b9+pJbWrj6vNVdca6W8fbyBzt4BrvKYNDYzPY64SGvASpd9Va3DTijyNSM9jqzEaJ5yVvyEqmTRxV/p4tvH67nxkXc41dLF4585n2/ftJCn77mI2Egbt218n8e3lrOqMNVv7flkogFdhZ0VBSnsrmxhwG7o6R/gVGtXSH7R5uck0tDe41WeByMvDrVubibVLV1sfKuMF/ae4svripg9LYEbl2Tz+pE69zotrjVcRnLLijz2f+tqPrzMuzf4wPq5xEZaeeSNEiDYHrrjmv4Bu3vXnkBcpYtP7XQE0gtmeueK1ztL9eKjbKz2OOcYGE3y20Nv7eqjrKGDJUEMiLqICBfOSnOPk8xMD3VA9y5dPFpzhjt/U8y0hGhe+MrF7kk/RZnxPHPPhSzMTaKhvYdbVkzu3jloQFdhaGVhCu09/RytOeOc2h54QHQ05mU7eq+HfdIuJXWOxaHSAkyUumyOYxmL/3zpCLMz491Twm9amkvvgJ1XDtTQN2CnorEj6Hywv1U50+Oj+Kdr5zJgNyTHRrj3yRyOK6B/ZHkeM0f4Y+KaLbq1tIHc5JghpYIrC1PJSozminmZXksswODAqG/J5f4qR5APZkDU0xpn2iU3OSbk5X+epYvdfQPc98fdJEbbeOILFwzpGKTFR/HE5y9gwx0ruHVFcEtfTyQN6CrsrCxw9A53VjS5SxZDEdDnZycC/gL6mWEXh8pJjmGuM5XxX7csck+pX5KXRGFaLM/traayqZO+AXPWFRu3r8pn6fRkzssMnDrxNC8rkQfXz+WB9XNHvDYlNoLYSCvGOMYGfFktwrP3XsR3P7xwyLlFeYl099kpqfceGN3rmiGamxxUe11cAT3U6RYYrHSpaOzgB68c5UjNGf7n1iVeJameoiMcE4lCVTo5ns7tvFSlQiAvJYbMhCiKK5rptzveNoci5eJYKz7abw/9mhHWdX/wunmcauliRYF33vlDS3N5+PXj7vVJznaAz2oRnvj8BQwEOdPRYhG+GOSAtogwPSWWo7Vn/K5ZDgQc2HRVvOyvamVuVqL7+N6TLcxIjxv1Us95KbFcPifDPXkrlFy16H/YVsnTu6v55OqCcXmeiaA9dBV2RIQVBY4ZrBWNncRFWgOmQ0bLd2C0sb2H5s6+EXuKl56Xwe2r8occv2mpYyebR51571CU4MVF2dwr+IWaK0UTKKAHMiM9nlg/A6OOAdHg8+eeHv/sKj61pnBMXzscV+ni07urmZURx79cNy/kzzFRNKCrsLSiwDGDdfuJJvLT4ka1VvZw5mUnUFrf7l4kK5jdcoYzKyOeRblJnGrtJjMhatwCcagsnZ5MUWb8qJd3tbpnjA4G9Nq2bmraukdV4XIuuEoXbRbhJ7ctm5Ap+uNFA7oKSyuds/UOnW6jMAT5c5d52Yn02w3Hax2BvCQEq/3dtDQHCN0CU+Ppy+uKeOX+S8b0B3JhbhKHTrfR22/nlYM13P37ncDgqp2TyX1XzOZ/P7Yk4CSrcKUBXYWlBTmJREc4Xr4jTSoajXk+A6Mlde3ERFjJSRr72hw3LslBZHwG+EJNRMY8+LcoN4nuPjuX/s8bfPF3O2lo7+E/bl7IsvyUkb/4HPvI8jxuWjr5yxBHSwdFVViKsFpYkpfMthNNFKSGbrJHYVoc0REWdx69pK6dWZlxZzXde1piNL/+9PmcN8yknqng/MJUIqxCZkIU37h+PtcsmBb0pugqNDSgq7C1sjDFEdBD2EO3WoQ5WYlePfRQLEo2VaoohjM9NZY9D11NbKQ1ZGMaanSC+vMpIuUisl9E9ojIkM0rROQyEWl1nt8jIg+FvqlKeVu/MJv52YkszAltHnR+dgKHaxwrOp5u9b/crfIvLsqmwXwCjaaHfrkxxv/aog5vG2NuONsGKRWshblJbL5vbcjvOy87kU3bT7LVuZRuOAxmKgU6KKrUEK6B0Rf2OhbW0h66ChfBBnQDbBGRnSJyV4Br1ojIXhF5SUQW+LtARO4SkWIRKa6vH7rutFKTgWsa/98O14643K1Sk0mwAf1iY8xyYD1wr4hc4nN+F1BgjFkCPAw86+8mxpiNzs2mV2ZkZIy1zUqNq4ToCKanxtDdZ6cwPW7Y5W6VmkyCeqUaY6qd/9YBzwCrfM63GWPanY83AxEiEngDSqUmuXnO9Uhma7pFhZERA7qIxIlIgusxcDVwwOeaLHEObYvIKud9G33vpVS4cOXRNX+uwkkwVS7TgGec8doG/MEY87KI3A1gjNkA3Ap8SUT6gS7gNmOCXA5OqUlIA7oKRyMGdGNMGbDEz/ENHo8fAR4JbdOUmjiXnJfO5y+e8YGYEKSmDp0pqpQfsZE2vnHD/IluhlKjosP3Sik1RWhAV0qpKUIDulJKTREa0JVSaorQgK6UUlOEBnSllJoiNKArpdQUoQFdKaWmCJmoGfoiUg9UjPHL04HhNtuYSJO1bZO1XaBtG4vJ2i6YvG2brO2C0bWtwBjjd7naCQvoZ0NEio0xKye6Hf5M1rZN1naBtm0sJmu7YPK2bbK2C0LXNk25KKXUFKEBXSmlpohwDegbJ7oBw5isbZus7QJt21hM1nbB5G3bZG0XhKhtYZlDV0opNVS49tCVUkr50ICulFJTRNgFdBG5VkSOikiJiDwwwW35tYjUicgBj2OpIvKqiBx3/psyAe2aLiJviMghETkoIvdNhraJSLSIbBeRvc52fdt5fIaIbHP+TP8kIpHnsl0+bbSKyG4ReXEytU1EykVkv4jsEZFi57HJ8FpLFpEnReSIiBwWkTWTpF1znN8r10ebiNw/Sdr2Nefr/4CIbHL+XoTkdRZWAV1ErMDPgPXAfOB2EZnIbWX+D7jW59gDwGvGmNnAa87Pz7V+4B+MMfOB1cC9zu/TRLetB1hnjFkCLAWuFZHVwPeBHxljioBm4M5z3C5P9wGHPT6fTG273Biz1KNeeaJ/ngA/AV42xszFsVXl4cnQLmPMUef3aimwAugEnpnotolILvBVYKUxZiFgBW4jVK8zY0zYfABrgFc8Pn8QeHCC21QIHPD4/CiQ7XycDRydBN+354CrJlPbgFhgF3ABjhlyNn8/43Pcpjwcv+TrgBcBmURtKwfSfY5N6M8TSAJO4CyumCzt8tPOq4Gtk6FtQC5wEkjFsQXoi8A1oXqdhVUPncFvhkuV89hkMs0Yc9r5uAaYNpGNEZFCYBmwjUnQNmdKYw9QB7wKlAItxph+5yUT+TP9MfBPgN35eRqTp20G2CIiO0XkLuexif55zgDqgcedaapfiUjcJGiXr9uATc7HE9o2Y0w18AOgEjgNtAI7CdHrLNwCelgxjj+3E1YXKiLxwFPA/caYNs9zE9U2Y8yAcbwNzgNWAXPPdRv8EZEbgDpjzM6JbksAFxtjluNIN94rIpd4npygn6cNWA48aoxZBnTgk8KYBL8DkcCHgL/4npuItjlz9jfh+GOYA8QxNG07ZuEW0KuB6R6f5zmPTSa1IpIN4Py3biIaISIROIL5E8aYpydT2wCMMS3AGzjeXiaLiM15aqJ+phcBHxKRcuCPONIuP5kkbXP17DDG1OHIBa9i4n+eVUCVMWab8/MncQT4iW6Xp/XALmNMrfPziW7blcAJY0y9MaYPeBrHay8kr7NwC+g7gNnOEeFIHG+lnp/gNvl6Hvi08/GnceSvzykREeAx4LAx5oeTpW0ikiEiyc7HMTjy+odxBPZbJ6pdAMaYB40xecaYQhyvq9eNMZ+YDG0TkTgRSXA9xpETPsAE/zyNMTXASRGZ4zx0BXBootvl43YG0y0w8W2rBFaLSKzz99T1PQvN62wiByvGOKhwHXAMR+71Xye4LZtw5MH6cPRW7sSRd30NOA78DUidgHZdjOOt5D5gj/PjuoluG7AY2O1s1wHgIefxmcB2oATHW+OoCf65Xga8OFna5mzDXufHQdfrfqJ/ns42LAWKnT/TZ4GUydAuZ9vigEYgyePYhLcN+DZwxPk78DsgKlSvM536r5RSU0S4pVyUUkoFoAFdKaWmCA3oSik1RWhAV0qpKUIDulJKTREa0JVSaorQgK6UUlPE/wdnqbx6GpBeZAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker\n", + "%matplotlib inline\n", + "\n", + "plt.figure()\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 预测\n", + "用训练好的模型进行预测。" + ] + }, + { + "cell_type": "code", + "execution_count": 188, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the input words is: whiles, thou\n", + "the predict words is: art\n", + "the true words is: art\n" + ] + } + ], + "source": [ + "import random\n", + "def test(model):\n", + " model.eval()\n", + " # 从最后10组数据中随机选取1个\n", + " idx = random.randint(len(trigram)-10, len(trigram)-1)\n", + " print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n", + " x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n", + " x_data = paddle.imperative.to_variable(np.array(x_data))\n", + " predicts = model(x_data)\n", + " predicts = predicts.numpy().tolist()[0]\n", + " predicts = predicts.index(max(predicts))\n", + " print('the predict words is: ' + idx_to_word[predicts])\n", + " y_data = trigram[idx][1]\n", + " print('the true words is: ' + y_data)\n", + "test(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}