From 6cbf7fd87baf95116489810be0b4a96ce3b2c7de Mon Sep 17 00:00:00 2001 From: chenlong Date: Wed, 26 Aug 2020 11:12:05 +0800 Subject: [PATCH 1/2] add three docs for paddle2.0 --- .../mnist_lenet_classification.ipynb | 666 ++++++++++++++++++ .../n_gram_model/n_gram_model.ipynb | 344 +++++++++ .../text_generation_paddle.ipynb | 526 ++++++++++++++ 3 files changed, 1536 insertions(+) create mode 100644 paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb create mode 100644 paddle2.0_docs/n_gram_model/n_gram_model.ipynb create mode 100644 paddle2.0_docs/text_generation/text_generation_paddle.ipynb diff --git a/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb new file mode 100644 index 00000000..4e544b82 --- /dev/null +++ b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb @@ -0,0 +1,666 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MNIST数据集使用LeNet进行图像分类\n", + "本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。\n", + "手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.0.0-alpha0\n" + ] + } + ], + "source": [ + "import paddle\n", + "print(paddle.__version__)\n", + "paddle.enable_imperative()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 加载数据集\n", + "我们使用飞桨自带的paddle.dataset完成mnist数据集的加载。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "download training data and load training data\n", + "load finished\n", + "\n" + ] + } + ], + "source": [ + "print('download training data and load training data')\n", + "train_dataset = paddle.incubate.hapi.datasets.MNIST(mode='train')\n", + "test_dataset = paddle.incubate.hapi.datasets.MNIST(mode='test')\n", + "print('load finished')\n", + "print(type(train_dataset[0][0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "取训练集中的一条数据看一下。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_data0 label is: [5]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIY0lEQVR4nO3dXWhUZxoH8P/jaPxav7KREtNgiooQFvwg1l1cNOr6sQUN3ixR0VUK9cKPXTBYs17ohReLwl5ovCmuZMU1y+IaWpdC0GIuxCJJMLhJa6oWtSl+FVEXvdDK24s5nc5zapKTZ86cOTPz/4Hk/M8xc17w8Z13zpl5RpxzIBquEbkeAOUnFg6ZsHDIhIVDJiwcMmHhkElGhSMiq0WkT0RuisjesAZF8SfW6zgikgDwFYAVAPoBdABY75z7IrzhUVyNzOB33wVw0zn3NQCIyL8A1AEYsHDKyspcVVVVBqekqHV1dX3nnJvq359J4VQA+CYt9wNYONgvVFVVobOzM4NTUtRE5M6b9md9cSwiH4hIp4h0Pnr0KNuno4hkUjjfAqhMy297+xTn3EfOuRrnXM3UqT+b8ShPZVI4HQBmicg7IlICoB7AJ+EMi+LOvMZxzn0vIjsAtAFIADjhnOsNbWQUa5ksjuGc+xTApyGNhfIIrxyTCQuHTFg4ZMLCIRMWDpmwcMiEhUMmLBwyYeGQCQuHTFg4ZMLCIZOMbnIWk9evX6v89OnTwL/b1NSk8osXL1Tu6+tT+dixYyo3NDSo3NLSovKYMWNU3rv3p88N7N+/P/A4h4MzDpmwcMiEhUMmRbPGuXv3rsovX75U+fLlyypfunRJ5SdPnqh85syZ0MZWWVmp8s6dO1VubW1VecKECSrPmTNH5SVLloQ2toFwxiETFg6ZsHDIpGDXOFevXlV52bJlKg/nOkzYEomEygcPHlR5/PjxKm/cuFHladOmqTxlyhSVZ8+enekQh8QZh0xYOGTCwiGTgl3jTJ8+XeWysjKVw1zjLFyom3T41xwXL15UuaSkROVNmzaFNpaocMYhExYOmbBwyKRg1zilpaUqHz58WOVz586pPG/ePJV37do16OPPnTs3tX3hwgV1zH8dpqenR+UjR44M+tj5gDMOmQxZOCJyQkQeikhP2r5SETkvIje8n1MGewwqPEFmnGYAq3379gL4zDk3C8BnXqYiEqjPsYhUAfivc+5XXu4DUOucuyci5QDanXND3iCpqalxcek6+uzZM5X973HZtm2bysePH1f51KlTqe0NGzaEPLr4EJEu51yNf791jfOWc+6et30fwFvmkVFeynhx7JJT1oDTFtvVFiZr4TzwnqLg/Xw40F9ku9rCZL2O8wmAPwL4q/fz49BGFJGJEycOenzSpEmDHk9f89TX16tjI0YU/lWOIC/HWwB8DmC2iPSLyPtIFswKEbkB4HdepiIy5IzjnFs/wKHlIY+F8kjhz6mUFQV7rypTBw4cULmrq0vl9vb21Lb/XtXKlSuzNazY4IxDJiwcMmHhkIn5Ozkt4nSvarhu3bql8vz581PbkydPVseWLl2qck2NvtWzfft2lUUkhBFmR9j3qqjIsXDIhC/HA5oxY4bKzc3Nqe2tW7eqYydPnhw0P3/+XOXNmzerXF5ebh1mZDjjkAkLh0xYOGTCNY7RunXrUtszZ85Ux3bv3q2y/5ZEY2Ojynfu6O+E37dvn8oVFRXmcWYLZxwyYeGQCQuHTHjLIQv8rW39HzfesmWLyv5/g+XL9Xvkzp8/H9rYhou3HChULBwyYeGQCdc4OTB69GiVX716pfKoUaNUbmtrU7m2tjYr43oTrnEoVCwcMmHhkAnvVYXg2rVrKvu/kqijo0Nl/5rGr7q6WuXFixdnMLrs4IxDJiwcMmHhkAnXOAH5v+L56NGjqe2zZ8+qY/fv3x/WY48cqf8Z/O85jmPblPiNiPJCkP44lSJyUUS+EJFeEfmTt58ta4tYkBnnewC7nXPVAH4NYLuIVIMta4takMZK9wDc87b/LyJfAqgAUAeg1vtr/wDQDuDDrIwyAv51yenTp1VuampS+fbt2+ZzLViwQGX/e4zXrl1rfuyoDGuN4/U7ngfgCtiytqgFLhwR+QWA/wD4s3NOdZcerGUt29UWpkCFIyKjkCyafzrnfnztGahlLdvVFqYh1ziS7MHxdwBfOuf+lnYor1rWPnjwQOXe3l6Vd+zYofL169fN5/J/1eKePXtUrqurUzmO12mGEuQC4CIAmwD8T0S6vX1/QbJg/u21r70D4A9ZGSHFUpBXVZcADNT5hy1ri1T+zZEUCwVzr+rx48cq+782qLu7W2V/a7bhWrRoUWrb/1nxVatWqTx27NiMzhVHnHHIhIVDJiwcMsmrNc6VK1dS24cOHVLH/O/r7e/vz+hc48aNU9n/ddLp95f8XxddDDjjkAkLh0zy6qmqtbX1jdtB+D9ysmbNGpUTiYTKDQ0NKvu7pxc7zjhkwsIhExYOmbDNCQ2KbU4oVCwcMmHhkAkLh0xYOGTCwiETFg6ZsHDIhIVDJiwcMmHhkEmk96pE5BGSn/osA/BdZCcenriOLVfjmu6c+9mH/iMtnNRJRTrfdOMsDuI6triNi09VZMLCIZNcFc5HOTpvEHEdW6zGlZM1DuU/PlWRSaSFIyKrRaRPRG6KSE7b24rICRF5KCI9afti0bs5H3pLR1Y4IpIAcAzA7wFUA1jv9UvOlWYAq3374tK7Of69pZ1zkfwB8BsAbWm5EUBjVOcfYExVAHrSch+Acm+7HEBfLseXNq6PAayI0/iifKqqAPBNWu739sVJ7Ho3x7W3NBfHA3DJ/9Y5fclp7S0dhSgL51sAlWn5bW9fnATq3RyFTHpLRyHKwukAMEtE3hGREgD1SPZKjpMfezcDOezdHKC3NJDr3tIRL/LeA/AVgFsA9uV4wdmC5JebvEJyvfU+gF8i+WrlBoALAEpzNLbfIvk0dA1At/fnvbiMzznHK8dkw8UxmbBwyISFQyYsHDJh4ZAJC4dMWDhkwsIhkx8AyyZIbAmqetUAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]\n", + "train_data0 = train_data0.transpose(1,2,0)\n", + "plt.figure(figsize=(2,2))\n", + "plt.imshow(train_data0, cmap=plt.cm.binary)\n", + "print('train_data0 label is: ' + str(train_label_0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2.组网&训练方案1\n", + "paddle支持用model类,直接完成模型的训练,具体如下。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 首先需要继承Model来自定义LeNet网络。" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class LeNet(paddle.incubate.hapi.model.Model):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n", + " self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n", + " self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10, act='softmax')\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.max_pool1(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " x = self.linear3(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 初始化Model,并定义相关的参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.incubate.hapi.model import Input\n", + "from paddle.incubate.hapi.loss import CrossEntropy\n", + "from paddle.incubate.hapi.metrics import Accuracy\n", + "\n", + "inputs = [Input([None, 1, 28, 28], 'float32', name='image')]\n", + "labels = [Input([None, 1], 'int64', name='label')]\n", + "model = LeNet()\n", + "optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n", + "\n", + "model.prepare(\n", + " optim,\n", + " CrossEntropy(),\n", + " Accuracy(topk=(1, 2)),\n", + " inputs=inputs,\n", + " labels=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 使用fit来训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/2\n", + "step 10/938 - loss: 2.1912 - acc_top1: 0.2719 - acc_top2: 0.4109 - 16ms/step\n", + "step 20/938 - loss: 1.6389 - acc_top1: 0.4109 - acc_top2: 0.5367 - 15ms/step\n", + "step 30/938 - loss: 1.1486 - acc_top1: 0.4797 - acc_top2: 0.6135 - 15ms/step\n", + "step 40/938 - loss: 0.7755 - acc_top1: 0.5484 - acc_top2: 0.6770 - 15ms/step\n", + "step 50/938 - loss: 0.7651 - acc_top1: 0.5975 - acc_top2: 0.7266 - 15ms/step\n", + "step 60/938 - loss: 0.3837 - acc_top1: 0.6393 - acc_top2: 0.7617 - 15ms/step\n", + "step 70/938 - loss: 0.6532 - acc_top1: 0.6712 - acc_top2: 0.7888 - 15ms/step\n", + "step 80/938 - loss: 0.3394 - acc_top1: 0.6969 - acc_top2: 0.8107 - 15ms/step\n", + "step 90/938 - loss: 0.2527 - acc_top1: 0.7189 - acc_top2: 0.8283 - 15ms/step\n", + "step 100/938 - loss: 0.2055 - acc_top1: 0.7389 - acc_top2: 0.8427 - 14ms/step\n", + "step 110/938 - loss: 0.3987 - acc_top1: 0.7531 - acc_top2: 0.8536 - 14ms/step\n", + "step 120/938 - loss: 0.2372 - acc_top1: 0.7660 - acc_top2: 0.8622 - 14ms/step\n", + "step 130/938 - loss: 0.4071 - acc_top1: 0.7780 - acc_top2: 0.8708 - 14ms/step\n", + "step 140/938 - loss: 0.1315 - acc_top1: 0.7895 - acc_top2: 0.8780 - 14ms/step\n", + "step 150/938 - loss: 0.3168 - acc_top1: 0.7981 - acc_top2: 0.8843 - 15ms/step\n", + "step 160/938 - loss: 0.2782 - acc_top1: 0.8063 - acc_top2: 0.8901 - 15ms/step\n", + "step 170/938 - loss: 0.2030 - acc_top1: 0.8144 - acc_top2: 0.8956 - 15ms/step\n", + "step 180/938 - loss: 0.2336 - acc_top1: 0.8203 - acc_top2: 0.9000 - 15ms/step\n", + "step 190/938 - loss: 0.5915 - acc_top1: 0.8260 - acc_top2: 0.9038 - 15ms/step\n", + "step 200/938 - loss: 0.4995 - acc_top1: 0.8310 - acc_top2: 0.9076 - 15ms/step\n", + "step 210/938 - loss: 0.2190 - acc_top1: 0.8359 - acc_top2: 0.9106 - 15ms/step\n", + "step 220/938 - loss: 0.1835 - acc_top1: 0.8397 - acc_top2: 0.9130 - 15ms/step\n", + "step 230/938 - loss: 0.1321 - acc_top1: 0.8442 - acc_top2: 0.9159 - 15ms/step\n", + "step 240/938 - loss: 0.2406 - acc_top1: 0.8478 - acc_top2: 0.9183 - 15ms/step\n", + "step 250/938 - loss: 0.1245 - acc_top1: 0.8518 - acc_top2: 0.9209 - 15ms/step\n", + "step 260/938 - loss: 0.1570 - acc_top1: 0.8559 - acc_top2: 0.9236 - 15ms/step\n", + "step 270/938 - loss: 0.1647 - acc_top1: 0.8593 - acc_top2: 0.9259 - 15ms/step\n", + "step 280/938 - loss: 0.1876 - acc_top1: 0.8625 - acc_top2: 0.9281 - 14ms/step\n", + "step 290/938 - loss: 0.2247 - acc_top1: 0.8650 - acc_top2: 0.9300 - 15ms/step\n", + "step 300/938 - loss: 0.2070 - acc_top1: 0.8679 - acc_top2: 0.9318 - 15ms/step\n", + "step 310/938 - loss: 0.1122 - acc_top1: 0.8701 - acc_top2: 0.9333 - 14ms/step\n", + "step 320/938 - loss: 0.0857 - acc_top1: 0.8729 - acc_top2: 0.9351 - 14ms/step\n", + "step 330/938 - loss: 0.2414 - acc_top1: 0.8751 - acc_top2: 0.9365 - 14ms/step\n", + "step 340/938 - loss: 0.2631 - acc_top1: 0.8774 - acc_top2: 0.9380 - 14ms/step\n", + "step 350/938 - loss: 0.1347 - acc_top1: 0.8796 - acc_top2: 0.9396 - 14ms/step\n", + "step 360/938 - loss: 0.2295 - acc_top1: 0.8816 - acc_top2: 0.9409 - 14ms/step\n", + "step 370/938 - loss: 0.2971 - acc_top1: 0.8842 - acc_top2: 0.9423 - 14ms/step\n", + "step 380/938 - loss: 0.1623 - acc_top1: 0.8863 - acc_top2: 0.9436 - 14ms/step\n", + "step 390/938 - loss: 0.1020 - acc_top1: 0.8880 - acc_top2: 0.9448 - 14ms/step\n", + "step 400/938 - loss: 0.0716 - acc_top1: 0.8895 - acc_top2: 0.9459 - 14ms/step\n", + "step 410/938 - loss: 0.0889 - acc_top1: 0.8914 - acc_top2: 0.9469 - 14ms/step\n", + "step 420/938 - loss: 0.1010 - acc_top1: 0.8931 - acc_top2: 0.9478 - 14ms/step\n", + "step 430/938 - loss: 0.0486 - acc_top1: 0.8945 - acc_top2: 0.9487 - 14ms/step\n", + "step 440/938 - loss: 0.1723 - acc_top1: 0.8958 - acc_top2: 0.9495 - 14ms/step\n", + "step 450/938 - loss: 0.2270 - acc_top1: 0.8974 - acc_top2: 0.9503 - 14ms/step\n", + "step 460/938 - loss: 0.1197 - acc_top1: 0.8987 - acc_top2: 0.9512 - 14ms/step\n", + "step 470/938 - loss: 0.2837 - acc_top1: 0.9002 - acc_top2: 0.9519 - 14ms/step\n", + "step 480/938 - loss: 0.1091 - acc_top1: 0.9017 - acc_top2: 0.9528 - 14ms/step\n", + "step 490/938 - loss: 0.1397 - acc_top1: 0.9029 - acc_top2: 0.9535 - 14ms/step\n", + "step 500/938 - loss: 0.1034 - acc_top1: 0.9040 - acc_top2: 0.9543 - 14ms/step\n", + "step 510/938 - loss: 0.0095 - acc_top1: 0.9054 - acc_top2: 0.9550 - 14ms/step\n", + "step 520/938 - loss: 0.0092 - acc_top1: 0.9068 - acc_top2: 0.9558 - 14ms/step\n", + "step 530/938 - loss: 0.0633 - acc_top1: 0.9077 - acc_top2: 0.9565 - 14ms/step\n", + "step 540/938 - loss: 0.0936 - acc_top1: 0.9086 - acc_top2: 0.9571 - 14ms/step\n", + "step 550/938 - loss: 0.1180 - acc_top1: 0.9097 - acc_top2: 0.9577 - 14ms/step\n", + "step 560/938 - loss: 0.1600 - acc_top1: 0.9106 - acc_top2: 0.9583 - 14ms/step\n", + "step 570/938 - loss: 0.1338 - acc_top1: 0.9118 - acc_top2: 0.9590 - 14ms/step\n", + "step 580/938 - loss: 0.0496 - acc_top1: 0.9128 - acc_top2: 0.9595 - 14ms/step\n", + "step 590/938 - loss: 0.0651 - acc_top1: 0.9138 - acc_top2: 0.9600 - 14ms/step\n", + "step 600/938 - loss: 0.1306 - acc_top1: 0.9147 - acc_top2: 0.9605 - 14ms/step\n", + "step 610/938 - loss: 0.0744 - acc_top1: 0.9157 - acc_top2: 0.9610 - 14ms/step\n", + "step 620/938 - loss: 0.1679 - acc_top1: 0.9166 - acc_top2: 0.9616 - 14ms/step\n", + "step 630/938 - loss: 0.0789 - acc_top1: 0.9173 - acc_top2: 0.9621 - 14ms/step\n", + "step 640/938 - loss: 0.0767 - acc_top1: 0.9182 - acc_top2: 0.9626 - 14ms/step\n", + "step 650/938 - loss: 0.1776 - acc_top1: 0.9188 - acc_top2: 0.9630 - 14ms/step\n", + "step 660/938 - loss: 0.1371 - acc_top1: 0.9196 - acc_top2: 0.9634 - 14ms/step\n", + "step 670/938 - loss: 0.1011 - acc_top1: 0.9204 - acc_top2: 0.9639 - 14ms/step\n", + "step 680/938 - loss: 0.0447 - acc_top1: 0.9209 - acc_top2: 0.9642 - 14ms/step\n", + "step 690/938 - loss: 0.0230 - acc_top1: 0.9217 - acc_top2: 0.9646 - 14ms/step\n", + "step 700/938 - loss: 0.0541 - acc_top1: 0.9224 - acc_top2: 0.9649 - 14ms/step\n", + "step 710/938 - loss: 0.1395 - acc_top1: 0.9231 - acc_top2: 0.9653 - 14ms/step\n", + "step 720/938 - loss: 0.0426 - acc_top1: 0.9238 - acc_top2: 0.9657 - 14ms/step\n", + "step 730/938 - loss: 0.0540 - acc_top1: 0.9247 - acc_top2: 0.9660 - 14ms/step\n", + "step 740/938 - loss: 0.1132 - acc_top1: 0.9253 - acc_top2: 0.9664 - 14ms/step\n", + "step 750/938 - loss: 0.0088 - acc_top1: 0.9261 - acc_top2: 0.9668 - 14ms/step\n", + "step 760/938 - loss: 0.0282 - acc_top1: 0.9266 - acc_top2: 0.9672 - 14ms/step\n", + "step 770/938 - loss: 0.1233 - acc_top1: 0.9272 - acc_top2: 0.9675 - 14ms/step\n", + "step 780/938 - loss: 0.2208 - acc_top1: 0.9275 - acc_top2: 0.9677 - 14ms/step\n", + "step 790/938 - loss: 0.0599 - acc_top1: 0.9281 - acc_top2: 0.9680 - 14ms/step\n", + "step 800/938 - loss: 0.0270 - acc_top1: 0.9287 - acc_top2: 0.9683 - 14ms/step\n", + "step 810/938 - loss: 0.1546 - acc_top1: 0.9291 - acc_top2: 0.9687 - 14ms/step\n", + "step 820/938 - loss: 0.0252 - acc_top1: 0.9297 - acc_top2: 0.9689 - 14ms/step\n", + "step 830/938 - loss: 0.0276 - acc_top1: 0.9304 - acc_top2: 0.9693 - 14ms/step\n", + "step 840/938 - loss: 0.0620 - acc_top1: 0.9309 - acc_top2: 0.9695 - 14ms/step\n", + "step 850/938 - loss: 0.0505 - acc_top1: 0.9314 - acc_top2: 0.9699 - 14ms/step\n", + "step 860/938 - loss: 0.0156 - acc_top1: 0.9319 - acc_top2: 0.9701 - 14ms/step\n", + "step 870/938 - loss: 0.0229 - acc_top1: 0.9325 - acc_top2: 0.9704 - 14ms/step\n", + "step 880/938 - loss: 0.0498 - acc_top1: 0.9330 - acc_top2: 0.9707 - 14ms/step\n", + "step 890/938 - loss: 0.0183 - acc_top1: 0.9335 - acc_top2: 0.9710 - 14ms/step\n", + "step 900/938 - loss: 0.1282 - acc_top1: 0.9339 - acc_top2: 0.9712 - 14ms/step\n", + "step 910/938 - loss: 0.0426 - acc_top1: 0.9342 - acc_top2: 0.9715 - 14ms/step\n", + "step 920/938 - loss: 0.0641 - acc_top1: 0.9347 - acc_top2: 0.9717 - 14ms/step\n", + "step 930/938 - loss: 0.0745 - acc_top1: 0.9351 - acc_top2: 0.9719 - 14ms/step\n", + "step 938/938 - loss: 0.0118 - acc_top1: 0.9354 - acc_top2: 0.9721 - 14ms/step\n", + "save checkpoint at mnist_checkpoint/0\n", + "Eval begin...\n", + "step 10/157 - loss: 0.1032 - acc_top1: 0.9828 - acc_top2: 0.9969 - 5ms/step\n", + "step 20/157 - loss: 0.2664 - acc_top1: 0.9781 - acc_top2: 0.9953 - 5ms/step\n", + "step 30/157 - loss: 0.1626 - acc_top1: 0.9766 - acc_top2: 0.9943 - 5ms/step\n", + "step 40/157 - loss: 0.0247 - acc_top1: 0.9734 - acc_top2: 0.9926 - 5ms/step\n", + "step 50/157 - loss: 0.0225 - acc_top1: 0.9738 - acc_top2: 0.9925 - 5ms/step\n", + "step 60/157 - loss: 0.2119 - acc_top1: 0.9737 - acc_top2: 0.9927 - 5ms/step\n", + "step 70/157 - loss: 0.0559 - acc_top1: 0.9723 - acc_top2: 0.9920 - 5ms/step\n", + "step 80/157 - loss: 0.0329 - acc_top1: 0.9725 - acc_top2: 0.9918 - 5ms/step\n", + "step 90/157 - loss: 0.1064 - acc_top1: 0.9741 - acc_top2: 0.9925 - 5ms/step\n", + "step 100/157 - loss: 0.0027 - acc_top1: 0.9744 - acc_top2: 0.9923 - 5ms/step\n", + "step 110/157 - loss: 0.0044 - acc_top1: 0.9750 - acc_top2: 0.9925 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 120/157 - loss: 0.0093 - acc_top1: 0.9768 - acc_top2: 0.9931 - 5ms/step\n", + "step 130/157 - loss: 0.1247 - acc_top1: 0.9774 - acc_top2: 0.9935 - 5ms/step\n", + "step 140/157 - loss: 0.0031 - acc_top1: 0.9785 - acc_top2: 0.9940 - 5ms/step\n", + "step 150/157 - loss: 0.0495 - acc_top1: 0.9794 - acc_top2: 0.9944 - 5ms/step\n", + "step 157/157 - loss: 0.0020 - acc_top1: 0.9790 - acc_top2: 0.9944 - 5ms/step\n", + "Eval samples: 10000\n", + "Epoch 2/2\n", + "step 10/938 - loss: 0.1735 - acc_top1: 0.9766 - acc_top2: 0.9938 - 16ms/step\n", + "step 20/938 - loss: 0.0723 - acc_top1: 0.9750 - acc_top2: 0.9922 - 15ms/step\n", + "step 30/938 - loss: 0.0593 - acc_top1: 0.9781 - acc_top2: 0.9927 - 15ms/step\n", + "step 40/938 - loss: 0.1243 - acc_top1: 0.9793 - acc_top2: 0.9938 - 15ms/step\n", + "step 50/938 - loss: 0.0127 - acc_top1: 0.9797 - acc_top2: 0.9944 - 15ms/step\n", + "step 60/938 - loss: 0.0319 - acc_top1: 0.9779 - acc_top2: 0.9938 - 15ms/step\n", + "step 70/938 - loss: 0.0404 - acc_top1: 0.9783 - acc_top2: 0.9946 - 15ms/step\n", + "step 80/938 - loss: 0.1120 - acc_top1: 0.9781 - acc_top2: 0.9943 - 15ms/step\n", + "step 90/938 - loss: 0.0222 - acc_top1: 0.9780 - acc_top2: 0.9944 - 15ms/step\n", + "step 100/938 - loss: 0.0726 - acc_top1: 0.9788 - acc_top2: 0.9948 - 15ms/step\n", + "step 110/938 - loss: 0.0255 - acc_top1: 0.9790 - acc_top2: 0.9952 - 15ms/step\n", + "step 120/938 - loss: 0.2556 - acc_top1: 0.9790 - acc_top2: 0.9948 - 15ms/step\n", + "step 130/938 - loss: 0.0795 - acc_top1: 0.9786 - acc_top2: 0.9945 - 15ms/step\n", + "step 140/938 - loss: 0.1106 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 150/938 - loss: 0.0564 - acc_top1: 0.9784 - acc_top2: 0.9946 - 15ms/step\n", + "step 160/938 - loss: 0.1016 - acc_top1: 0.9784 - acc_top2: 0.9947 - 15ms/step\n", + "step 170/938 - loss: 0.0665 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n", + "step 180/938 - loss: 0.0443 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n", + "step 190/938 - loss: 0.0696 - acc_top1: 0.9789 - acc_top2: 0.9947 - 15ms/step\n", + "step 200/938 - loss: 0.0552 - acc_top1: 0.9791 - acc_top2: 0.9948 - 15ms/step\n", + "step 210/938 - loss: 0.1540 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n", + "step 220/938 - loss: 0.0422 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n", + "step 230/938 - loss: 0.2994 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 240/938 - loss: 0.0246 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 250/938 - loss: 0.0802 - acc_top1: 0.9788 - acc_top2: 0.9946 - 15ms/step\n", + "step 260/938 - loss: 0.1142 - acc_top1: 0.9787 - acc_top2: 0.9947 - 15ms/step\n", + "step 270/938 - loss: 0.0195 - acc_top1: 0.9785 - acc_top2: 0.9946 - 15ms/step\n", + "step 280/938 - loss: 0.0559 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 290/938 - loss: 0.1101 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n", + "step 300/938 - loss: 0.0078 - acc_top1: 0.9786 - acc_top2: 0.9943 - 15ms/step\n", + "step 310/938 - loss: 0.0877 - acc_top1: 0.9789 - acc_top2: 0.9944 - 15ms/step\n", + "step 320/938 - loss: 0.0919 - acc_top1: 0.9790 - acc_top2: 0.9945 - 15ms/step\n", + "step 330/938 - loss: 0.0395 - acc_top1: 0.9789 - acc_top2: 0.9945 - 15ms/step\n", + "step 340/938 - loss: 0.1892 - acc_top1: 0.9787 - acc_top2: 0.9945 - 15ms/step\n", + "step 350/938 - loss: 0.0457 - acc_top1: 0.9784 - acc_top2: 0.9944 - 15ms/step\n", + "step 360/938 - loss: 0.1036 - acc_top1: 0.9786 - acc_top2: 0.9944 - 15ms/step\n", + "step 370/938 - loss: 0.0614 - acc_top1: 0.9785 - acc_top2: 0.9944 - 15ms/step\n", + "step 380/938 - loss: 0.2316 - acc_top1: 0.9787 - acc_top2: 0.9944 - 15ms/step\n", + "step 390/938 - loss: 0.0126 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 400/938 - loss: 0.0614 - acc_top1: 0.9789 - acc_top2: 0.9946 - 15ms/step\n", + "step 410/938 - loss: 0.0374 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 420/938 - loss: 0.0924 - acc_top1: 0.9788 - acc_top2: 0.9945 - 15ms/step\n", + "step 430/938 - loss: 0.0151 - acc_top1: 0.9791 - acc_top2: 0.9946 - 15ms/step\n", + "step 440/938 - loss: 0.0223 - acc_top1: 0.9791 - acc_top2: 0.9947 - 15ms/step\n", + "step 450/938 - loss: 0.0111 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n", + "step 460/938 - loss: 0.0112 - acc_top1: 0.9793 - acc_top2: 0.9947 - 15ms/step\n", + "step 470/938 - loss: 0.0239 - acc_top1: 0.9794 - acc_top2: 0.9947 - 15ms/step\n", + "step 480/938 - loss: 0.0821 - acc_top1: 0.9795 - acc_top2: 0.9948 - 15ms/step\n", + "step 490/938 - loss: 0.0493 - acc_top1: 0.9796 - acc_top2: 0.9948 - 15ms/step\n", + "step 500/938 - loss: 0.0627 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 510/938 - loss: 0.0331 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 520/938 - loss: 0.0831 - acc_top1: 0.9797 - acc_top2: 0.9949 - 15ms/step\n", + "step 530/938 - loss: 0.0687 - acc_top1: 0.9796 - acc_top2: 0.9949 - 15ms/step\n", + "step 540/938 - loss: 0.1556 - acc_top1: 0.9794 - acc_top2: 0.9949 - 15ms/step\n", + "step 550/938 - loss: 0.2394 - acc_top1: 0.9795 - acc_top2: 0.9950 - 15ms/step\n", + "step 560/938 - loss: 0.0353 - acc_top1: 0.9794 - acc_top2: 0.9950 - 15ms/step\n", + "step 570/938 - loss: 0.0179 - acc_top1: 0.9794 - acc_top2: 0.9951 - 15ms/step\n", + "step 580/938 - loss: 0.0307 - acc_top1: 0.9796 - acc_top2: 0.9951 - 15ms/step\n", + "step 590/938 - loss: 0.0806 - acc_top1: 0.9796 - acc_top2: 0.9952 - 15ms/step\n", + "step 600/938 - loss: 0.0320 - acc_top1: 0.9796 - acc_top2: 0.9953 - 15ms/step\n", + "step 610/938 - loss: 0.0201 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n", + "step 620/938 - loss: 0.1524 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n", + "step 630/938 - loss: 0.0062 - acc_top1: 0.9797 - acc_top2: 0.9953 - 15ms/step\n", + "step 640/938 - loss: 0.0908 - acc_top1: 0.9798 - acc_top2: 0.9953 - 15ms/step\n", + "step 650/938 - loss: 0.0467 - acc_top1: 0.9799 - acc_top2: 0.9954 - 15ms/step\n", + "step 660/938 - loss: 0.0156 - acc_top1: 0.9801 - acc_top2: 0.9954 - 15ms/step\n", + "step 670/938 - loss: 0.0318 - acc_top1: 0.9802 - acc_top2: 0.9955 - 15ms/step\n", + "step 680/938 - loss: 0.0133 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 690/938 - loss: 0.0651 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 700/938 - loss: 0.0052 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 710/938 - loss: 0.1208 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 720/938 - loss: 0.1519 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 730/938 - loss: 0.0954 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 740/938 - loss: 0.0059 - acc_top1: 0.9806 - acc_top2: 0.9955 - 15ms/step\n", + "step 750/938 - loss: 0.1000 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 760/938 - loss: 0.0629 - acc_top1: 0.9805 - acc_top2: 0.9955 - 15ms/step\n", + "step 770/938 - loss: 0.0182 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 780/938 - loss: 0.0215 - acc_top1: 0.9804 - acc_top2: 0.9955 - 15ms/step\n", + "step 790/938 - loss: 0.0418 - acc_top1: 0.9804 - acc_top2: 0.9956 - 15ms/step\n", + "step 800/938 - loss: 0.0132 - acc_top1: 0.9805 - acc_top2: 0.9956 - 15ms/step\n", + "step 810/938 - loss: 0.0546 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 820/938 - loss: 0.0373 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 830/938 - loss: 0.0965 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 840/938 - loss: 0.0143 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 850/938 - loss: 0.0578 - acc_top1: 0.9806 - acc_top2: 0.9956 - 15ms/step\n", + "step 860/938 - loss: 0.0205 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 870/938 - loss: 0.0384 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 880/938 - loss: 0.0157 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 890/938 - loss: 0.0457 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 900/938 - loss: 0.0202 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 910/938 - loss: 0.0240 - acc_top1: 0.9807 - acc_top2: 0.9956 - 15ms/step\n", + "step 920/938 - loss: 0.0585 - acc_top1: 0.9808 - acc_top2: 0.9956 - 15ms/step\n", + "step 930/938 - loss: 0.0414 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n", + "step 938/938 - loss: 0.0180 - acc_top1: 0.9809 - acc_top2: 0.9956 - 15ms/step\n", + "save checkpoint at mnist_checkpoint/1\n", + "Eval begin...\n", + "step 10/157 - loss: 0.1093 - acc_top1: 0.9828 - acc_top2: 0.9984 - 5ms/step\n", + "step 20/157 - loss: 0.2292 - acc_top1: 0.9789 - acc_top2: 0.9969 - 5ms/step\n", + "step 30/157 - loss: 0.1203 - acc_top1: 0.9797 - acc_top2: 0.9969 - 5ms/step\n", + "step 40/157 - loss: 0.0068 - acc_top1: 0.9773 - acc_top2: 0.9961 - 5ms/step\n", + "step 50/157 - loss: 0.0049 - acc_top1: 0.9775 - acc_top2: 0.9959 - 5ms/step\n", + "step 60/157 - loss: 0.0399 - acc_top1: 0.9779 - acc_top2: 0.9956 - 5ms/step\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 70/157 - loss: 0.0299 - acc_top1: 0.9768 - acc_top2: 0.9953 - 5ms/step\n", + "step 80/157 - loss: 0.0108 - acc_top1: 0.9771 - acc_top2: 0.9955 - 5ms/step\n", + "step 90/157 - loss: 0.0209 - acc_top1: 0.9793 - acc_top2: 0.9958 - 5ms/step\n", + "step 100/157 - loss: 0.0031 - acc_top1: 0.9806 - acc_top2: 0.9962 - 5ms/step\n", + "step 110/157 - loss: 4.0509e-04 - acc_top1: 0.9808 - acc_top2: 0.9962 - 5ms/step\n", + "step 120/157 - loss: 8.9143e-04 - acc_top1: 0.9820 - acc_top2: 0.9965 - 5ms/step\n", + "step 130/157 - loss: 0.0119 - acc_top1: 0.9833 - acc_top2: 0.9968 - 5ms/step\n", + "step 140/157 - loss: 6.7999e-04 - acc_top1: 0.9844 - acc_top2: 0.9970 - 5ms/step\n", + "step 150/157 - loss: 0.0047 - acc_top1: 0.9853 - acc_top2: 0.9972 - 5ms/step\n", + "step 157/157 - loss: 1.6522e-04 - acc_top1: 0.9847 - acc_top2: 0.9973 - 5ms/step\n", + "Eval samples: 10000\n", + "save checkpoint at mnist_checkpoint/final\n" + ] + } + ], + "source": [ + "model.fit(train_dataset,\n", + " test_dataset,\n", + " epochs=2,\n", + " batch_size=64,\n", + " save_dir='mnist_checkpoint')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 组网&训练方式1结束\n", + "以上就是组网&训练方式1,可以非常快速的完成网络模型的构建与训练。此外,paddle还可以用下面的方式来完成模型的训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3.组网&训练方式2\n", + "方式1可以快速便捷的完成组网&训练,将细节都隐藏了起来。而方式2则可以用最基本的方式,完成模型的组网与训练。具体如下。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 通过继承Layer的方式来构建模型" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class LeNet(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(LeNet, self).__init__()\n", + " self.conv1 = paddle.nn.Conv2D(num_channels=1, num_filters=6, filter_size=5, stride=1, padding=2, act='relu')\n", + " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.conv2 = paddle.nn.Conv2D(num_channels=6, num_filters=16, filter_size=5, stride=1, act='relu')\n", + " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.linear1 = paddle.nn.Linear(input_dim=16*5*5, output_dim=120, act='relu')\n", + " self.linear2 = paddle.nn.Linear(input_dim=120, output_dim=84, act='relu')\n", + " self.linear3 = paddle.nn.Linear(input_dim=84, output_dim=10,act='softmax')\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.max_pool1(x)\n", + " x = self.conv2(x)\n", + " x = self.max_pool2(x)\n", + " x = paddle.reshape(x, shape=[-1, 16*5*5])\n", + " x = self.linear1(x)\n", + " x = self.linear2(x)\n", + " x = self.linear3(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [2.2982373], acc is: [0.15625]\n", + "epoch: 0, batch_id: 100, loss is: [0.25794172], acc is: [0.96875]\n", + "epoch: 0, batch_id: 200, loss is: [0.25025752], acc is: [0.984375]\n", + "epoch: 0, batch_id: 300, loss is: [0.17673397], acc is: [0.984375]\n", + "epoch: 0, batch_id: 400, loss is: [0.09535598], acc is: [1.]\n", + "epoch: 0, batch_id: 500, loss is: [0.08496016], acc is: [1.]\n", + "epoch: 0, batch_id: 600, loss is: [0.14111154], acc is: [0.984375]\n", + "epoch: 0, batch_id: 700, loss is: [0.07322718], acc is: [0.984375]\n", + "epoch: 0, batch_id: 800, loss is: [0.2417614], acc is: [0.984375]\n", + "epoch: 0, batch_id: 900, loss is: [0.10721541], acc is: [1.]\n", + "epoch: 1, batch_id: 0, loss is: [0.02449418], acc is: [1.]\n", + "epoch: 1, batch_id: 100, loss is: [0.151768], acc is: [0.984375]\n", + "epoch: 1, batch_id: 200, loss is: [0.06956144], acc is: [0.984375]\n", + "epoch: 1, batch_id: 300, loss is: [0.2008793], acc is: [1.]\n", + "epoch: 1, batch_id: 400, loss is: [0.03839134], acc is: [1.]\n", + "epoch: 1, batch_id: 500, loss is: [0.0217573], acc is: [1.]\n", + "epoch: 1, batch_id: 600, loss is: [0.10977131], acc is: [0.984375]\n", + "epoch: 1, batch_id: 700, loss is: [0.02774046], acc is: [1.]\n", + "epoch: 1, batch_id: 800, loss is: [0.13530938], acc is: [0.984375]\n", + "epoch: 1, batch_id: 900, loss is: [0.0282761], acc is: [1.]\n" + ] + } + ], + "source": [ + "import paddle\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def train(model):\n", + " model.train()\n", + " epochs = 2\n", + " batch_size = 64\n", + " optim = paddle.optimizer.Adam(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = LeNet()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 对模型进行验证" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch_id: 0, loss is: [0.0054796], acc is: [1.]\n", + "batch_id: 100, loss is: [0.12248081], acc is: [0.984375]\n", + "batch_id: 200, loss is: [0.06583288], acc is: [1.]\n", + "batch_id: 300, loss is: [0.07927508], acc is: [1.]\n", + "batch_id: 400, loss is: [0.02623187], acc is: [1.]\n", + "batch_id: 500, loss is: [0.02039231], acc is: [1.]\n", + "batch_id: 600, loss is: [0.03374948], acc is: [1.]\n", + "batch_id: 700, loss is: [0.05141395], acc is: [1.]\n", + "batch_id: 800, loss is: [0.1005884], acc is: [1.]\n", + "batch_id: 900, loss is: [0.03581202], acc is: [1.]\n" + ] + } + ], + "source": [ + "import paddle\n", + "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "def test(model):\n", + " model.eval()\n", + " batch_size = 64\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + "test(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 组网&训练方式2结束\n", + "以上就是组网&训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paddle2.0_docs/n_gram_model/n_gram_model.ipynb b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb new file mode 100644 index 00000000..1e8ab412 --- /dev/null +++ b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 用N-Gram模型在莎士比亚诗中训练word embedding\n", + "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n", + "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n", + "本示例在莎士比亚十四行诗上实现了trigram。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.0.0-alpha0'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "paddle.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据集&&相关参数\n", + "训练数据集采用了莎士比亚十四行诗,CONTEXT_SIZE设为2,意味着是trigram。EMBEDDING_DIM设为10。" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT_SIZE = 2\n", + "EMBEDDING_DIM = 10\n", + "\n", + "test_sentence = \"\"\"When forty winters shall besiege thy brow,\n", + "And dig deep trenches in thy beauty's field,\n", + "Thy youth's proud livery so gazed on now,\n", + "Will be a totter'd weed of small worth held:\n", + "Then being asked, where all thy beauty lies,\n", + "Where all the treasure of thy lusty days;\n", + "To say, within thine own deep sunken eyes,\n", + "Were an all-eating shame, and thriftless praise.\n", + "How much more praise deserv'd thy beauty's use,\n", + "If thou couldst answer 'This fair child of mine\n", + "Shall sum my count, and make my old excuse,'\n", + "Proving his beauty by succession thine!\n", + "This were to be new made when thou art old,\n", + "And see thy blood warm when thou feel'st it cold.\"\"\".split()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据预处理\n", + "将文本被拆成了元组的形式,格式为(('第一个词', '第二个词'), '第三个词');其中,第三个词就是我们的目标。" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(('When', 'forty'), 'winters'), (('forty', 'winters'), 'shall'), (('winters', 'shall'), 'besiege')]\n" + ] + } + ], + "source": [ + "trigram = [((test_sentence[i], test_sentence[i + 1]), test_sentence[i + 2])\n", + " for i in range(len(test_sentence) - 2)]\n", + "\n", + "vocab = set(test_sentence)\n", + "word_to_idx = {word: i for i, word in enumerate(vocab)}\n", + "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n", + "# 看一下数据集\n", + "print(trigram[:3])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 构建`Dataset`类 加载数据" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "class TrainDataset(paddle.io.Dataset):\n", + " def __init__(self, tuple_data, vocab):\n", + " self.tuple_data = tuple_data\n", + " self.vocab = vocab\n", + "\n", + " def __getitem__(self, idx):\n", + " data = list(self.tuple_data[idx][0])\n", + " label = list(self.tuple_data[idx][1])\n", + " return data, label\n", + " \n", + " def __len__(self):\n", + " return len(self.tuple_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 组网&训练\n", + "这里用paddle动态图的方式组网,由于是N-Gram模型,只需要一层`Embedding`与两层`Linear`就可以完成网络模型的构建。" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "class NGramModel(paddle.nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, context_size):\n", + " super(NGramModel, self).__init__()\n", + " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", + " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 128)\n", + " self.linear2 = paddle.nn.Linear(128, vocab_size)\n", + "\n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x = paddle.reshape(x, [1, -1])\n", + " x = self.linear1(x)\n", + " x = paddle.nn.functional.relu(x)\n", + " x = self.linear2(x)\n", + " x = paddle.nn.functional.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 初始化Model,并定义相关的参数。" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, loss is: [4.631529]\n", + "epoch: 50, loss is: [4.6081576]\n", + "epoch: 100, loss is: [4.600631]\n", + "epoch: 150, loss is: [4.603069]\n", + "epoch: 200, loss is: [4.592647]\n", + "epoch: 250, loss is: [4.5626693]\n", + "epoch: 300, loss is: [4.513106]\n", + "epoch: 350, loss is: [4.4345813]\n", + "epoch: 400, loss is: [4.3238697]\n", + "epoch: 450, loss is: [4.1728854]\n", + "epoch: 500, loss is: [3.9622664]\n", + "epoch: 550, loss is: [3.67673]\n", + "epoch: 600, loss is: [3.2998457]\n", + "epoch: 650, loss is: [2.8206367]\n", + "epoch: 700, loss is: [2.2514927]\n", + "epoch: 750, loss is: [1.6479329]\n", + "epoch: 800, loss is: [1.1147357]\n", + "epoch: 850, loss is: [0.73231363]\n", + "epoch: 900, loss is: [0.49481753]\n", + "epoch: 950, loss is: [0.3504072]\n" + ] + } + ], + "source": [ + "vocab_size = len(vocab)\n", + "embedding_dim = 10\n", + "context_size = 2\n", + "\n", + "paddle.enable_imperative()\n", + "losses = []\n", + "def train(model):\n", + " model.train()\n", + " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(1000):\n", + " # 留最后10组作为预测\n", + " for context, target in trigram[:-10]:\n", + " context_idxs = list(map(lambda w: word_to_idx[w], context))\n", + " x_data = paddle.imperative.to_variable(np.array(context_idxs))\n", + " y_data = paddle.imperative.to_variable(np.array([word_to_idx[target]]))\n", + " predicts = model(x_data)\n", + " # print (predicts)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " loss.backward()\n", + " optim.minimize(loss)\n", + " model.clear_gradients()\n", + " if epoch % 50 == 0:\n", + " print(\"epoch: {}, loss is: {}\".format(epoch, loss.numpy()))\n", + " losses.append(loss.numpy())\n", + "model = NGramModel(vocab_size, embedding_dim, context_size)\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 打印loss下降曲线\n", + "通过可视化loss的曲线,可以看到模型训练的效果。" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 123, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdkklEQVR4nO3dd3xV9f3H8dfn3uwBmUAggRghMmUFDQ5Qq4JC3VqtVmuttL/aqq3WamtdtcPaarWOuq1aR0XrAJVaxKKIIghE9oaAAQIBAiE7398fuVhUJCHk5tzxfj4e95Hcc06S98nJ451zv/cMc84hIiKhy+d1ABER2T8VtYhIiFNRi4iEOBW1iEiIU1GLiIS4mGB806ysLJefnx+Mby0iEpHmzJmzxTmXva95QSnq/Px8Zs+eHYxvLSISkcxs7dfN09CHiEiIU1GLiIQ4FbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIC8px1G318PSVNDmI8Vnzw+8jPsZHVko82anxZKXEk5kSR6xf/19EJHqEVFHf9fYyauqbWlwuIzmOrJQ4slPjyU5pLvCs1HhS4mNIjPWTGOcnIdZHjM+H32dffJi1OM3s4NclxucjMbY5h7XHNxSRqBVSRV1y8xgampqob3Q0NjkaGpuorm9ky646tuyqpXxn7Vc+zlm3jfKdta0qeK8kxvpJivOTEPiYGOf//B/K3tPTEuPITIkjMyWerOTmj5kpcaQnxeH3qexFolVIFXVcjI+4fQyb98pM3u/XOeeoqmtkd20DNfXN5V5d30hDYxONTY5G11z8X3nsa7pztMdNb5r/yQSy1DVQXd/I7rpGagIfq+saqaiqY8O2wPP6RnZU19PY9NUf7rPmVxGZyfF06RRPbnoiuelJ9EhLJDc9kR7piXRJTVCZi0SokCrqtjIzUuJjSIkP79VpanLsqK5na1Xt568itu6qY+uuWrZU1bFlZy2bKmt4u6ySLbvqvvC1sX4jp3OguNMCRZ6eGCj1RLp3TsSnIhcJS+HdbBHG5zPSk+NIT46jd5f9L1td18iG7dWs37Y78LGaDduan09fXs7mnbVfeGWQGOunsFsqfbum0jcnlcO6pdI/pxNpSXHBXSkROWgq6jCVGOend5cUendJ2ef82oZGyrbXsGF7NaUVu1m2aRdLNlby9uJNvDC79PPl8jOTGJyXxuDcNAbnpTGgeycSYv0dtRoi0goq6ggVH+MnPyuZ/Kwvju875yjfVcuSsp0s+GwHJaU7mLW6glfnfQY0HxrZL6cTg/M6MyI/gxH5GXRPS/RiFUQkwFx7vHP2JUVFRU7Xow4vmyprmFe6nfml25m/fjvzS3ewq7YBgNz0RI7Iz2DEIRkc0zuLvIwkj9OKRB4zm+OcK9rnPBW17Etjk2NxWSWzVlfw8ZoKZq2uYGtV8xuYBdnJjC7MZlRhNiMLMjVUItIOVNRy0JxzrCzfxfRlW/jvsnI+XLWV2oYmEmP9HN83mzEDunFC3y6kJsR6HVUkLKmopd3V1Dfy4aqt/GfxJqYs3ET5zlri/D6O6ZPFqYNyOKl/VzonqrRFWktFLUHV1OSYW7qNtxZs5M0FG1m/rZo4v49RhVmMOzyHE/t11Z62SAtU1NJhnHPMX7+DySWfMbmkjM921BAX42N0YTbjD8/hG/26hv2JSSLBoKIWTzTvaW9nckkZb3xaxsbKGuJjfIwd2I1zhudy1KFZOu1dJEBFLZ5ranLMWbeNV+dt4LV5n1FZ00BO5wTOGtaDc4bncUjW/q/nIhLpVNQSUmrqG5m6eDMT55Ty32XlNDkYVZjNpUfnM7pPtq5JIlFJRS0ha1NlDS98XMozH65l885aCrKSueSofM4enquxbIkqKmoJeXUNTby5oIzHZ6xhful2UhNiuHhkLy47poCMZF04SiKfilrCytx123jkvVW8uWAjCTF+LiruyeWjCuiSmuB1NJGgUVFLWFq+aScPvLuSV+dtIMbv4zvFvbji+N7aw5aIpKKWsLZmSxX3TVvBy5+sJzkuhh+MLuB7xxxCUpzGsCVyqKglIizbtJM7pyzl7UWbyE6N5/qxfTlzaA8dJSIRYX9F/dUbFIqEqMKuqTxycRETfziS7mmJXPPifM59aCYLNuzwOppIUKmoJewU5Wfwr/87ij+efThrtlRx2n3vc+Mrn7J9d13LXywShlTUEpZ8PuO8EXm8c+1xXDwyn2c/Wsfxf3qXf35cSjCG80S8pKKWsNY5MZZbThvA5CuPpU+XVK57qYTvPDaL0ordXkcTaTetLmoz85vZXDObFMxAIm3RL6cTz08o5vYzBjJ33TZOvns6T8xYTVOT9q4l/B3IHvVVwOJgBRE5WD6fcVFxL/79s9EccUgGt76+iHMfmsmKzbu8jiZyUFpV1GaWC4wDHg1uHJGD1yMtkScvHcFd5w1mZfkuTr33Pe6ftoL6xiavo4m0SWv3qP8CXAd87V+6mU0ws9lmNru8vLw9som0mZlx1rBc3v7paE7s14U7pyzl7Ac/YPWWKq+jiRywFovazMYDm51zc/a3nHPuYedckXOuKDs7u90CihyM7NR4HrhwOA9cOIy1W3cz7t73mDhnvY4MkbDSmj3qo4HTzGwN8Dxwgpk9E9RUIu3s1EE5vHnVsQzq0ZlrX5zPlc/PY2dNvdexRFqlxaJ2zt3gnMt1zuUD5wPvOOcuCnoykXbWPS2RZy8v5tqTC5lc8hmn3TeDJRsrvY4l0iIdRy1Rxe8zfnxCH569vJhdtQ2ccf8MXpxd6nUskf06oKJ2zr3rnBsfrDAiHaW4IJPJVx7D0Lx0fj6xhF+/skBHhUjI0h61RK0uqQk88/0j+cGoAp7+cC3ffWIWO3Zr3FpCj4paoprfZ9xwaj/uPOdwZq2u4IwHZrCyXCfISGhRUYsA5xbl8dzlxVRW13Pm/TP4cNVWryOJfE5FLRJQlJ/BK1ccTZdOCVz8+CzeXrTJ60gigIpa5AvyMpJ48Qcj6ZfTiR8+M4eJc9Z7HUlERS3yZenJcTz7/SM56tBMrn1xPn//YI3XkSTKqahF9iE5PoZHLyni5P5dufm1hTwxY7XXkSSKqahFvkZ8jJ/7LxzGmAFdufX1RTz+vspavKGiFtmPWL+P+749jFMGduO2SYt49L1VXkeSKKSiFmlBrN/HvRcM5dRB3bh98mKe1DCIdLAYrwOIhINYv497zh9KQ+Mn3PL6IpLjYzi3KM/rWBIltEct0kqxfh9//fZQju2TxS9eKmFySZnXkSRKqKhFDkB8jJ+HvjOcYT3TufqFuUxbstnrSBIFVNQiBygpLobHLx3BYd1S+eEzc3S6uQSdilqkDTolxPL3S48gLyOJy578mHml272OJBFMRS3SRpkp8Txz2ZFkpMRxyeOzWLpxp9eRJEKpqEUOQrfOCTz7/WLiY3xc+sQsNlXWeB1JIpCKWuQg5WUk8fh3R7C9up7vPfkxVbUNXkeSCKOiFmkHA3t05v5vD2NxWSU/eW4uDbqtl7QjFbVIOzm+bxduO30g7yzZzC2vL8Q553UkiRA6M1GkHV1U3IvSbbt56L+r6JmRxIRRh3odSSKAilqknf1iTF/WV1TzuzeW0CszmTEDunkdScKchj5E2pnPZ/z5vMEMzu3MNf+cz4rNOmxPDo6KWiQIEmL9PHjRcOJjfEx4eg47a+q9jiRhTEUtEiTd0xK5/8JhrN26m5/9cz5NTXpzUdpGRS0SRMUFmdw4rh9vL9rEfdNWeB1HwpSKWiTIvntUPmcN7cHd/1nG1MWbvI4jYUhFLRJkZsbvzhpE/5xOXP38PFaV7/I6koQZFbVIB0iIbb6OdYzfuOLZudTUN3odScKIilqkg+SmJ/Hn8wazuKyS2ycv8jqOhBEVtUgHOqFvVyaMKuCZD9cxqeQzr+NImFBRi3Swn485jKE907jhpU9Zu7XK6zgSBlTUIh0s1u/j3vOHYgY/eW4u9brSnrRARS3igbyMJO44+3BK1u/gr+/o+GrZPxW1iEdOGZTDWcN6cP+0FXyybpvXcSSEqahFPHTLaQPo1imBn70wj911ujOM7FuLRW1mCWY2y8zmm9lCM7u1I4KJRINOCbH8+bzBrK3YzW8nL/Y6joSo1uxR1wInOOcGA0OAsWZWHNRUIlGkuCCTy48t4B8frWPaks1ex5EQ1GJRu2Z7znmNDTx0GTCRdnTNyYX07ZbKzyeWUFFV53UcCTGtGqM2M7+ZzQM2A2875z7axzITzGy2mc0uLy9v55gikS0+xs/d3xpCZXU9N7xcovstyhe0qqidc43OuSFALnCEmQ3cxzIPO+eKnHNF2dnZ7RxTJPL1y+nEtWMKmbJwE/+au8HrOBJCDuioD+fcdmAaMDYoaUSi3GXHFDC8Vzq3TVpE+c5ar+NIiGjNUR/ZZpYW+DwROAlYEuRcIlHJ7zPuOHsQu2sbueW1hV7HkRDRmj3qHGCamZUAH9M8Rj0puLFEolfvLqlcdWIfJn9axlsLNnodR0JATEsLOOdKgKEdkEVEAiaMKmBySRm/fnUBIwsy6ZwU63Uk8ZDOTBQJQbF+H38853Aqqur4ja5dHfVU1CIhamCPzvxwdAET56xn+jId8hrNVNQiIewnJ/Th0Oxkbnj5U6pqdS2QaKWiFglhCbF+7jj7cDZsr+beqcu9jiMeUVGLhLii/AzOH5HHY++vZunGnV7HEQ+oqEXCwHVj+5KSEMOvX1mg08ujkIpaJAxkJMdx/di+zFpTwcuf6PTyaKOiFgkT5xXlMbRnGr9/czE7dtd7HUc6kIpaJEz4fMbtZwykoqqOP/17qddxpAOpqEXCyIDunbl4ZD7PfLSWkvXbvY4jHURFLRJmfnZyIVkp8dz4ygIam/TGYjRQUYuEmU4Jsdw4rh8l63fw7Kx1XseRDqCiFglDpw3uzlGHZnLnW0vYskvXrY50KmqRMGRm3Hb6QKrrG/n9G7o8fKRTUYuEqd5dUrj82AJe+mQ9H63a6nUcCSIVtUgY+8kJfeiRlshNry6kobHJ6zgSJCpqkTCWGOfn1+P7sXTTTp77uNTrOBIkKmqRMDdmQDeKCzK4699LdcZihFJRi4Q5M+Om8QPYUV3PX6Yu8zqOBIGKWiQC9O/eifOP6MlTM9eyYrMuhRppVNQiEeKakwpJivNz26TFuhRqhFFRi0SIzJR4rvpGH6YvK2fa0s1ex5F2pKIWiSAXj8ynIDuZ2yctpq5Bh+tFChW1SASJi/Hx63H9WbWliqc/XOt1HGknKmqRCHN83y4c2yeLe6cu1+F6EUJFLRKBfnlqPypr6rlvmu5cHglU1CIRqF9OJ84dnsvfP1jLuq27vY4jB0lFLRKhrjn5MPw+444purpeuFNRi0Sorp0SmDCqgMklZcxZu83rOHIQVNQiEWzCqAKyU+P53Rs6CSacqahFIlhyfAzXnFTInLXbeGvBRq/jSBupqEUi3LlFeRzWNZU/vLVEJ8GEKRW1SITz+4xfjuvH2q27dRJMmFJRi0SB0YXZOgkmjKmoRaKEToIJXypqkSihk2DCV4tFbWZ5ZjbNzBaZ2UIzu6ojgolI+9NJMOGpNXvUDcA1zrn+QDFwhZn1D24sEQkGnQQTnlosaudcmXPuk8DnO4HFQI9gBxOR4NBJMOHngMaozSwfGAp8tI95E8xstpnNLi8vb6d4ItLe9j4J5k2dBBMWWl3UZpYCvARc7Zyr/PJ859zDzrki51xRdnZ2e2YUkXb2+Ukwb+okmHDQqqI2s1iaS/ofzrmXgxtJRIJtz0kw6yp289TMNV7HkRa05qgPAx4DFjvn7gp+JBHpCKMLsxldmM09U5dTUVXndRzZj9bsUR8NfAc4wczmBR6nBjmXiHSAX43rR1VtA/dO1UkwoSympQWcc+8D1gFZRKSDFXZN5YIjevL0h2u5qLgXvbukeB1J9kFnJopEuZ+eVEhSrJ/fv7HY6yjyNVTUIlEuKyWeK07ozdQlm3l/+Rav48g+qKhFhO8elU9ueiK3T15EY5NOggk1KmoRISHWz/Wn9GXJxp1MnFPqdRz5EhW1iAAwblAOw3ulc+eUZeyqbfA6juxFRS0iAJgZN47rx5Zdtfzt3ZVex5G9qKhF5HNDe6Zz+pDuPPLeKjZsr/Y6jgSoqEXkC64b2xeAP76la1aHChW1iHxBj7REvn/sIbw67zPmrtM1q0OBilpEvuL/jutNVko8t0/WNatDgYpaRL4iJT6Ga09uvmb15E/LvI4T9VTUIrJP5xbl0bdbKr9/YwnVdY1ex4lqKmoR2Se/z7j1tAFs2F7Ng++u8DpOVFNRi8jXOrIgk9OHdOdv01exdmuV13GilopaRPbrl6f2I9Zn/GbSIq+jRC0VtYjsV9dOCVz5jT78Z/Fm3lmyyes4UUlFLSItuvToQyjITubW1xdRU683FjuailpEWhQX4+OWbw5g7dbdPPb+aq/jRB0VtYi0yqjCbMYO6MZ976zgM10HpEOpqEWk1W4c348m5/itbtvVoVTUItJquelJXHF8byaXlDF9WbnXcaKGilpEDsiEUQUUZCXzq1c+1RmLHURFLSIHJCHWz2/PHERpRTX3TF3udZyooKIWkQM28tBMzh2eyyPvrWJxWaXXcSKeilpE2uSXp/ajc2IsN7z8qe5cHmQqahFpk/TkOG4a3595pdt55sO1XseJaCpqEWmz04d059g+Wdzx1hJKK3Z7HSdiqahFpM3MjN+fNQifGddNLKFJQyBBoaIWkYOSm57EjeP6MXPVVp7WEEhQqKhF5KB9a0Qeowuz+cObS1izRdetbm8qahE5aGbGH84eRIzf+PnE+RoCaWcqahFpFzmdE7n5mwP4eM02Hn1/lddxIoqKWkTazdnDejBmQFfunLKUBRt2eB0nYqioRaTdmBl/OOtwMpPjufK5uVTVNngdKSKoqEWkXaUnx3H3t4awemsVt72u+yy2BxW1iLS7kYdm8qPjDuWF2aVMLinzOk7Ya7GozexxM9tsZgs6IpCIRIarTyxkaM80fvFSCSvLd3kdJ6y1Zo/6SWBskHOISISJ9fu4/9vDiI/x8cOn57BL49Vt1mJRO+emAxUdkEVEIkz3tET+esFQVpbv4rqJ83FOx1e3hcaoRSSojuqdxfWn9OWNTzfy8HQdX90W7VbUZjbBzGab2ezyct1LTUT+5/JjCxg3KIc73lrCjBVbvI4TdtqtqJ1zDzvnipxzRdnZ2e31bUUkApgZfzzncHp3SeFH//hEby4eIA19iEiHSI6P4bFLRhDrNy594mO27qr1OlLYaM3hec8BM4HDzGy9mV0W/FgiEonyMpJ45OIiNlXW8P2nZlNTr7uYt0Zrjvq4wDmX45yLdc7lOuce64hgIhKZhvZM557zhzCvdDtXPT+XhsYmryOFPA19iEiHGzswh5vG92fKwk26M0wrxHgdQESi06VHH0JVbQN/+vcyEuL8/PaMgZiZ17FCkopaRDxzxfG9qapr5MF3V5IY6+fGcf1U1vugohYRz5gZ1405jOq6Rh57fzVNznHT+P4q6y9RUYuIp8yMm7/ZHzN4YsYaqusa+e2Zg/D7VNZ7qKhFxHNmxk3j+5McF8N901awu66RP583mFi/jncAFbWIhAgz49oxh5EU7+ePby1la1UtD1w4nM6JsV5H85z+XYlISPnRcb3507mDmbW6grMf/IDSit1eR/KcilpEQs45w3N56ntHsrmyhjMfmMGctdu8juQpFbWIhKSRh2by8o+OJjk+hvMfnsmTM1ZH7fWsVdQiErJ6d0nhtSuOYXRhNre8vogrn58XlXc2V1GLSEjrnBTLw98p4udjDmNyyWeM/+v7zF0XXUMhKmoRCXk+n3HF8b159vJi6hqaOOdvM7nr7WXUR8kFnVTUIhI2igsyefPqYzl9SHfunbqcsx74gAUbdngdK+hU1CISVjolxHLXeUN48MJhlO2o4bT73ufW1xeys6be62hBo6IWkbB0yqAcpl4zmguP7MWTH6zhxLv+y6vzNkTkJVNV1CIStjonxvKbMwbyrx8dTVZKPFc9P4/T75/BBxF2A10VtYiEvSF5abz+42O467zBVFTV8e1HP+K7T8yKmPFrC8YB5EVFRW727Nnt/n1FRFpSU9/IUzPXcN87K6isaeD4w7L58Qm9Gd4rw+to+2Vmc5xzRfucp6IWkUhUWVPP0zPX8tj7q6moqqO4IIMJowo4rrALvhC8hKqKWkSi1u66Bp6bVcoj01exsbKGnhlJXFTck/OK8khLivM63udU1CIS9eobm5iycCNPzVzLrNUVxMf4GH94d84e3oPiQzI938tWUYuI7GVxWSVPf7iW1+Z9xq7aBnqkJXLG0O6cOTSX3l1SPMmkohYR2YfqukbeXryJlz9Zz/Rl5TQ5KOyawtgB3RgzsBv9czp12P0bVdQiIi3YXFnD5E/LmLJwI7NWV9DkIC8jkRP7dWVUYTbFh2SSGOcP2s9XUYuIHIAtu2r5z6JNvLVwIzNXbqW2oYm4GB9H5GcwqjCLo3tn0bdbp3a9Aa+KWkSkjWrqG5m1uoLpy8qZvrycZZt2AZCaEMPwXumMyM9gRH4Gh+d2JiG27Xvc+ytq3dxWRGQ/EmL9jCrMZlRhNgAbd9Tw4aqtzFpTwcerK3h36VIA4vw+huSl8fyE4nY/gkRFLSJyALp1TuCMoT04Y2gPALZV1TF77TZmr6lgR3V9UA7zU1GLiByE9OQ4TurflZP6dw3az9BFmUREQpyKWkQkxKmoRURCnIpaRCTEqahFREKcilpEJMSpqEVEQpyKWkQkxAXlWh9mVg6sbeOXZwGRdQvhlmmdo4PWOfIdzPr2cs5l72tGUIr6YJjZ7K+7MEmk0jpHB61z5AvW+mroQ0QkxKmoRURCXCgW9cNeB/CA1jk6aJ0jX1DWN+TGqEVE5ItCcY9aRET2oqIWEQlxIVPUZjbWzJaa2Qozu97rPO3FzPLMbJqZLTKzhWZ2VWB6hpm9bWbLAx/TA9PNzO4N/B5KzGyYt2vQdmbmN7O5ZjYp8PwQM/sosG4vmFlcYHp84PmKwPx8T4O3kZmlmdlEM1tiZovNbGSkb2cz+2ng73qBmT1nZgmRtp3N7HEz22xmC/aadsDb1cwuCSy/3MwuOZAMIVHUZuYH7gdOAfoDF5hZf29TtZsG4BrnXH+gGLgisG7XA1Odc32AqYHn0Pw76BN4TAAe7PjI7eYqYPFez+8A7nbO9Qa2AZcFpl8GbAtMvzuwXDi6B3jLOdcXGEzzukfsdjazHsCVQJFzbiDgB84n8rbzk8DYL007oO1qZhnAzcCRwBHAzXvKvVWcc54/gJHAlL2e3wDc4HWuIK3rq8BJwFIgJzAtB1ga+Pwh4IK9lv98uXB6ALmBP+ATgEmA0XzGVsyXtzkwBRgZ+DwmsJx5vQ4HuL6dgdVfzh3J2xnoAZQCGYHtNgkYE4nbGcgHFrR1uwIXAA/tNf0Ly7X0CIk9av63wfdYH5gWUQIv9YYCHwFdnXNlgVkbgT03XIuU38VfgOuApsDzTGC7c64h8Hzv9fp8nQPzdwSWDyeHAOXAE4HhnkfNLJkI3s7OuQ3An4B1QBnN220Okb2d9zjQ7XpQ2ztUijrimVkK8BJwtXOucu95rvlfbMQcJ2lm44HNzrk5XmfpQDHAMOBB59xQoIr/vRwGInI7pwOn0/xPqjuQzFeHCCJeR2zXUCnqDUDeXs9zA9MigpnF0lzS/3DOvRyYvMnMcgLzc4DNgemR8Ls4GjjNzNYAz9M8/HEPkGZmMYFl9l6vz9c5ML8zsLUjA7eD9cB659xHgecTaS7uSN7OJwKrnXPlzrl64GWat30kb+c9DnS7HtT2DpWi/hjoE3i3OI7mNyRe8zhTuzAzAx4DFjvn7tpr1mvAnnd+L6F57HrP9IsD7x4XAzv2eokVFpxzNzjncp1z+TRvy3eccxcC04BzAot9eZ33/C7OCSwfVnuezrmNQKmZHRaY9A1gERG8nWke8ig2s6TA3/medY7Y7byXA92uU4CTzSw98Erk5MC01vF6kH6vwfVTgWXASuBXXudpx/U6huaXRSXAvMDjVJrH5qYCy4H/ABmB5Y3mI2BWAp/S/I665+txEOt/HDAp8HkBMAtYAbwIxAemJwSerwjML/A6dxvXdQgwO7CtXwHSI307A7cCS4AFwNNAfKRtZ+A5msfg62l+5XRZW7Yr8L3Auq8ALj2QDDqFXEQkxIXK0IeIiHwNFbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIU1GLiIS4/wecQTmUPnjbjwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib.ticker as ticker\n", + "%matplotlib inline\n", + "\n", + "plt.figure()\n", + "plt.plot(losses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 预测\n", + "用训练好的模型进行预测。" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the input words is: praise., How\n", + "the predict words is: much\n", + "the true words is: much\n" + ] + } + ], + "source": [ + "import random\n", + "def test(model):\n", + " model.eval()\n", + " # 从最后10组数据中随机选取1个\n", + " idx = random.randint(len(trigram)-10, len(trigram)-1)\n", + " print('the input words is: ' + trigram[idx][0][0] + ', ' + trigram[idx][0][1])\n", + " x_data = list(map(lambda w: word_to_idx[w], trigram[idx][0]))\n", + " x_data = paddle.imperative.to_variable(np.array(x_data))\n", + " predicts = model(x_data)\n", + " predicts = predicts.numpy().tolist()[0]\n", + " predicts = predicts.index(max(predicts))\n", + " print('the predict words is: ' + idx_to_word[predicts])\n", + " y_data = trigram[idx][1]\n", + " print('the true words is: ' + y_data)\n", + "test(model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.3 64-bit", + "language": "python", + "name": "python_defaultSpec_1598180286976" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3-final" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/paddle2.0_docs/text_generation/text_generation_paddle.ipynb b/paddle2.0_docs/text_generation/text_generation_paddle.ipynb new file mode 100644 index 00000000..fc47a419 --- /dev/null +++ b/paddle2.0_docs/text_generation/text_generation_paddle.ipynb @@ -0,0 +1,526 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 基于GRU的Text Generation\n", + "文本生成是NLP领域中的重要组成部分,基于GRU,我们可以快速构建文本生成模型。" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.0.0-alpha0'" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "paddle.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 复现过程\n", + "## 1.下载数据\n", + "文件路径:https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n", + "保存为txt格式即可" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.读取数据" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of text: 1115394 characters\n" + ] + } + ], + "source": [ + "# 文件路径\n", + "path_to_file = './shakespeare.txt'\n", + "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", + "\n", + "# 文本长度是指文本中的字符个数\n", + "print ('Length of text: {} characters'.format(len(text)))" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You are all resolved rather to die than to famish?\n", + "\n", + "All:\n", + "Resolved. resolved.\n", + "\n", + "First Citizen:\n", + "First, you know Caius Marcius is chief enemy to the people.\n", + "\n" + ] + } + ], + "source": [ + "# 看一看文本中的前 250 个字符\n", + "print(text[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "65 unique characters\n" + ] + } + ], + "source": [ + "# 文本中的非重复字符\n", + "vocab = sorted(set(text))\n", + "print ('{} unique characters'.format(len(vocab)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3.向量化文本\n", + "在训练之前,我们需要将字符串映射到数字表示值。创建两个查找表格:一个将字符映射到数字,另一个将数字映射到字符。" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "# 创建从非重复字符到索引的映射\n", + "char2idx = {u:i for i, u in enumerate(vocab)}\n", + "idx2char = np.array(vocab)\n", + "# 用index表示文本\n", + "text_as_int = np.array([char2idx[c] for c in text])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n" + ] + } + ], + "source": [ + "print(char2idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n", + " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n", + " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n", + " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n" + ] + } + ], + "source": [ + "print(idx2char)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "现在,每个字符都有一个整数表示值。请注意,我们将字符映射至索引 0 至 len(vocab)." + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[18 47 56 ... 45 8 0]\n", + "1115394\n" + ] + } + ], + "source": [ + "print(text_as_int)\n", + "print(len(text_as_int))" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]\n" + ] + } + ], + "source": [ + "# 显示文本首 13 个字符的整数映射\n", + "print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 预测任务\n", + "给定一个字符或者一个字符序列,下一个最可能出现的字符是什么?这就是我们训练模型要执行的任务。输入进模型的是一个字符序列,我们训练这个模型来预测输出 -- 每个时间步(time step)预测下一个字符是什么。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 创建训练样本和目标\n", + "接下来,将文本划分为样本序列。每个输入序列包含文本中的 seq_length 个字符。\n", + "\n", + "对于每个输入序列,其对应的目标包含相同长度的文本,但是向右顺移一个字符。\n", + "\n", + "将文本拆分为长度为 seq_length 的文本块。例如,假设 seq_length 为 4 而且文本为 “Hello”, 那么输入序列将为 “Hell”,目标序列将为 “ello”。" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "seq_length = 100\n", + "def load_data(data, seq_length):\n", + " train_data = []\n", + " train_label = []\n", + " for i in range(len(data)//seq_length):\n", + " train_data.append(data[i*seq_length:(i+1)*seq_length])\n", + " train_label.append(data[i*seq_length + 1:(i+1)*seq_length+1])\n", + " return train_data, train_label\n", + "train_data, train_label = load_data(text_as_int, seq_length)" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training data is :\n", + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You\n", + "------------\n", + "training_label is:\n", + "irst Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You \n" + ] + } + ], + "source": [ + "char_list = []\n", + "label_list = []\n", + "for char_id, label_id in zip(train_data[0], train_label[0]):\n", + " char_list.append(idx2char[char_id])\n", + " label_list.append(idx2char[label_id])\n", + "\n", + "print('training data is :')\n", + "print(''.join(char_list))\n", + "print(\"------------\")\n", + "print('training_label is:')\n", + "print(''.join(label_list))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 用`paddle.batch`完成数据的加载" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "batch_size = 64\n", + "def train_reader():\n", + " for i in range(len(train_data)):\n", + " yield train_data[i], train_label[i]\n", + "batch_reader = paddle.batch(train_reader, batch_size=batch_size) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 基于GRU构建文本生成模型" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import numpy as np\n", + "\n", + "vocab_size = len(vocab)\n", + "embedding_dim = 256\n", + "hidden_size = 1024\n", + "class GRUModel(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(GRUModel, self).__init__()\n", + " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", + " self.gru = paddle.incubate.hapi.text.GRU(input_size=embedding_dim, hidden_size=hidden_size)\n", + " self.linear1 = paddle.nn.Linear(hidden_size, hidden_size//2)\n", + " self.linear2 = paddle.nn.Linear(hidden_size//2, vocab_size)\n", + " def forward(self, x):\n", + " x = self.embedding(x)\n", + " x = paddle.reshape(x, [-1, 1, embedding_dim])\n", + " x, _ = self.gru(x)\n", + " x = paddle.reshape(x, [-1, hidden_size])\n", + " x = self.linear1(x)\n", + " x = paddle.nn.functional.relu(x)\n", + " x = self.linear2(x)\n", + " x = paddle.nn.functional.softmax(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch: 50, loss is: [3.7835407]\n", + "epoch: 0, batch: 100, loss is: [3.2774005]\n", + "epoch: 0, batch: 150, loss is: [3.2576294]\n", + "epoch: 1, batch: 50, loss is: [3.3434656]\n", + "epoch: 1, batch: 100, loss is: [2.9948606]\n", + "epoch: 1, batch: 150, loss is: [3.0285468]\n", + "epoch: 2, batch: 50, loss is: [3.133882]\n", + "epoch: 2, batch: 100, loss is: [2.7811327]\n", + "epoch: 2, batch: 150, loss is: [2.8133557]\n", + "epoch: 3, batch: 50, loss is: [3.000814]\n", + "epoch: 3, batch: 100, loss is: [2.6404488]\n", + "epoch: 3, batch: 150, loss is: [2.7050896]\n", + "epoch: 4, batch: 50, loss is: [2.9289591]\n", + "epoch: 4, batch: 100, loss is: [2.5629177]\n", + "epoch: 4, batch: 150, loss is: [2.6438713]\n", + "epoch: 5, batch: 50, loss is: [2.8832304]\n", + "epoch: 5, batch: 100, loss is: [2.5137548]\n", + "epoch: 5, batch: 150, loss is: [2.5926144]\n", + "epoch: 6, batch: 50, loss is: [2.8562953]\n", + "epoch: 6, batch: 100, loss is: [2.4752126]\n", + "epoch: 6, batch: 150, loss is: [2.5510798]\n", + "epoch: 7, batch: 50, loss is: [2.8426895]\n", + "epoch: 7, batch: 100, loss is: [2.4442513]\n", + "epoch: 7, batch: 150, loss is: [2.5187433]\n", + "epoch: 8, batch: 50, loss is: [2.8353484]\n", + "epoch: 8, batch: 100, loss is: [2.4200597]\n", + "epoch: 8, batch: 150, loss is: [2.4956212]\n", + "epoch: 9, batch: 50, loss is: [2.8308532]\n", + "epoch: 9, batch: 100, loss is: [2.4011066]\n", + "epoch: 9, batch: 150, loss is: [2.4787998]\n" + ] + } + ], + "source": [ + "paddle.enable_imperative()\n", + "losses = []\n", + "def train(model):\n", + " model.train()\n", + " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", + " for epoch in range(10):\n", + " batch_id = 0\n", + " for batch_data in batch_reader():\n", + " batch_id += 1\n", + " data = np.array(batch_data)\n", + " x_data = data[:, 0]\n", + " y_data = data[:, 1]\n", + " for i in range(len(x_data[0])):\n", + " x_char = x_data[:, i]\n", + " y_char = y_data[:, i]\n", + " x_char = paddle.imperative.to_variable(x_char)\n", + " y_char = paddle.imperative.to_variable(y_char)\n", + " predicts = model(x_char)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_char)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_loss.backward()\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + " if batch_id % 50 == 0:\n", + " print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n", + " losses.append(loss.numpy())\n", + "model = GRUModel()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 模型预测\n", + "利用训练好的模型,输出初始化文本'ROMEO: ',自动生成后续的num_generate个字符。" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ROMEO:I the the the the the the the the the the the the the the the the the the the the the the the the th\n" + ] + } + ], + "source": [ + "def generate_text(model, start_string):\n", + " \n", + " model.eval()\n", + " num_generate = 100\n", + "\n", + " # Converting our start string to numbers (vectorizing)\n", + " input_eval = [char2idx[s] for s in start_string]\n", + " input_data = paddle.imperative.to_variable(np.array(input_eval))\n", + " input_data = paddle.reshape(input_data, [-1, 1])\n", + " text_generated = []\n", + "\n", + " for i in range(num_generate):\n", + " predicts = model(input_data)\n", + " predicts = predicts.numpy().tolist()[0]\n", + " # print(predicts)\n", + " predicts_id = predicts.index(max(predicts))\n", + " # print(predicts_id)\n", + " # using a categorical distribution to predict the character returned by the model\n", + " input_data = paddle.imperative.to_variable(np.array([predicts_id]))\n", + " input_data = paddle.reshape(input_data, [-1, 1])\n", + " text_generated.append(idx2char[predicts_id])\n", + " return (start_string + ''.join(text_generated))\n", + "print(generate_text(model, start_string=u\"ROMEO:\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 1e9fc024b13c97a675f6e99b3b7475985ae98b40 Mon Sep 17 00:00:00 2001 From: chenlong Date: Mon, 31 Aug 2020 20:32:46 +0800 Subject: [PATCH 2/2] fix n_gram_model --- .../n_gram_model/n_gram_model.ipynb | 264 ++++++--- .../text_generation_paddle.ipynb | 526 ------------------ 2 files changed, 175 insertions(+), 615 deletions(-) delete mode 100644 paddle2.0_docs/text_generation/text_generation_paddle.ipynb diff --git a/paddle2.0_docs/n_gram_model/n_gram_model.ipynb b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb index 1e8ab412..d46e601f 100644 --- a/paddle2.0_docs/n_gram_model/n_gram_model.ipynb +++ b/paddle2.0_docs/n_gram_model/n_gram_model.ipynb @@ -4,10 +4,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 用N-Gram模型在莎士比亚诗中训练word embedding\n", + "## 用N-Gram模型在莎士比亚文集中训练word embedding\n", "N-gram 是计算机语言学和概率论范畴内的概念,是指给定的一段文本中N个项目的序列。\n", "N=1 时 N-gram 又称为 unigram,N=2 称为 bigram,N=3 称为 trigram,以此类推。实际应用通常采用 bigram 和 trigram 进行计算。\n", - "本示例在莎士比亚十四行诗上实现了trigram。" + "本示例在莎士比亚文集上实现了trigram。" ] }, { @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 189, "metadata": {}, "outputs": [ { @@ -29,7 +29,7 @@ "'2.0.0-alpha0'" ] }, - "execution_count": 1, + "execution_count": 189, "metadata": {}, "output_type": "execute_result" } @@ -44,32 +44,88 @@ "metadata": {}, "source": [ "## 数据集&&相关参数\n", - "训练数据集采用了莎士比亚十四行诗,CONTEXT_SIZE设为2,意味着是trigram。EMBEDDING_DIM设为10。" + "训练数据集采用了莎士比亚文集,[下载](https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt),保存为txt格式即可。
\n", + "context_size设为2,意味着是trigram。embedding_dim设为256。" ] }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 190, "metadata": {}, "outputs": [], "source": [ - "CONTEXT_SIZE = 2\n", - "EMBEDDING_DIM = 10\n", + "embedding_dim = 256\n", + "context_size = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 191, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of text: 1115394 characters\n" + ] + } + ], + "source": [ + "# 文件路径\n", + "path_to_file = './shakespeare.txt'\n", + "test_sentence = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", "\n", - "test_sentence = \"\"\"When forty winters shall besiege thy brow,\n", - "And dig deep trenches in thy beauty's field,\n", - "Thy youth's proud livery so gazed on now,\n", - "Will be a totter'd weed of small worth held:\n", - "Then being asked, where all thy beauty lies,\n", - "Where all the treasure of thy lusty days;\n", - "To say, within thine own deep sunken eyes,\n", - "Were an all-eating shame, and thriftless praise.\n", - "How much more praise deserv'd thy beauty's use,\n", - "If thou couldst answer 'This fair child of mine\n", - "Shall sum my count, and make my old excuse,'\n", - "Proving his beauty by succession thine!\n", - "This were to be new made when thou art old,\n", - "And see thy blood warm when thou feel'st it cold.\"\"\".split()" + "# 文本长度是指文本中的字符个数\n", + "print ('Length of text: {} characters'.format(len(test_sentence)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 去除标点符号\n", + "用`string`库中的punctuation,完成英文符号的替换。" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'!': '', '\"': '', '#': '', '$': '', '%': '', '&': '', \"'\": '', '(': '', ')': '', '*': '', '+': '', ',': '', '-': '', '.': '', '/': '', ':': '', ';': '', '<': '', '=': '', '>': '', '?': '', '@': '', '[': '', '\\\\': '', ']': '', '^': '', '_': '', '`': '', '{': '', '|': '', '}': '', '~': ''}\n" + ] + } + ], + "source": [ + "from string import punctuation\n", + "process_dicts={i:'' for i in punctuation}\n", + "print(process_dicts)" + ] + }, + { + "cell_type": "code", + "execution_count": 193, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12848\n" + ] + } + ], + "source": [ + "punc_table = str.maketrans(dicts)\n", + "test_sentence = test_sentence.translate(punc_table)\n", + "test_sentence = test_sentence.lower().split()\n", + "vocab = set(test_sentence)\n", + "print(len(vocab))" ] }, { @@ -82,54 +138,62 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 194, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[(('When', 'forty'), 'winters'), (('forty', 'winters'), 'shall'), (('winters', 'shall'), 'besiege')]\n" + "[[['first', 'citizen'], 'before'], [['citizen', 'before'], 'we'], [['before', 'we'], 'proceed']]\n" ] } ], "source": [ - "trigram = [((test_sentence[i], test_sentence[i + 1]), test_sentence[i + 2])\n", + "trigram = [[[test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2]]\n", " for i in range(len(test_sentence) - 2)]\n", "\n", - "vocab = set(test_sentence)\n", "word_to_idx = {word: i for i, word in enumerate(vocab)}\n", "idx_to_word = {word_to_idx[word]: word for word in word_to_idx}\n", "# 看一下数据集\n", - "print(trigram[:3])\n" + "print(trigram[:3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 构建`Dataset`类 加载数据" + "## 构建`Dataset`类 加载数据\n", + "用`paddle.io.Dataset`构建数据集,然后作为参数传入到`paddle.io.DataLoader`,完成数据集的加载。" ] }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 184, "metadata": {}, "outputs": [], "source": [ "import paddle\n", + "batch_size = 256\n", + "paddle.enable_imperative()\n", "class TrainDataset(paddle.io.Dataset):\n", - " def __init__(self, tuple_data, vocab):\n", + " def __init__(self, tuple_data):\n", " self.tuple_data = tuple_data\n", - " self.vocab = vocab\n", "\n", " def __getitem__(self, idx):\n", - " data = list(self.tuple_data[idx][0])\n", - " label = list(self.tuple_data[idx][1])\n", + " data = self.tuple_data[idx][0]\n", + " label = self.tuple_data[idx][1]\n", + " data = list(map(lambda w: word_to_idx[w], data))\n", + " label = word_to_idx[label]\n", + " \n", " return data, label\n", " \n", " def __len__(self):\n", - " return len(self.tuple_data)" + " return len(self.tuple_data)\n", + " \n", + "train_dataset = TrainDataset(trigram)\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.fluid.cpu_places(), return_list=True,\\\n", + " shuffle=True, batch_size=batch_size, drop_last=True)" ] }, { @@ -137,12 +201,12 @@ "metadata": {}, "source": [ "## 组网&训练\n", - "这里用paddle动态图的方式组网,由于是N-Gram模型,只需要一层`Embedding`与两层`Linear`就可以完成网络模型的构建。" + "这里用paddle动态图的方式组网。为了构建Trigram模型,用一层 `Embedding` 与两层 `Linear` 完成构建。`Embedding` 层对输入的前两个单词embedding,然后输入到后面的两个`Linear`层中,完成特征提取。" ] }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 185, "metadata": {}, "outputs": [], "source": [ @@ -152,12 +216,12 @@ " def __init__(self, vocab_size, embedding_dim, context_size):\n", " super(NGramModel, self).__init__()\n", " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", - " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 128)\n", - " self.linear2 = paddle.nn.Linear(128, vocab_size)\n", + " self.linear1 = paddle.nn.Linear(context_size * embedding_dim, 1024)\n", + " self.linear2 = paddle.nn.Linear(1024, vocab_size)\n", "\n", " def forward(self, x):\n", " x = self.embedding(x)\n", - " x = paddle.reshape(x, [1, -1])\n", + " x = paddle.reshape(x, [-1, context_size * embedding_dim])\n", " x = self.linear1(x)\n", " x = paddle.nn.functional.relu(x)\n", " x = self.linear2(x)\n", @@ -169,66 +233,81 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 初始化Model,并定义相关的参数。" + "### 定义`train()`函数,对模型进行训练。" ] }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 195, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 0, loss is: [4.631529]\n", - "epoch: 50, loss is: [4.6081576]\n", - "epoch: 100, loss is: [4.600631]\n", - "epoch: 150, loss is: [4.603069]\n", - "epoch: 200, loss is: [4.592647]\n", - "epoch: 250, loss is: [4.5626693]\n", - "epoch: 300, loss is: [4.513106]\n", - "epoch: 350, loss is: [4.4345813]\n", - "epoch: 400, loss is: [4.3238697]\n", - "epoch: 450, loss is: [4.1728854]\n", - "epoch: 500, loss is: [3.9622664]\n", - "epoch: 550, loss is: [3.67673]\n", - "epoch: 600, loss is: [3.2998457]\n", - "epoch: 650, loss is: [2.8206367]\n", - "epoch: 700, loss is: [2.2514927]\n", - "epoch: 750, loss is: [1.6479329]\n", - "epoch: 800, loss is: [1.1147357]\n", - "epoch: 850, loss is: [0.73231363]\n", - "epoch: 900, loss is: [0.49481753]\n", - "epoch: 950, loss is: [0.3504072]\n" + "epoch: 0, batch_id: 0, loss is: [9.4609375]\n", + "epoch: 0, batch_id: 100, loss is: [6.8079987]\n", + "epoch: 0, batch_id: 200, loss is: [7.15846]\n", + "epoch: 0, batch_id: 300, loss is: [6.536172]\n", + "epoch: 0, batch_id: 400, loss is: [6.847218]\n", + "epoch: 0, batch_id: 500, loss is: [6.6856213]\n", + "epoch: 0, batch_id: 600, loss is: [7.023284]\n", + "epoch: 0, batch_id: 700, loss is: [6.668572]\n", + "epoch: 1, batch_id: 0, loss is: [6.251506]\n", + "epoch: 1, batch_id: 100, loss is: [6.3487673]\n", + "epoch: 1, batch_id: 200, loss is: [6.3545203]\n", + "epoch: 1, batch_id: 300, loss is: [6.2171617]\n", + "epoch: 1, batch_id: 400, loss is: [6.0640473]\n", + "epoch: 1, batch_id: 500, loss is: [6.469482]\n", + "epoch: 1, batch_id: 600, loss is: [6.342491]\n", + "epoch: 1, batch_id: 700, loss is: [6.3045797]\n", + "epoch: 2, batch_id: 0, loss is: [5.993376]\n", + "epoch: 2, batch_id: 100, loss is: [5.8828726]\n", + "epoch: 2, batch_id: 200, loss is: [5.9760466]\n", + "epoch: 2, batch_id: 300, loss is: [5.756211]\n", + "epoch: 2, batch_id: 400, loss is: [6.112233]\n", + "epoch: 2, batch_id: 500, loss is: [5.8354917]\n", + "epoch: 2, batch_id: 600, loss is: [5.8915625]\n", + "epoch: 2, batch_id: 700, loss is: [5.820907]\n", + "epoch: 3, batch_id: 0, loss is: [5.9928036]\n", + "epoch: 3, batch_id: 100, loss is: [5.8935637]\n", + "epoch: 3, batch_id: 200, loss is: [6.1485195]\n", + "epoch: 3, batch_id: 300, loss is: [5.932296]\n", + "epoch: 3, batch_id: 400, loss is: [5.8619576]\n", + "epoch: 3, batch_id: 500, loss is: [6.1057215]\n", + "epoch: 3, batch_id: 600, loss is: [5.9520254]\n", + "epoch: 3, batch_id: 700, loss is: [5.6877956]\n", + "epoch: 4, batch_id: 0, loss is: [5.550747]\n", + "epoch: 4, batch_id: 100, loss is: [5.757982]\n", + "epoch: 4, batch_id: 200, loss is: [6.2809753]\n", + "epoch: 4, batch_id: 300, loss is: [5.860643]\n", + "epoch: 4, batch_id: 400, loss is: [5.9789114]\n", + "epoch: 4, batch_id: 500, loss is: [5.763079]\n", + "epoch: 4, batch_id: 600, loss is: [5.85236]\n", + "epoch: 4, batch_id: 700, loss is: [6.244775]\n" ] } ], "source": [ "vocab_size = len(vocab)\n", - "embedding_dim = 10\n", - "context_size = 2\n", - "\n", - "paddle.enable_imperative()\n", + "epochs = 5\n", "losses = []\n", "def train(model):\n", " model.train()\n", - " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", - " for epoch in range(1000):\n", - " # 留最后10组作为预测\n", - " for context, target in trigram[:-10]:\n", - " context_idxs = list(map(lambda w: word_to_idx[w], context))\n", - " x_data = paddle.imperative.to_variable(np.array(context_idxs))\n", - " y_data = paddle.imperative.to_variable(np.array([word_to_idx[target]]))\n", + " optim = paddle.optimizer.Adam(learning_rate=0.01, parameter_list=model.parameters())\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", " predicts = model(x_data)\n", - " # print (predicts)\n", " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", - " loss.backward()\n", - " optim.minimize(loss)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " losses.append(avg_loss.numpy())\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n", + " optim.minimize(avg_loss)\n", " model.clear_gradients()\n", - " if epoch % 50 == 0:\n", - " print(\"epoch: {}, loss is: {}\".format(epoch, loss.numpy()))\n", - " losses.append(loss.numpy())\n", "model = NGramModel(vocab_size, embedding_dim, context_size)\n", "train(model)" ] @@ -243,22 +322,22 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 187, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[]" ] }, - "execution_count": 123, + "execution_count": 187, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAdkklEQVR4nO3dd3xV9f3H8dfn3uwBmUAggRghMmUFDQ5Qq4JC3VqtVmuttL/aqq3WamtdtcPaarWOuq1aR0XrAJVaxKKIIghE9oaAAQIBAiE7398fuVhUJCHk5tzxfj4e95Hcc06S98nJ451zv/cMc84hIiKhy+d1ABER2T8VtYhIiFNRi4iEOBW1iEiIU1GLiIS4mGB806ysLJefnx+Mby0iEpHmzJmzxTmXva95QSnq/Px8Zs+eHYxvLSISkcxs7dfN09CHiEiIU1GLiIQ4FbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIC8px1G318PSVNDmI8Vnzw+8jPsZHVko82anxZKXEk5kSR6xf/19EJHqEVFHf9fYyauqbWlwuIzmOrJQ4slPjyU5pLvCs1HhS4mNIjPWTGOcnIdZHjM+H32dffJi1OM3s4NclxucjMbY5h7XHNxSRqBVSRV1y8xgampqob3Q0NjkaGpuorm9ky646tuyqpXxn7Vc+zlm3jfKdta0qeK8kxvpJivOTEPiYGOf//B/K3tPTEuPITIkjMyWerOTmj5kpcaQnxeH3qexFolVIFXVcjI+4fQyb98pM3u/XOeeoqmtkd20DNfXN5V5d30hDYxONTY5G11z8X3nsa7pztMdNb5r/yQSy1DVQXd/I7rpGagIfq+saqaiqY8O2wPP6RnZU19PY9NUf7rPmVxGZyfF06RRPbnoiuelJ9EhLJDc9kR7piXRJTVCZi0SokCrqtjIzUuJjSIkP79VpanLsqK5na1Xt568itu6qY+uuWrZU1bFlZy2bKmt4u6ySLbvqvvC1sX4jp3OguNMCRZ6eGCj1RLp3TsSnIhcJS+HdbBHG5zPSk+NIT46jd5f9L1td18iG7dWs37Y78LGaDduan09fXs7mnbVfeGWQGOunsFsqfbum0jcnlcO6pdI/pxNpSXHBXSkROWgq6jCVGOend5cUendJ2ef82oZGyrbXsGF7NaUVu1m2aRdLNlby9uJNvDC79PPl8jOTGJyXxuDcNAbnpTGgeycSYv0dtRoi0goq6ggVH+MnPyuZ/Kwvju875yjfVcuSsp0s+GwHJaU7mLW6glfnfQY0HxrZL6cTg/M6MyI/gxH5GXRPS/RiFUQkwFx7vHP2JUVFRU7Xow4vmyprmFe6nfml25m/fjvzS3ewq7YBgNz0RI7Iz2DEIRkc0zuLvIwkj9OKRB4zm+OcK9rnPBW17Etjk2NxWSWzVlfw8ZoKZq2uYGtV8xuYBdnJjC7MZlRhNiMLMjVUItIOVNRy0JxzrCzfxfRlW/jvsnI+XLWV2oYmEmP9HN83mzEDunFC3y6kJsR6HVUkLKmopd3V1Dfy4aqt/GfxJqYs3ET5zlri/D6O6ZPFqYNyOKl/VzonqrRFWktFLUHV1OSYW7qNtxZs5M0FG1m/rZo4v49RhVmMOzyHE/t11Z62SAtU1NJhnHPMX7+DySWfMbmkjM921BAX42N0YTbjD8/hG/26hv2JSSLBoKIWTzTvaW9nckkZb3xaxsbKGuJjfIwd2I1zhudy1KFZOu1dJEBFLZ5ranLMWbeNV+dt4LV5n1FZ00BO5wTOGtaDc4bncUjW/q/nIhLpVNQSUmrqG5m6eDMT55Ty32XlNDkYVZjNpUfnM7pPtq5JIlFJRS0ha1NlDS98XMozH65l885aCrKSueSofM4enquxbIkqKmoJeXUNTby5oIzHZ6xhful2UhNiuHhkLy47poCMZF04SiKfilrCytx123jkvVW8uWAjCTF+LiruyeWjCuiSmuB1NJGgUVFLWFq+aScPvLuSV+dtIMbv4zvFvbji+N7aw5aIpKKWsLZmSxX3TVvBy5+sJzkuhh+MLuB7xxxCUpzGsCVyqKglIizbtJM7pyzl7UWbyE6N5/qxfTlzaA8dJSIRYX9F/dUbFIqEqMKuqTxycRETfziS7mmJXPPifM59aCYLNuzwOppIUKmoJewU5Wfwr/87ij+efThrtlRx2n3vc+Mrn7J9d13LXywShlTUEpZ8PuO8EXm8c+1xXDwyn2c/Wsfxf3qXf35cSjCG80S8pKKWsNY5MZZbThvA5CuPpU+XVK57qYTvPDaL0ordXkcTaTetLmoz85vZXDObFMxAIm3RL6cTz08o5vYzBjJ33TZOvns6T8xYTVOT9q4l/B3IHvVVwOJgBRE5WD6fcVFxL/79s9EccUgGt76+iHMfmsmKzbu8jiZyUFpV1GaWC4wDHg1uHJGD1yMtkScvHcFd5w1mZfkuTr33Pe6ftoL6xiavo4m0SWv3qP8CXAd87V+6mU0ws9lmNru8vLw9som0mZlx1rBc3v7paE7s14U7pyzl7Ac/YPWWKq+jiRywFovazMYDm51zc/a3nHPuYedckXOuKDs7u90CihyM7NR4HrhwOA9cOIy1W3cz7t73mDhnvY4MkbDSmj3qo4HTzGwN8Dxwgpk9E9RUIu3s1EE5vHnVsQzq0ZlrX5zPlc/PY2dNvdexRFqlxaJ2zt3gnMt1zuUD5wPvOOcuCnoykXbWPS2RZy8v5tqTC5lc8hmn3TeDJRsrvY4l0iIdRy1Rxe8zfnxCH569vJhdtQ2ccf8MXpxd6nUskf06oKJ2zr3rnBsfrDAiHaW4IJPJVx7D0Lx0fj6xhF+/skBHhUjI0h61RK0uqQk88/0j+cGoAp7+cC3ffWIWO3Zr3FpCj4paoprfZ9xwaj/uPOdwZq2u4IwHZrCyXCfISGhRUYsA5xbl8dzlxVRW13Pm/TP4cNVWryOJfE5FLRJQlJ/BK1ccTZdOCVz8+CzeXrTJ60gigIpa5AvyMpJ48Qcj6ZfTiR8+M4eJc9Z7HUlERS3yZenJcTz7/SM56tBMrn1xPn//YI3XkSTKqahF9iE5PoZHLyni5P5dufm1hTwxY7XXkSSKqahFvkZ8jJ/7LxzGmAFdufX1RTz+vspavKGiFtmPWL+P+749jFMGduO2SYt49L1VXkeSKKSiFmlBrN/HvRcM5dRB3bh98mKe1DCIdLAYrwOIhINYv497zh9KQ+Mn3PL6IpLjYzi3KM/rWBIltEct0kqxfh9//fZQju2TxS9eKmFySZnXkSRKqKhFDkB8jJ+HvjOcYT3TufqFuUxbstnrSBIFVNQiBygpLobHLx3BYd1S+eEzc3S6uQSdilqkDTolxPL3S48gLyOJy578mHml272OJBFMRS3SRpkp8Txz2ZFkpMRxyeOzWLpxp9eRJEKpqEUOQrfOCTz7/WLiY3xc+sQsNlXWeB1JIpCKWuQg5WUk8fh3R7C9up7vPfkxVbUNXkeSCKOiFmkHA3t05v5vD2NxWSU/eW4uDbqtl7QjFbVIOzm+bxduO30g7yzZzC2vL8Q553UkiRA6M1GkHV1U3IvSbbt56L+r6JmRxIRRh3odSSKAilqknf1iTF/WV1TzuzeW0CszmTEDunkdScKchj5E2pnPZ/z5vMEMzu3MNf+cz4rNOmxPDo6KWiQIEmL9PHjRcOJjfEx4eg47a+q9jiRhTEUtEiTd0xK5/8JhrN26m5/9cz5NTXpzUdpGRS0SRMUFmdw4rh9vL9rEfdNWeB1HwpSKWiTIvntUPmcN7cHd/1nG1MWbvI4jYUhFLRJkZsbvzhpE/5xOXP38PFaV7/I6koQZFbVIB0iIbb6OdYzfuOLZudTUN3odScKIilqkg+SmJ/Hn8wazuKyS2ycv8jqOhBEVtUgHOqFvVyaMKuCZD9cxqeQzr+NImFBRi3Swn485jKE907jhpU9Zu7XK6zgSBlTUIh0s1u/j3vOHYgY/eW4u9brSnrRARS3igbyMJO44+3BK1u/gr+/o+GrZPxW1iEdOGZTDWcN6cP+0FXyybpvXcSSEqahFPHTLaQPo1imBn70wj911ujOM7FuLRW1mCWY2y8zmm9lCM7u1I4KJRINOCbH8+bzBrK3YzW8nL/Y6joSo1uxR1wInOOcGA0OAsWZWHNRUIlGkuCCTy48t4B8frWPaks1ex5EQ1GJRu2Z7znmNDTx0GTCRdnTNyYX07ZbKzyeWUFFV53UcCTGtGqM2M7+ZzQM2A2875z7axzITzGy2mc0uLy9v55gikS0+xs/d3xpCZXU9N7xcovstyhe0qqidc43OuSFALnCEmQ3cxzIPO+eKnHNF2dnZ7RxTJPL1y+nEtWMKmbJwE/+au8HrOBJCDuioD+fcdmAaMDYoaUSi3GXHFDC8Vzq3TVpE+c5ar+NIiGjNUR/ZZpYW+DwROAlYEuRcIlHJ7zPuOHsQu2sbueW1hV7HkRDRmj3qHGCamZUAH9M8Rj0puLFEolfvLqlcdWIfJn9axlsLNnodR0JATEsLOOdKgKEdkEVEAiaMKmBySRm/fnUBIwsy6ZwU63Uk8ZDOTBQJQbF+H38853Aqqur4ja5dHfVU1CIhamCPzvxwdAET56xn+jId8hrNVNQiIewnJ/Th0Oxkbnj5U6pqdS2QaKWiFglhCbF+7jj7cDZsr+beqcu9jiMeUVGLhLii/AzOH5HHY++vZunGnV7HEQ+oqEXCwHVj+5KSEMOvX1mg08ujkIpaJAxkJMdx/di+zFpTwcuf6PTyaKOiFgkT5xXlMbRnGr9/czE7dtd7HUc6kIpaJEz4fMbtZwykoqqOP/17qddxpAOpqEXCyIDunbl4ZD7PfLSWkvXbvY4jHURFLRJmfnZyIVkp8dz4ygIam/TGYjRQUYuEmU4Jsdw4rh8l63fw7Kx1XseRDqCiFglDpw3uzlGHZnLnW0vYskvXrY50KmqRMGRm3Hb6QKrrG/n9G7o8fKRTUYuEqd5dUrj82AJe+mQ9H63a6nUcCSIVtUgY+8kJfeiRlshNry6kobHJ6zgSJCpqkTCWGOfn1+P7sXTTTp77uNTrOBIkKmqRMDdmQDeKCzK4699LdcZihFJRi4Q5M+Om8QPYUV3PX6Yu8zqOBIGKWiQC9O/eifOP6MlTM9eyYrMuhRppVNQiEeKakwpJivNz26TFuhRqhFFRi0SIzJR4rvpGH6YvK2fa0s1ex5F2pKIWiSAXj8ynIDuZ2yctpq5Bh+tFChW1SASJi/Hx63H9WbWliqc/XOt1HGknKmqRCHN83y4c2yeLe6cu1+F6EUJFLRKBfnlqPypr6rlvmu5cHglU1CIRqF9OJ84dnsvfP1jLuq27vY4jB0lFLRKhrjn5MPw+444purpeuFNRi0Sorp0SmDCqgMklZcxZu83rOHIQVNQiEWzCqAKyU+P53Rs6CSacqahFIlhyfAzXnFTInLXbeGvBRq/jSBupqEUi3LlFeRzWNZU/vLVEJ8GEKRW1SITz+4xfjuvH2q27dRJMmFJRi0SB0YXZOgkmjKmoRaKEToIJXypqkSihk2DCV4tFbWZ5ZjbNzBaZ2UIzu6ojgolI+9NJMOGpNXvUDcA1zrn+QDFwhZn1D24sEQkGnQQTnlosaudcmXPuk8DnO4HFQI9gBxOR4NBJMOHngMaozSwfGAp8tI95E8xstpnNLi8vb6d4ItLe9j4J5k2dBBMWWl3UZpYCvARc7Zyr/PJ859zDzrki51xRdnZ2e2YUkXb2+Ukwb+okmHDQqqI2s1iaS/ofzrmXgxtJRIJtz0kw6yp289TMNV7HkRa05qgPAx4DFjvn7gp+JBHpCKMLsxldmM09U5dTUVXndRzZj9bsUR8NfAc4wczmBR6nBjmXiHSAX43rR1VtA/dO1UkwoSympQWcc+8D1gFZRKSDFXZN5YIjevL0h2u5qLgXvbukeB1J9kFnJopEuZ+eVEhSrJ/fv7HY6yjyNVTUIlEuKyWeK07ozdQlm3l/+Rav48g+qKhFhO8elU9ueiK3T15EY5NOggk1KmoRISHWz/Wn9GXJxp1MnFPqdRz5EhW1iAAwblAOw3ulc+eUZeyqbfA6juxFRS0iAJgZN47rx5Zdtfzt3ZVex5G9qKhF5HNDe6Zz+pDuPPLeKjZsr/Y6jgSoqEXkC64b2xeAP76la1aHChW1iHxBj7REvn/sIbw67zPmrtM1q0OBilpEvuL/jutNVko8t0/WNatDgYpaRL4iJT6Ga09uvmb15E/LvI4T9VTUIrJP5xbl0bdbKr9/YwnVdY1ex4lqKmoR2Se/z7j1tAFs2F7Ng++u8DpOVFNRi8jXOrIgk9OHdOdv01exdmuV13GilopaRPbrl6f2I9Zn/GbSIq+jRC0VtYjsV9dOCVz5jT78Z/Fm3lmyyes4UUlFLSItuvToQyjITubW1xdRU683FjuailpEWhQX4+OWbw5g7dbdPPb+aq/jRB0VtYi0yqjCbMYO6MZ976zgM10HpEOpqEWk1W4c348m5/itbtvVoVTUItJquelJXHF8byaXlDF9WbnXcaKGilpEDsiEUQUUZCXzq1c+1RmLHURFLSIHJCHWz2/PHERpRTX3TF3udZyooKIWkQM28tBMzh2eyyPvrWJxWaXXcSKeilpE2uSXp/ajc2IsN7z8qe5cHmQqahFpk/TkOG4a3595pdt55sO1XseJaCpqEWmz04d059g+Wdzx1hJKK3Z7HSdiqahFpM3MjN+fNQifGddNLKFJQyBBoaIWkYOSm57EjeP6MXPVVp7WEEhQqKhF5KB9a0Qeowuz+cObS1izRdetbm8qahE5aGbGH84eRIzf+PnE+RoCaWcqahFpFzmdE7n5mwP4eM02Hn1/lddxIoqKWkTazdnDejBmQFfunLKUBRt2eB0nYqioRaTdmBl/OOtwMpPjufK5uVTVNngdKSKoqEWkXaUnx3H3t4awemsVt72u+yy2BxW1iLS7kYdm8qPjDuWF2aVMLinzOk7Ya7GozexxM9tsZgs6IpCIRIarTyxkaM80fvFSCSvLd3kdJ6y1Zo/6SWBskHOISISJ9fu4/9vDiI/x8cOn57BL49Vt1mJRO+emAxUdkEVEIkz3tET+esFQVpbv4rqJ83FOx1e3hcaoRSSojuqdxfWn9OWNTzfy8HQdX90W7VbUZjbBzGab2ezyct1LTUT+5/JjCxg3KIc73lrCjBVbvI4TdtqtqJ1zDzvnipxzRdnZ2e31bUUkApgZfzzncHp3SeFH//hEby4eIA19iEiHSI6P4bFLRhDrNy594mO27qr1OlLYaM3hec8BM4HDzGy9mV0W/FgiEonyMpJ45OIiNlXW8P2nZlNTr7uYt0Zrjvq4wDmX45yLdc7lOuce64hgIhKZhvZM557zhzCvdDtXPT+XhsYmryOFPA19iEiHGzswh5vG92fKwk26M0wrxHgdQESi06VHH0JVbQN/+vcyEuL8/PaMgZiZ17FCkopaRDxzxfG9qapr5MF3V5IY6+fGcf1U1vugohYRz5gZ1405jOq6Rh57fzVNznHT+P4q6y9RUYuIp8yMm7/ZHzN4YsYaqusa+e2Zg/D7VNZ7qKhFxHNmxk3j+5McF8N901awu66RP583mFi/jncAFbWIhAgz49oxh5EU7+ePby1la1UtD1w4nM6JsV5H85z+XYlISPnRcb3507mDmbW6grMf/IDSit1eR/KcilpEQs45w3N56ntHsrmyhjMfmMGctdu8juQpFbWIhKSRh2by8o+OJjk+hvMfnsmTM1ZH7fWsVdQiErJ6d0nhtSuOYXRhNre8vogrn58XlXc2V1GLSEjrnBTLw98p4udjDmNyyWeM/+v7zF0XXUMhKmoRCXk+n3HF8b159vJi6hqaOOdvM7nr7WXUR8kFnVTUIhI2igsyefPqYzl9SHfunbqcsx74gAUbdngdK+hU1CISVjolxHLXeUN48MJhlO2o4bT73ufW1xeys6be62hBo6IWkbB0yqAcpl4zmguP7MWTH6zhxLv+y6vzNkTkJVNV1CIStjonxvKbMwbyrx8dTVZKPFc9P4/T75/BBxF2A10VtYiEvSF5abz+42O467zBVFTV8e1HP+K7T8yKmPFrC8YB5EVFRW727Nnt/n1FRFpSU9/IUzPXcN87K6isaeD4w7L58Qm9Gd4rw+to+2Vmc5xzRfucp6IWkUhUWVPP0zPX8tj7q6moqqO4IIMJowo4rrALvhC8hKqKWkSi1u66Bp6bVcoj01exsbKGnhlJXFTck/OK8khLivM63udU1CIS9eobm5iycCNPzVzLrNUVxMf4GH94d84e3oPiQzI938tWUYuI7GVxWSVPf7iW1+Z9xq7aBnqkJXLG0O6cOTSX3l1SPMmkohYR2YfqukbeXryJlz9Zz/Rl5TQ5KOyawtgB3RgzsBv9czp12P0bVdQiIi3YXFnD5E/LmLJwI7NWV9DkIC8jkRP7dWVUYTbFh2SSGOcP2s9XUYuIHIAtu2r5z6JNvLVwIzNXbqW2oYm4GB9H5GcwqjCLo3tn0bdbp3a9Aa+KWkSkjWrqG5m1uoLpy8qZvrycZZt2AZCaEMPwXumMyM9gRH4Gh+d2JiG27Xvc+ytq3dxWRGQ/EmL9jCrMZlRhNgAbd9Tw4aqtzFpTwcerK3h36VIA4vw+huSl8fyE4nY/gkRFLSJyALp1TuCMoT04Y2gPALZV1TF77TZmr6lgR3V9UA7zU1GLiByE9OQ4TurflZP6dw3az9BFmUREQpyKWkQkxKmoRURCnIpaRCTEqahFREKcilpEJMSpqEVEQpyKWkQkxAXlWh9mVg6sbeOXZwGRdQvhlmmdo4PWOfIdzPr2cs5l72tGUIr6YJjZ7K+7MEmk0jpHB61z5AvW+mroQ0QkxKmoRURCXCgW9cNeB/CA1jk6aJ0jX1DWN+TGqEVE5ItCcY9aRET2oqIWEQlxIVPUZjbWzJaa2Qozu97rPO3FzPLMbJqZLTKzhWZ2VWB6hpm9bWbLAx/TA9PNzO4N/B5KzGyYt2vQdmbmN7O5ZjYp8PwQM/sosG4vmFlcYHp84PmKwPx8T4O3kZmlmdlEM1tiZovNbGSkb2cz+2ng73qBmT1nZgmRtp3N7HEz22xmC/aadsDb1cwuCSy/3MwuOZAMIVHUZuYH7gdOAfoDF5hZf29TtZsG4BrnXH+gGLgisG7XA1Odc32AqYHn0Pw76BN4TAAe7PjI7eYqYPFez+8A7nbO9Qa2AZcFpl8GbAtMvzuwXDi6B3jLOdcXGEzzukfsdjazHsCVQJFzbiDgB84n8rbzk8DYL007oO1qZhnAzcCRwBHAzXvKvVWcc54/gJHAlL2e3wDc4HWuIK3rq8BJwFIgJzAtB1ga+Pwh4IK9lv98uXB6ALmBP+ATgEmA0XzGVsyXtzkwBRgZ+DwmsJx5vQ4HuL6dgdVfzh3J2xnoAZQCGYHtNgkYE4nbGcgHFrR1uwIXAA/tNf0Ly7X0CIk9av63wfdYH5gWUQIv9YYCHwFdnXNlgVkbgT03XIuU38VfgOuApsDzTGC7c64h8Hzv9fp8nQPzdwSWDyeHAOXAE4HhnkfNLJkI3s7OuQ3An4B1QBnN220Okb2d9zjQ7XpQ2ztUijrimVkK8BJwtXOucu95rvlfbMQcJ2lm44HNzrk5XmfpQDHAMOBB59xQoIr/vRwGInI7pwOn0/xPqjuQzFeHCCJeR2zXUCnqDUDeXs9zA9MigpnF0lzS/3DOvRyYvMnMcgLzc4DNgemR8Ls4GjjNzNYAz9M8/HEPkGZmMYFl9l6vz9c5ML8zsLUjA7eD9cB659xHgecTaS7uSN7OJwKrnXPlzrl64GWat30kb+c9DnS7HtT2DpWi/hjoE3i3OI7mNyRe8zhTuzAzAx4DFjvn7tpr1mvAnnd+L6F57HrP9IsD7x4XAzv2eokVFpxzNzjncp1z+TRvy3eccxcC04BzAot9eZ33/C7OCSwfVnuezrmNQKmZHRaY9A1gERG8nWke8ig2s6TA3/medY7Y7byXA92uU4CTzSw98Erk5MC01vF6kH6vwfVTgWXASuBXXudpx/U6huaXRSXAvMDjVJrH5qYCy4H/ABmB5Y3mI2BWAp/S/I665+txEOt/HDAp8HkBMAtYAbwIxAemJwSerwjML/A6dxvXdQgwO7CtXwHSI307A7cCS4AFwNNAfKRtZ+A5msfg62l+5XRZW7Yr8L3Auq8ALj2QDDqFXEQkxIXK0IeIiHwNFbWISIhTUYuIhDgVtYhIiFNRi4iEOBW1iEiIU1GLiIS4/wecQTmUPnjbjwAAAABJRU5ErkJggg==\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAA7AklEQVR4nO3dd3ic1ZX48e+ZGfVerWZJtmXcuzE2YIqppoQQSH6QkEpCCCSBZJ/dhd0sKZvNbnazaZDgOCFsCnEKHWLABAjFgG25d1uSJVmy1ZvVy9zfH1M0M5qRRvLI0ojzeR49Hr3vq3eupdHRnXPPvVeMMSillAp/lolugFJKqdDQgK6UUlOEBnSllJoiNKArpdQUoQFdKaWmCNtEPXF6eropLCycqKdXSqmwtHPnzgZjTIa/cxMW0AsLCykuLp6op1dKqbAkIhWBzmnKRSmlpggN6EopNUVoQFdKqSkiqIAuIveJyAEROSgi9/s5f5mItIrIHufHQyFvqVJKqWGNOCgqIguBLwCrgF7gZRF50RhT4nPp28aYG8ahjUoppYIQTA99HrDNGNNpjOkH3gQ+Mr7NUkopNVrBBPQDwFoRSRORWOA6YLqf69aIyF4ReUlEFvi7kYjcJSLFIlJcX19/Fs1WSinla8SAbow5DHwf2AK8DOwBBnwu2wUUGGOWAA8Dzwa410ZjzEpjzMqMDL918SM6WnOG/91ylMb2njF9vVJKTVVBDYoaYx4zxqwwxlwCNAPHfM63GWPanY83AxEikh7y1gKl9e08/HoJ9RrQlVLKS7BVLpnOf/Nx5M//4HM+S0TE+XiV876NoW2qQ6TV0eTefvt43F4ppcJWsFP/nxKRNKAPuNcY0yIidwMYYzYAtwJfEpF+oAu4zYzTVkhREY6A3qMBXSmlvAQV0I0xa/0c2+Dx+BHgkRC2K6AomxXQHrpSSvkKu5mikTZXD913XFYppT7Ywi6gR7kCep/20JVSylPYBXRXD713QAO6Ukp5CruArj10pZTyLwwDumNQtEd76Eop5SXsArp7ULRPB0WVUspT2AV0d8pFyxaVUspL2AV0nSmqlFL+hV1At1iESKtFe+hKKeUj7AI6ONIu2kNXSilvYRnQI20WnSmqlFI+wjKgR9k05aKUUr7CMqBHaspFKaWGCMuAHmWzaspFKaV8hGdAj9AeulJK+QrLgK5li0opNVSwW9DdJyIHROSgiNzv57yIyE9FpERE9onI8pC31ENUhAZ0pZTyNWJAF5GFwBeAVcAS4AYRKfK5bD0w2/lxF/BoiNvpJdKqKRellPIVTA99HrDNGNNpjOkH3sSxUbSnm4DfGof3gWQRyQ5xW910UFQppYYKJqAfANaKSJqIxALXAdN9rskFTnp8XuU85kVE7hKRYhEprq+vH2ubdVBUKaX8GDGgG2MOA98HtgAvA3uAMXWPjTEbjTErjTErMzIyxnILQAdFlVLKn6AGRY0xjxljVhhjLgGagWM+l1Tj3WvPcx4bFzooqpRSQwVb5ZLp/DcfR/78Dz6XPA98ylntshpoNcacDmlLPURarZpyUUopH7Ygr3tKRNKAPuBeY0yLiNwNYIzZAGzGkVsvATqBz45HY10cPXQdFFVKKU9BBXRjzFo/xzZ4PDbAvSFs17CibBb6Bgx2u8FikXP1tEopNamF50xR5zZ0vbpRtFJKuYVlQI+yWQHo6dOArpRSLmEZ0F099J4BzaMrpZRLWAb0KFdA1x66Ukq5hXVA1xy6UkoNCuuArj10pZQaFKYB3TkoqrXoSinlFpYB3V22qLNFlVLKLSwDujvlogFdKaXcwjSgO1Iu2kNXSqlBYRnQI7WHrpRSQ4RlQB9MueigqFJKuYRnQI/QQVGllPIVlgE90qopF6WU8hWWAT0qQuvQlVLKV7A7Fn1NRA6KyAER2SQi0T7nPyMi9SKyx/nx+fFproOrh64pF6WUGjRiQBeRXOCrwEpjzELACtzm59I/GWOWOj9+FeJ2eomwCiKaclFKKU/BplxsQIyI2IBY4NT4NWlkIkKUzaI9dKWU8jBiQDfGVAM/ACqB0zg2gN7i59JbRGSfiDwpItP93UtE7hKRYhEprq+vP6uGR1ot2kNXSikPwaRcUoCbgBlADhAnInf4XPYCUGiMWQy8CvzG372MMRuNMSuNMSszMjLOquFREVYdFFVKKQ/BpFyuBE4YY+qNMX3A08CFnhcYYxqNMT3OT38FrAhtM4fSHrpSSnkLJqBXAqtFJFZEBLgCOOx5gYhke3z6Id/z4yEqQgO6Ukp5so10gTFmm4g8CewC+oHdwEYR+Q5QbIx5HviqiHzIeb4J+Mz4NdkhymbVQVGllPIwYkAHMMZ8E/imz+GHPM4/CDwYwnaNKNKmPXSllPIUljNFwbFAV0+fDooqpZRLWAd03SRaKaUGhXVA102ilVJqUBgHdKv20JVSykPYBnTHoKjm0JVSyiVsA7qmXJRSylvYBvRIHRRVSikvYRvQtYeulFLewjig66CoUkp5CtuAHmmzMGA39GtQV0opIIwDepRNN4pWSilPYRvQI226r6hSSnkK24AeZbMC2kNXSimXMA7o2kNXSilPYRvQI905dJ0tqpRSEGRAF5GvichBETkgIptEJNrnfJSI/ElESkRkm4gUjktrPeigqFJKeQtmk+hc4KvASmPMQsAK3OZz2Z1AszGmCPgR8P1QN9RXpAZ0pZTyEmzKxQbEiIgNiAVO+Zy/CfiN8/GTwBXO/UfHzeCgqKZclFIKggjoxphq4Ac4Nos+DbQaY7b4XJYLnHRe3w+0Amm+9xKRu0SkWESK6+vrz6rhURE6KKqUUp6CSbmk4OiBzwBygDgRuWMsT2aM2WiMWWmMWZmRkTGWW7hFWjXlopRSnoJJuVwJnDDG1Btj+oCngQt9rqkGpgM40zJJQGMoG+orOkIDulJKeQomoFcCq0Uk1pkXvwI47HPN88CnnY9vBV43xpjQNXOoSKsjh64pF6WUcggmh74Nx0DnLmC/82s2ish3RORDzsseA9JEpAT4OvDAOLXXLSpC69CVUsqTLZiLjDHfBL7pc/ghj/PdwEdD2K4R6UxRpZTyNgVmimpAV0opCOeA7qpy0V2LlFIKCOOAbrNasFqE3gHNoSulFIRxQAfdV1QppTyFfUDXfUWVUsohrAN6pPbQlVLKLawDepTNqnXoSinlFNYBPVJTLkop5RbWAV0HRZVSalD4B3SdWKSUUkCYB/RIm0Wn/iullFNYB3QdFFVKqUFhHdAjNeWilFJuYR3QozTlopRSbmEe0K3aQ1dKKadg9hSdIyJ7PD7aROR+n2suE5FWj2seCnC7kNKUi1JKDRpxgwtjzFFgKYCIWHHsH/qMn0vfNsbcENLWjcBRtqiDokopBaNPuVwBlBpjKsajMaOlOXSllBo02oB+G7ApwLk1IrJXRF4SkQVn2a6guCYWjfN+1EopFRaCDugiEgl8CPiLn9O7gAJjzBLgYeDZAPe4S0SKRaS4vr5+DM31FhVhBdD1XJRSitH10NcDu4wxtb4njDFtxph25+PNQISIpPu5bqMxZqUxZmVGRsaYG+3i2oZO0y5KKTW6gH47AdItIpIlIuJ8vMp538azb97woiJ0o2illHIZscoFQETigKuAL3ocuxvAGLMBuBX4koj0A13AbeYcJLa1h66UUoOCCujGmA4gzefYBo/HjwCPhLZpI9MeulJKDQr7maKA1qIrpRRhHtA15aKUUoPCOqBrykUppQaFd0B3ply0h66UUmEe0CNtrh665tCVUiqsA3qUK6DrRtFKKRXeAd3VQ9ep/0opFeYBXXvoSik1KMwDurMOXXvoSikV3gHdPSjap4OiSikV1gHdnXLRskWllArvgK4zRZVSalBYB3SLRYi06kbRSikFYR7QQfcVVUopl7AP6JE2i84UVUoppkBAd20UrZRSH3QjBnQRmSMiezw+2kTkfp9rRER+KiIlIrJPRJaPW4t9RGrKRSmlgCB2LDLGHAWWAoiIFagGnvG5bD0w2/lxAfCo899xF2WzaspFKaUYfcrlCqDUGFPhc/wm4LfG4X0gWUSyQ9LCEURFaA9dKaVg9AH9NmCTn+O5wEmPz6ucx7yIyF0iUiwixfX19aN8av9GW7bY1TtAvy4VoJSagoIO6CISCXwI+MtYn8wYs9EYs9IYszIjI2Ost/ESFTG6gH79T9/m4ddLQvLcSik1mYymh74e2GWMqfVzrhqY7vF5nvPYuIu0Bp9yae7opayhg0On28a5VUopde6NJqDfjv90C8DzwKec1S6rgVZjzOmzbl0QRjMoWlrfDsCplq6g7z9gN2Nql1JKnWtBBXQRiQOuAp72OHa3iNzt/HQzUAaUAL8E7glxOwMazaDoaAP6n3ecZOm3t9DY3jPm9iml1LkyYtkigDGmA0jzObbB47EB7g1t04IzmkHR0voOAJo7++jqHSAm0hrw2vKGDr75/EG6+gY4VtvOmviokLRXKaXGS/jPFB3FoGhJXbv78anWwL30/gE7X//zHuzGkW6pbOo4u0YqpdQ5EPYBPdJqHVXKJTPB0dMeLu3yi7fK2FXZwn9+ZBE2i1DZ1BmStiql1HgK+4Du6KGPPCja3TfAyaZOLp6dDgQO6AeqW/nRq8e4fnE2Ny/LJTclhopGDehKqckv/AO6zULfgKFvhMlC5Y0d2A1cXJSOCFS3dA+5pqd/gK//eQ+pcZH8x4cXIiLkp8ZqD10pFRbCPqAvzEkCYPP+4askS+scefA5WQlMS4jmtJ8e+o4TzRyrbeehG+eTHBsJoAFdKRU2wj6gr5ubSVFmPBveLMOYwDXjrpLFmenx5CRH+x0ULak7A8CqwlT3sYK0WFo6+2jt6gtxy5VSKrTCPqBbLMIXL5nJ4dNtvHks8PowpfXt5CbHEBNpJTs5hlN+Ui6l9R0kRNvISBgsUcxPjQXgpPbSlVKTXNgHdICbluaSnRTNo38vDXhNSV07RZnxAOQmx1Dd0jWkR19S186sjHhExH0sPzUOYNwHRt86Vs/F33+d9p7+cX0epdTUNSUCeqTNwufXzmTbiSZ2VjQPOW+3G8rqO5iV4QjoOUnR9Pbbaezo9bqutL7dfY1Lfpqjhz7eefQX9p6iqrmLsvr2kS9WSik/pkRAB7jt/OkkxUSw4c2hvfTTbd109Q0wK9PR285JjgG8SxfbuvuoO9Pj7sW7xEfZSIuLHNfJRcYY3i1tHNImpZQajSkT0OOibHz6wkJePVTL8dozXudcM0TdPXR3QB/Mo5e6r4kbcu/pqbHjmnKpbOqk2hnIq5o1oCulxmbKBHSAz1xYSHSEhQ1vlnkddwVrzxw6ePeGXeu8+PbQwVHpMp4pF1fv3FEfrwFdKTU2Uyqgp8ZFctv5+Ty3p9qrKqW0vp2kmAjS4hy15cmxEURHWHwCejsRVmG6s6rFU0FqLKdausZtq7utJQ1kJkRRlBGvKRel1JhNqYAO8MVLZ2IR4VGPXLpjsDPOXb0iIuQkx3jVopfUtVOQFkeEdei3ZHpqLHYzPvltYwzvlTZy4aw0clNitIeulBqzKRfQs5NiuHVlHk8WV3HaGbBL6jqGVK84Shc9cuj17RRlDE23ABSkOUsXxyHtcqy2ncaOXi4sSicnOYZqzaErpcYo2A0ukkXkSRE5IiKHRWSNz/nLRKRVRPY4Px4an+YG50uXzsJuDL94s4zWzj4a2odWr+Qkxbin//f226lo7HRXwfhyTS4ajzz6u6UNAI4eenIMzZ19dPZqLbpSavSC2uAC+AnwsjHmVudm0UMTzfC2MeaG0DVt7KanxnLzslw2ba9kzSzHvhy+PfSc5BjqzvTQ0+9YhXHAbvwOiAJkJkQRZbNQ2Rj60sWtJY3kp8aSlxJLXsrgYG1RZkLIn0spNbWN2EMXkSTgEuAxAGNMrzGmZZzbddbuvbyIvgE733nhEACzfHvoydEA1Lb2UOJcuMs36LtYLI7B0lD30PsH7Gwra+SiojRnmxwBXUsXlVJjEUzKZQZQDzwuIrtF5FfOPUZ9rRGRvSLykogs8HcjEblLRIpFpLi+PvC6K6FQmB7Hh5bkUN3SRaTVwnRn79fFFTyrW7oGF+4KENDBUekS6lr0g6faONPTz5pZjjXac/3UxyulVLCCCeg2YDnwqDFmGdABPOBzzS6gwBizBHgYeNbfjYwxG40xK40xKzMyMsbe6iB9eV0RIlCYHovNp3rFc7ZoaV072UnRxEcFzkC5eujDreg4Wlud+fM1Mx099MyEKKwWobpFFwJTSo1eMAG9Cqgyxmxzfv4kjgDvZoxpM8a0Ox9vBiJEJD2kLR2DoswE7rlsFh9dMX3IuewkR8rlVEsXJX7WcPFVkBZLZ+/AkPVfzsZ7pY3MmZbgXt3RZrWQlRitPXSl1JiMGNCNMTXASRGZ4zx0BXDI8xoRyRJnkbeIrHLetzHEbR2Tf7xmLl+4ZOaQ49ERVtLjIznV6uihBxoQdXFVunimXY7WnGH7iSa/1xtj+Mqm3bwUYOONnv4BdpQ3uQdtXXJTtHRRKTU2wVa5fAV4wlnhUgZ8VkTuBjDGbABuBb4kIv1AF3CbCWVuYpzkJMewu7KFjt4Bv2u4eCpIG1wXfUVBCqdbu/j4L9+n324o/saVQyYkHT59hhf2nqKzp5/1i7KH3G/vyVa6++xc6BvQk2MC/pFQSqnhBBXQjTF7gJU+hzd4nH8EeCR0zTo3spOieeVgLRC4wsUlL2Wwh97bb+eeJ3bR1NmLMY7UySXneY8JvHKwBoDdJ1swxnitsQ5QXOEI2ud77I4EjoBe09ZN/4B9SN5fKaWGE2wPfUpyDYyC/0W5PEVHWMlKjKayqZP/+Oshdle28MOPLeEbzx7gpQM1QwL6lkO1iEBTRy+VTZ3u2aYuuyqamZURR4pzfRmX3JQYBuyG2jM97qoXpYbzu/creP1wLa1djq0SO3oGeOjG+Vzn552hmto+0F1AV8BMiPLedi6Q/NRYthyq4TfvVfD5i2fwkeV5rJubyauHahiwD2aYTjZ1cvh0GzcvzQVgz8kWr/sYY9hZ0cyKgpQhz+Eup9Q8ugrST/52nP3VbcRG2piblUhrVx9vH2+Y6GaNaGtJA49vPTHRzZhSPtAB3RU8Z2XGD0mJ+JOfFsuZ7n5WFabyz+vnArB+YTYN7b3sKB/Me7vSLV9eV0RspJXdlS1e9ylr6KC5s89vQPe3tK9SgXT3DdDQ3sOn1hTw+89fwM8+sZzZ0+Kpap78pa8/e6OE720+THffwEQ3ZcrQgM7I+XOX1TPTmJkRxyMfX+YeBL1sTgZRNgsvH6hxX7flUC1zsxKYmRHP4rwkdld6b4vn2ibPfw/dUU6pqy6qYLj+8Od5TJzLC4NVO3v77eyqbKZvwAzp8Kix+0AHdNcvwXnTggvot67I4/V/uIzMxGj3sbgoG5eel8FLB05jtxsa23soLm/i6gVZACzLT+HgqTavXsiuimaSYiKYmT70eWMjbaTGRer0fxUU1+vEc7wlLyWW6uahm6D7U9nYyU9fOx7SCXPBOHDKUeUFUFyuVV2h8oEO6OnxUfzuzlV8YnXBWd1n/aIsatt62H2yhdcO12E3cPX8aQAsm55Mv91w8FSr+/piZ/7cYvGf5slNjjmrlEvfgJ2fvnacujadoDTVuXrieR4bs+SlxNDTb6e+vWfEr//D9kp++Oqxc96B2OEszc1Oima7BvSQ+UAHdIC1szOGnfIfjHVzpxFhFV4+cJoth2rITY5hQU4iAEvzkwHcbytbOnspqWv3m25xyUmOPqu3zM/tOcUPXz3GI2+UBHV9a1cfR2raxvx8Z8NuN+O2E9QHQVVzJ1aLMM1jUN/1zjOYIH2gujXoa0NpR3kTM9PjuHLeNHZVNNM/oK+BUPjAB/RQSIqJ4KKidF7cd5q3jjdw9YJp7kHWzIRocp0TmGAwsC/PDxzQc5MdW955vg3+12f2c+f/7RixLXa7YeNbjt2antpZxZnuvhG/5h/+vIf1P3n7nFccVDV3cuWP3uSeJ3ad0+edSqqbu8hOivaas5Cb7OitjxSkjTHsdwb0k+dwENVuN+wob+b8wlTOn5FKR+8Ah0+fGfkL1Yg0oIfIdQuzOd3aTW+/navnZ3mdW5af7B4Y3VnRjNUiLJmeFPBeuSkxdPYO0NLpCMblDR1s2l7Ja0fqhpRA+vr7sTqO1bbz2YsK6egd4KmdVcNef/BUK387XEdWYjTffuEQ//7iIez28c+nltSd4dZH36OsvoPXj9TSFMI1cj5Iqpq7hsxXyE0JrvT1ZFMXrV2O11jVOG6C7utY3Rlau/o4f0Yq5xc6OjYjpV16++3c+PA7/P79ijE95x+2VfK9zYfH9LXhRAN6iFw1fxpWi5ASG+F+kbosy0/hVGs3tW3d7KxoZkFOIrGRgdM8uT6VLhvfLsNmsZAQZeOxd4bvRW94s4ycpGj+5bp5LMtP5rfvVQwboH/+RikJUTY2f3Utn7mwkMfeOcE9T+wa11Ky/VWtfOwXjmUT/vvWxdgNvHa4dtyebyqrbulyz2J2iY+ykRIbMWLpoqt3LnJuUy6u/PmqwlSyk2LIS4lxHwvkREMH+6tb+bfnDvDivlOjer72nn7+66XDPPF+xTkf/D3XNKCHSEpcJHdckM+dF88YMmV/mTOPvqO8iT0nW4ZNt8DgW+bqli7q2rp5sriKW1fmcduq6WzefzrggOmuyma2n2jizrUzibBa+PSaQsoaOni7xP8kk5K6M2w+cJpPXVhASlwk3/rQAv7thvm8cqiGf3xy3yi/A4HZ7YaTTZ28cbSOX7xZyu2/fJ+YCCt/uXsNH12RR47HEgwqeL39dmrbut09ck95KbEjBun91a1EWIWl05PPacple3kzWYnRTE91tHtVYSrFFU3DBtuSOseeBfmpsXz9T3vdWzcG4887TtLW3U9H7wAN7VP7naAG9BD69k0L+fK62UOOL8hJJNJqYdP2Srr6BoYdEAWPWvTmLn69tZx+u50vXjKTT19YiDGG37xX7vfrNr5ZRlJMBLed71gu+LpF2aTHR/Hbd/1f//O/lxJts/K5i2a4j9158QxuOz+f1w/Xes1+HatXD9Wy8FuvsPa/3+Czj+/gP186Qn5qLE9+aQ0z0uMQEa5ekMXbx+vPei/VUy1dPLu7+qzbHC5qWruxG8jzs0REXkpMED30FuZkJTAzPf6c9dCNMWw/0cj5M1Ld40znz0ilob2XEw2Bt3gsrW9HBP501xoK02O567c7vSrHAukfsPPYOyfchQ/l47CNZEtn74ip0HNFA/o5EGWzMi8nka0ljhWFRwroqXGRREdYOFLTxu/fr+D6xTkUpMWRlxLLtQuz2LStko4e7+BXVt/OK4dq+OTqAuKcL95Im4WPr5rO60frqPTZbamysZPn9pzi4xfkkxbvvezBBc6BqqM1/geqDp5qDTr4btpeSUK0je/dvIg/f3ENu/7tKjbft5bspMEgdPWCafT023nr2NntYvXb9yq4/097eONI3VndZ7wZY9i0vfKsxw2qnBuh5PnpoecmOyYXBer1GmM4UN3Gotxk8lIcC8L19I//jM2TTV3UtvWwyiMt6VqgbscwefTS+nZyk2PISorm/z67ioRoG595fId7kl4gmw/UUN3Sxf1XOjpa5cP80Rirf3lmPx/d8O6kGAfSgH6OLJueDDjqbnNGWHRLRMhNjuGpXdW09/Rz96WD67nfefEM2rr7eXqX92DnL98uc6RZLiz0Ov6J1QVYRfjd++Vexze8VYpVhLv8rBXvSgntqhz6y9Lc0ctNj2zlv18+Ouz/ARy5y3dKGrhhcQ4fvyCfVTNSSfVZjAwcb7mTYyPOOu1S4ex9fefFQ5O6FLK0vp0Hn97Pf7985Kzu455U5DflEkN3nz3ghiyuAdFFuUlMT43FmHOz9aFr8HPVjMFlo2dlxJEaF8n2E4GDc0nd4CY0Ockx/PZzq7BZhFsefZd/enIvjX5q7o1xVHzNzIjjk2sKsFok5D3047VneOlADX0DhpcO+N/74FzSgH6OuPLoI/XOXXKSHasuXnpeBgtyBitiluensGR6Mr/eWo7dOWHp87/ZwabtJ/nYyrwhi4xNS4zmmoVZ/GnHSX72Rgm/fucEv3u/gieLq/joyjymecx6dZmeGkN6fKTfgP5eWSP9dsNTO6tG7KW/dazeWfUzbdjrbFYLV8ydxmuHa+k7i3rkisZO0uOjONHQMakXfXLlg5/aVXVWE8iqm7sQwevdjotroDRQKmVfdQsAi3KTPOrWxz+Pvv1EI0kxEcz2WN1URFhZkBKwh263G8rqO7xWRJ09LYG/ff1SvnjJTJ7eVc26/32T379f4VXP/l5ZIweq2/jC2plE2azkpcRQ3hDa/+PP3ighJsLK9NQYXtg7usHa8RBUQBeRZBF5UkSOiMhhEVnjc15E5KciUiIi+0RkeaB7fVCtKEjBInDBzLSRL2bwbfQ9l83yOi4i3HnxDE40dPCxX7zH9T99h+0nmvjHa+bwjevn+73XXWtn0jtg539eOcp3XjzEvz17ABG4+9JZfq8XEZblp7DLz9vZrSUNWATO9PTz4t7heyRbDtaQGhcZ1B+xaxZMo627n21lY5s1aIyhorGDGxZnc+W8zEk9U7a03tFLNMbxzmqsqlu6mJYQTaRt6K9xXurwQXp/dSuRVgvnZcUzPdW1ecv459Ed9edDZ0mvmpFKZVMntX5+Zqdau+jqGxiy5lJclI0Hr5vH5vvWMi87gW88e4CrfvQWz+2pxm43/PKtMtLjI7l5mWPV08K0uJD20MsbOnh+7ynuWF3ALcvz2HaiiZrWoe3/5VtlPLFtbOWWoxVsD/0nwMvGmLnAEsC3oHM9MNv5cRfwaMhaOEXkpcTy0n2XcPv5Q/c39eejK6dz/5WzWTUjdci59QuzyEuJ4fDpNr66roi3/3kd915eRHSE1e+9lkxP5vB3ruXIv1/Lnoeu4r0H1/HuA+vcv8j+rChIobyxc8hb2a0lDVw+J5PzpsUP+yLtG7Dz2pE6rpibGdRGHWtnZxAdYXGvVAmOPP+9f9jFtrKRdzNs7Oilo3eA/NRYvnH9fPoGDP91limN8VJa105OUjQ3Lc1l0/ZKGoKYou9PVXOn3/w5DK7tEqiHfqC6lTlZCUTZHOv82ywy7pUudWe6OdHQMWRTFxg+j+76AxhoV7HzpiWw6Qur+cUnVxBls3DfH/dw1Y/e5I2j9XxqTaH792JGehzlDR0hK1189O+lRFgtfH7tDG5ckoMx8FefLSeP1pzhey8d5l+fOcCjfy8NyfMOZ8TfNBFJAi4BHgMwxvQaY1p8LrsJ+K1xeB9IFhFdXd/HnKyEoHchWp6fwv1Xnud3Wd8Iq4Xn7r2Idx+4gq9fPYekmIgR7yciREdYSY6NJDspZshAqL/nB7xWwqtq7qS8sZOLitL5+Kp89la1uqeO+9pW1sSZ7n73ImUjiYm0cul5Gbx6qBa73fDygdNc//Db/HXfaf7xyX0j1sW79notTI+lMD2OO9fO4Old1SMOmk2E0vp2ZmXG86XLZtHTbx9zeqi6pctv/hwgITqC5AC16MYY9le1sjDXkcqzWoSc5Jhxr3QpLnf8LPx1UubnJBITYfVbj+5KUQ23CY2IcM2CLDZ/dS0P374MYyAx2sYdHus0FaTFhqx0sbqli6d2VXH7qnwyE6KZlRHPgpzEIWmXH716jPhIG+sXZvH9l4+4Z3GPl2CiywygHnhcRHaLyK9ExPdPZS5w0uPzKucxNU7S4qNIih05kI/V4rwkbBbxyqO/66zSuXh2OjcvzyM6wsIT2yr9fv2WQzXERFhZOzs96Oe8ZkEWNW3d3P37ndz9+13MTI/jv29ZTGVTJ794c/jURGWToxeXn+p4aX758iKmJUbx3b8eGu7LzspYenrGGErrO5iVEU9RZjzrF2bx23cr3DM2gzVgN5xu6Q7YQwdnpYufIF3Z1Elbdz+L8wbHZqanxnDSz2zRjp7+gMtH2O2Gb79wkH1VLUG1eWdFM1E2i/sPiacIq4UVBSm87yflVlrfTnJshN8BdV8Wi3Djkhxe/fqlbH1gndfXFKY7XhuhSLts+HspIngVFdy4JIc9J1vcFWUHqlt5+WANn7t4Bg/fvowbFmfzvc1H+NVZpNlGEkxAtwHLgUeNMcuADuCBsTyZiNwlIsUiUlxff3Ylamp8RUdYmZ+T6NXDfaekgYyEKGZnxpMUE8GNi3N4bk/1kF94YwxbDtZyyXnpAdNA/qybm4nVImw5VMvnLprBX+6+kI+dP53rF2Xz87+X+A04LuUNnYjgnqwSF2Xj3suL2F3Zwt4Q1wgbY/jW8we55sdvjXoQt7ath/aefnf64J7LijjT08/vAswtCHyfbvrtxj0JzR9HLfrQgO6aIbrII7DmJfufiPSVTbv5XIA1hPZXt/L41nL+tOOk3/O+9lW1sCAncciG6i5rZqVxtPbMkBRUaV07RRnBbULjYrUICdHeHZ5C5zaQZ1u6WNfWzZ+KT3LrijyvirUbFjuSEi84Z7L+6NVjJMVEcOdax2TDH/+/pVy/KJvv/vXwuA3aBxPQq4AqY8w25+dP4gjwnqoBz+RwnvOYF2PMRmPMSmPMyoyMDN/TapJZnp/CvqpW+gfsGGN4t7SBi2aluX+xPrG6gM7eAZ7b4/02c391KzVt3UPWtBlJcmwk/3PrYh7/7Pk8dON892Dfv14/D4sI//5i4N52ZVMnOUkxRNkG/4DcvCyX2Eir3/U/7HbDP/5l75gmIv3ob8f5v3fLOVbbPuqUTmm9I33gGuBbmJvE5XMy+PXW8lFNrHItCxEo5QKDs0V930nsr3IOiE5LcB+bnhpDQ3sPXb2Dqa3efjtbSxrYUd7sd/XPVw85ykwDpd089Q/YOVDdxuK85IDXXDjLUTDwvs+YSWl9e9Cb0AwnLyUmJKWLf9heSd+AfUhRQV5KLCsKUnhh7yl2Vzbz2pE67rpkJonOPyw2q4Uf37aUDy/NYUa6//GAszViQDfG1AAnRWSO89AVgO9v1vPAp5zVLquBVmPMxBdlqrOyvCCFrr4BjtSccfacermwaDCFsiQvifnZiTyxrdIraLxysAarRVg3N3PUz/mR5XlcPsf763KSY/jKFUVsOVTL34/6nzRU0dhBvs8gb0J0BDctzeWFfado7fR+F/Hi/tP8ZWcVDzy9b8ikq+H87r1yfvracW5amkOEVUY9iclfPvjL64po6ujlj9uD6+nCYPXKcCmXvJQYuvoGhkx42V/dytzsBK/qGNcAuWfOfX91Cz3Oev5XPHbkcnEF9MM1Z0Z8p3K8rp2uvgGWOudj+LMoN4n4KBvvlg4G9JbOXhrae5mVefYBMMJqcZQujuLn7ctuNzy1q4oLZ6UN2fgd4MbF2RypOcM/P7WP1LhIPuMzLyTCauHHty3jsjmj/90IRrBVLl8BnhCRfcBS4HsicreI3O08vxkoA0qAXwL3hLqh6txb7qyd31nR7J7lepFHQBcRPrE6n8On2/jm8wfZVtbIgN2RbllVmEpKEDnPYN158QxmpMfxrecP+p3RWNHYSUHa0PTDHavz6e6z86THRKy+ATs/3HKUmelxWEX412f3B5UP/+u+0zz0/EGunJfJ/350CecXpvL6KAN6aX37kE3JVxSkcsGMVDa+VRb0bM1qPzsV+fJXi+5aMtc3j+1vDXVXPjsvJcZri0VwVCAdrT3D0unJ9PbbOVY7/PK3rjy7Z97el81q4YIZqbznEdBd72iGGxAdjcK0uLNKuewob+JkUxe3LM/ze/66xdlYBI7VtnP3pTPds7bPlaACujFmjzNVstgY82FjTLMxZoMxZoPzvDHG3GuMmWWMWWSMKR7fZqtzITc5hmmJUeyqbGZrSQMz0+OGBJCbl+WyfmEWf9x+kv+38X1WfPdVjte1c/WC4ScTjVaUzcq/3TCP8sZOtvjMKG3v6aexo5d8PwF9QU4Sy/KTvVba+0txFeWNnfzLdfP45/Vzeft4A0/vGj71cvBUK1/70x5W5Kfw8O3LsVktrJubyfG69mFz+75K69uZ6WdT8nsvL6KmrZtnRmiHS1VzF+nxUcOOUbiCtGe6pKKxkzPd/V75c4DpzuDvWbq4/UQT502L59YVeeyoaKL+zGBue8shR4D/2lXnASOnXfacbCUx2ubOYweyZlYaJxo63BOuSutcJYuhCuixVDR2jrl08aldVcRFWrl2of90YmZCNBcVpZOZEMUnVxeeRUvHRmeKqoBEhOX5KRSXN7OtrJELi4ZOioqNtPHoHSvY9dBV/Ozjy7nsvAzmZydy/aLQV61eMjuDmAjrkD0oXVP+AwWLOy4ooKyhg3dLG+nuG+Anrx1jeX4yV8zL5I4LClien8y///XQsPXgfz9aT++AnQ2fXEFMpCOIXu5MKQVKA/lTWtfht5567ex0FuUm8eibpUHt3jNcyaJLrp8ZoK46ad/JXunxUUTaLO4eev+AneLyJlbNSOXahVkYMxjEwZFumZuVwNqidOKjbO6B1kD2VbWwOC854LaLLhfOcrwDdPXSS+rbibRZhiwRPFaF6XG09/SPqXSxq3eAzftruG5R9rDLX//4/y3lmXsvcr9OziUN6GpYy/NTqG7poqN3gIuLApcgxkfZuH5xNj++bRmb71vrtZF2qNisFpZOT6bYZyDSlQP3zaG7XL84m+TYCH7/fgW/ebec2rYe/unauYgIFovw/VsW09HTP+yga0ldO1mJ0aR71O/PTI+jIC026LTLme4+atq6/aYPRIR7Ly+iorGTzT7pjbeP1/PLt8q8epVVzV3D5s8BEqMjSIy2uYN0W3cfG98qY93cTK8BUXCU++WlDJYuHjrdRkfvABfMSGPOtAQK02LdaZfmjl52lDdx1fxpWCzCgpxE9lcH3sKwu8+x0Ntw6RaXuVkJpMZFstW5PG5pXbsjNTbCH4JguUoXK8YwMPrKwRrae/q5ZYX/dItLWnzUsKmw8aQBXQ1reUEy4NgEYXWQyxaMp5WFKRw+3Ua7x2qTrkEufzl0cJRgfnRFHlsO1fKzN0q45LwMr//L7GkJ3HNZEc/tOcW7AdeObx8SiEWEy+dk8m5po1d1iDGGe/+wix9u8V7ArKx++PTB1fOnUZQZz8/fKMFuNxhj+NXbZXzq19v5j82H2eIchLTbjWNjiyCChue66I+9fYLWrj6+7kyT+JqeEutOuWx3bULhXOb22oXZvFfaSGtnH68fcWyEfpVzjZ6FuUkcPt0WcGD04Kk2+u2GJcMMiLpYLMKamWm8V9rorNkPTYWLi+td3HBL9RpjeOd4A20+5bhP7qwiLyWGVX5muk4WGtDVsBbkJBFptbAoN4nk2NANco7VioIU7Ab2eMxgrWzqIC0uckjdsaePX1DAgN3Q1t3PP10zZ8j5L102i0irhTf9LOFrtzsCi7+e9eVzM+npt3uV2v25+CR/3Xeax7eWe81w9S1Z9GWxCPdcNosjNWd45WANDz69n+/+9TDXLshizrQEvvPCITp7+2no6KG33z5iDx0G10Vv7ujlsXdOcO2CLL8TewavdQT/98uaKEyLdS/edu3CLPrthr8druXVQ7VkJUa78/CLcpPo7be7K3h8uQZElwxTsuhpzaw0Trd2c7T2DJVNncwK0YAoDJYuVgSodOns7eerf9zDHY9t4yM/f9f9juVUSxdbSxv4yPK8EdNGE0kDuhpWdISVr6wr4ksBFvI615YXpCACxRWDefSKxk6/A6KeZqTHceuKPO5Yne83oEVHWCnKjOfQ6aGpg9Nt3XT2DvgN6BfMSCUmwupOu9Sf6eE//nqYrMRozvT0e5U1ltS1Y7NIwHcS4JhtmJcSw5c37eaPO07ylXVF/Ozjy/nuzQupbuni4ddLhl0215erh77x7TI6evvdg5j+TE+NpaWzj9auPnaUN3GBxxK3S/KSyE6K5tk91bx1vJ4r52e6B3Zd389AefS9J1uYlhhFVlJwaThXPfqmbZXYTeA1XMbCVbp4wk/K5WRTJx/5+bu8uO8Un15TQF1bNzf/fCu7K5t5Znc1xsAtyyf3BHgN6GpEX7liNuvHYZBzLBKjI5gzLcFrQk9FYycFwyw05vKDjy7hux9eFPD8vOxEv7vPD7eWSHSElYuK0nn9SB3GGL7z4iG6++z87s5VpMdHeU26Kq1vpyAtNuBMSXAEnK9deR6RzpmF/3D1HCwW4fzCVG5dkccv3yrjzaOOdxHDzRJ1cW04/tg7J7hxcQ5zshICXuvq8b9+pJbWrj6vNVdca6W8fbyBzt4BrvKYNDYzPY64SGvASpd9Va3DTijyNSM9jqzEaJ5yVvyEqmTRxV/p4tvH67nxkXc41dLF4585n2/ftJCn77mI2Egbt218n8e3lrOqMNVv7flkogFdhZ0VBSnsrmxhwG7o6R/gVGtXSH7R5uck0tDe41WeByMvDrVubibVLV1sfKuMF/ae4svripg9LYEbl2Tz+pE69zotrjVcRnLLijz2f+tqPrzMuzf4wPq5xEZaeeSNEiDYHrrjmv4Bu3vXnkBcpYtP7XQE0gtmeueK1ztL9eKjbKz2OOcYGE3y20Nv7eqjrKGDJUEMiLqICBfOSnOPk8xMD3VA9y5dPFpzhjt/U8y0hGhe+MrF7kk/RZnxPHPPhSzMTaKhvYdbVkzu3jloQFdhaGVhCu09/RytOeOc2h54QHQ05mU7eq+HfdIuJXWOxaHSAkyUumyOYxmL/3zpCLMz491Twm9amkvvgJ1XDtTQN2CnorEj6Hywv1U50+Oj+Kdr5zJgNyTHRrj3yRyOK6B/ZHkeM0f4Y+KaLbq1tIHc5JghpYIrC1PJSozminmZXksswODAqG/J5f4qR5APZkDU0xpn2iU3OSbk5X+epYvdfQPc98fdJEbbeOILFwzpGKTFR/HE5y9gwx0ruHVFcEtfTyQN6CrsrCxw9A53VjS5SxZDEdDnZycC/gL6mWEXh8pJjmGuM5XxX7csck+pX5KXRGFaLM/traayqZO+AXPWFRu3r8pn6fRkzssMnDrxNC8rkQfXz+WB9XNHvDYlNoLYSCvGOMYGfFktwrP3XsR3P7xwyLlFeYl099kpqfceGN3rmiGamxxUe11cAT3U6RYYrHSpaOzgB68c5UjNGf7n1iVeJameoiMcE4lCVTo5ns7tvFSlQiAvJYbMhCiKK5rptzveNoci5eJYKz7abw/9mhHWdX/wunmcauliRYF33vlDS3N5+PXj7vVJznaAz2oRnvj8BQwEOdPRYhG+GOSAtogwPSWWo7Vn/K5ZDgQc2HRVvOyvamVuVqL7+N6TLcxIjxv1Us95KbFcPifDPXkrlFy16H/YVsnTu6v55OqCcXmeiaA9dBV2RIQVBY4ZrBWNncRFWgOmQ0bLd2C0sb2H5s6+EXuKl56Xwe2r8occv2mpYyebR51571CU4MVF2dwr+IWaK0UTKKAHMiM9nlg/A6OOAdHg8+eeHv/sKj61pnBMXzscV+ni07urmZURx79cNy/kzzFRNKCrsLSiwDGDdfuJJvLT4ka1VvZw5mUnUFrf7l4kK5jdcoYzKyOeRblJnGrtJjMhatwCcagsnZ5MUWb8qJd3tbpnjA4G9Nq2bmraukdV4XIuuEoXbRbhJ7ctm5Ap+uNFA7oKSyuds/UOnW6jMAT5c5d52Yn02w3Hax2BvCQEq/3dtDQHCN0CU+Ppy+uKeOX+S8b0B3JhbhKHTrfR22/nlYM13P37ncDgqp2TyX1XzOZ/P7Yk4CSrcKUBXYWlBTmJREc4Xr4jTSoajXk+A6Mlde3ERFjJSRr72hw3LslBZHwG+EJNRMY8+LcoN4nuPjuX/s8bfPF3O2lo7+E/bl7IsvyUkb/4HPvI8jxuWjr5yxBHSwdFVViKsFpYkpfMthNNFKSGbrJHYVoc0REWdx69pK6dWZlxZzXde1piNL/+9PmcN8yknqng/MJUIqxCZkIU37h+PtcsmBb0pugqNDSgq7C1sjDFEdBD2EO3WoQ5WYlePfRQLEo2VaoohjM9NZY9D11NbKQ1ZGMaanSC+vMpIuUisl9E9ojIkM0rROQyEWl1nt8jIg+FvqlKeVu/MJv52YkszAltHnR+dgKHaxwrOp5u9b/crfIvLsqmwXwCjaaHfrkxxv/aog5vG2NuONsGKRWshblJbL5vbcjvOy87kU3bT7LVuZRuOAxmKgU6KKrUEK6B0Rf2OhbW0h66ChfBBnQDbBGRnSJyV4Br1ojIXhF5SUQW+LtARO4SkWIRKa6vH7rutFKTgWsa/98O14643K1Sk0mwAf1iY8xyYD1wr4hc4nN+F1BgjFkCPAw86+8mxpiNzs2mV2ZkZIy1zUqNq4ToCKanxtDdZ6cwPW7Y5W6VmkyCeqUaY6qd/9YBzwCrfM63GWPanY83AxEiEngDSqUmuXnO9Uhma7pFhZERA7qIxIlIgusxcDVwwOeaLHEObYvIKud9G33vpVS4cOXRNX+uwkkwVS7TgGec8doG/MEY87KI3A1gjNkA3Ap8SUT6gS7gNmOCXA5OqUlIA7oKRyMGdGNMGbDEz/ENHo8fAR4JbdOUmjiXnJfO5y+e8YGYEKSmDp0pqpQfsZE2vnHD/IluhlKjosP3Sik1RWhAV0qpKUIDulJKTREa0JVSaorQgK6UUlOEBnSllJoiNKArpdQUoQFdKaWmCJmoGfoiUg9UjPHL04HhNtuYSJO1bZO1XaBtG4vJ2i6YvG2brO2C0bWtwBjjd7naCQvoZ0NEio0xKye6Hf5M1rZN1naBtm0sJmu7YPK2bbK2C0LXNk25KKXUFKEBXSmlpohwDegbJ7oBw5isbZus7QJt21hM1nbB5G3bZG0XhKhtYZlDV0opNVS49tCVUkr50ICulFJTRNgFdBG5VkSOikiJiDwwwW35tYjUicgBj2OpIvKqiBx3/psyAe2aLiJviMghETkoIvdNhraJSLSIbBeRvc52fdt5fIaIbHP+TP8kIpHnsl0+bbSKyG4ReXEytU1EykVkv4jsEZFi57HJ8FpLFpEnReSIiBwWkTWTpF1znN8r10ebiNw/Sdr2Nefr/4CIbHL+XoTkdRZWAV1ErMDPgPXAfOB2EZnIbWX+D7jW59gDwGvGmNnAa87Pz7V+4B+MMfOB1cC9zu/TRLetB1hnjFkCLAWuFZHVwPeBHxljioBm4M5z3C5P9wGHPT6fTG273Biz1KNeeaJ/ngA/AV42xszFsVXl4cnQLmPMUef3aimwAugEnpnotolILvBVYKUxZiFgBW4jVK8zY0zYfABrgFc8Pn8QeHCC21QIHPD4/CiQ7XycDRydBN+354CrJlPbgFhgF3ABjhlyNn8/43Pcpjwcv+TrgBcBmURtKwfSfY5N6M8TSAJO4CyumCzt8tPOq4Gtk6FtQC5wEkjFsQXoi8A1oXqdhVUPncFvhkuV89hkMs0Yc9r5uAaYNpGNEZFCYBmwjUnQNmdKYw9QB7wKlAItxph+5yUT+TP9MfBPgN35eRqTp20G2CIiO0XkLuexif55zgDqgcedaapfiUjcJGiXr9uATc7HE9o2Y0w18AOgEjgNtAI7CdHrLNwCelgxjj+3E1YXKiLxwFPA/caYNs9zE9U2Y8yAcbwNzgNWAXPPdRv8EZEbgDpjzM6JbksAFxtjluNIN94rIpd4npygn6cNWA48aoxZBnTgk8KYBL8DkcCHgL/4npuItjlz9jfh+GOYA8QxNG07ZuEW0KuB6R6f5zmPTSa1IpIN4Py3biIaISIROIL5E8aYpydT2wCMMS3AGzjeXiaLiM15aqJ+phcBHxKRcuCPONIuP5kkbXP17DDG1OHIBa9i4n+eVUCVMWab8/MncQT4iW6Xp/XALmNMrfPziW7blcAJY0y9MaYPeBrHay8kr7NwC+g7gNnOEeFIHG+lnp/gNvl6Hvi08/GnceSvzykREeAx4LAx5oeTpW0ikiEiyc7HMTjy+odxBPZbJ6pdAMaYB40xecaYQhyvq9eNMZ+YDG0TkTgRSXA9xpETPsAE/zyNMTXASRGZ4zx0BXBootvl43YG0y0w8W2rBFaLSKzz99T1PQvN62wiByvGOKhwHXAMR+71Xye4LZtw5MH6cPRW7sSRd30NOA78DUidgHZdjOOt5D5gj/PjuoluG7AY2O1s1wHgIefxmcB2oATHW+OoCf65Xga8OFna5mzDXufHQdfrfqJ/ns42LAWKnT/TZ4GUydAuZ9vigEYgyePYhLcN+DZwxPk78DsgKlSvM536r5RSU0S4pVyUUkoFoAFdKaWmCA3oSik1RWhAV0qpKUIDulJKTREa0JVSaorQgK6UUlPE/wdnqbx6GpBeZAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] @@ -288,16 +367,16 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 188, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "the input words is: praise., How\n", - "the predict words is: much\n", - "the true words is: much\n" + "the input words is: whiles, thou\n", + "the predict words is: art\n", + "the true words is: art\n" ] } ], @@ -318,13 +397,20 @@ " print('the true words is: ' + y_data)\n", "test(model)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.7.3 64-bit", + "display_name": "Python 3", "language": "python", - "name": "python_defaultSpec_1598180286976" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -336,7 +422,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3-final" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/paddle2.0_docs/text_generation/text_generation_paddle.ipynb b/paddle2.0_docs/text_generation/text_generation_paddle.ipynb deleted file mode 100644 index fc47a419..00000000 --- a/paddle2.0_docs/text_generation/text_generation_paddle.ipynb +++ /dev/null @@ -1,526 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 基于GRU的Text Generation\n", - "文本生成是NLP领域中的重要组成部分,基于GRU,我们可以快速构建文本生成模型。" - ] - }, - { - "cell_type": "code", - "execution_count": 74, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'2.0.0-alpha0'" - ] - }, - "execution_count": 74, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import paddle\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "paddle.__version__" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 复现过程\n", - "## 1.下载数据\n", - "文件路径:https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n", - "保存为txt格式即可" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2.读取数据" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Length of text: 1115394 characters\n" - ] - } - ], - "source": [ - "# 文件路径\n", - "path_to_file = './shakespeare.txt'\n", - "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", - "\n", - "# 文本长度是指文本中的字符个数\n", - "print ('Length of text: {} characters'.format(len(text)))" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "First Citizen:\n", - "Before we proceed any further, hear me speak.\n", - "\n", - "All:\n", - "Speak, speak.\n", - "\n", - "First Citizen:\n", - "You are all resolved rather to die than to famish?\n", - "\n", - "All:\n", - "Resolved. resolved.\n", - "\n", - "First Citizen:\n", - "First, you know Caius Marcius is chief enemy to the people.\n", - "\n" - ] - } - ], - "source": [ - "# 看一看文本中的前 250 个字符\n", - "print(text[:250])" - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "65 unique characters\n" - ] - } - ], - "source": [ - "# 文本中的非重复字符\n", - "vocab = sorted(set(text))\n", - "print ('{} unique characters'.format(len(vocab)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3.向量化文本\n", - "在训练之前,我们需要将字符串映射到数字表示值。创建两个查找表格:一个将字符映射到数字,另一个将数字映射到字符。" - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "metadata": {}, - "outputs": [], - "source": [ - "# 创建从非重复字符到索引的映射\n", - "char2idx = {u:i for i, u in enumerate(vocab)}\n", - "idx2char = np.array(vocab)\n", - "# 用index表示文本\n", - "text_as_int = np.array([char2idx[c] for c in text])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'\\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, \"'\": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}\n" - ] - } - ], - "source": [ - "print(char2idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n", - " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n", - " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n", - " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n" - ] - } - ], - "source": [ - "print(idx2char)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "现在,每个字符都有一个整数表示值。请注意,我们将字符映射至索引 0 至 len(vocab)." - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[18 47 56 ... 45 8 0]\n", - "1115394\n" - ] - } - ], - "source": [ - "print(text_as_int)\n", - "print(len(text_as_int))" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58 1 15 47 58 47 64 43 52]\n" - ] - } - ], - "source": [ - "# 显示文本首 13 个字符的整数映射\n", - "print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 预测任务\n", - "给定一个字符或者一个字符序列,下一个最可能出现的字符是什么?这就是我们训练模型要执行的任务。输入进模型的是一个字符序列,我们训练这个模型来预测输出 -- 每个时间步(time step)预测下一个字符是什么。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 创建训练样本和目标\n", - "接下来,将文本划分为样本序列。每个输入序列包含文本中的 seq_length 个字符。\n", - "\n", - "对于每个输入序列,其对应的目标包含相同长度的文本,但是向右顺移一个字符。\n", - "\n", - "将文本拆分为长度为 seq_length 的文本块。例如,假设 seq_length 为 4 而且文本为 “Hello”, 那么输入序列将为 “Hell”,目标序列将为 “ello”。" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "metadata": {}, - "outputs": [], - "source": [ - "seq_length = 100\n", - "def load_data(data, seq_length):\n", - " train_data = []\n", - " train_label = []\n", - " for i in range(len(data)//seq_length):\n", - " train_data.append(data[i*seq_length:(i+1)*seq_length])\n", - " train_label.append(data[i*seq_length + 1:(i+1)*seq_length+1])\n", - " return train_data, train_label\n", - "train_data, train_label = load_data(text_as_int, seq_length)" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "training data is :\n", - "First Citizen:\n", - "Before we proceed any further, hear me speak.\n", - "\n", - "All:\n", - "Speak, speak.\n", - "\n", - "First Citizen:\n", - "You\n", - "------------\n", - "training_label is:\n", - "irst Citizen:\n", - "Before we proceed any further, hear me speak.\n", - "\n", - "All:\n", - "Speak, speak.\n", - "\n", - "First Citizen:\n", - "You \n" - ] - } - ], - "source": [ - "char_list = []\n", - "label_list = []\n", - "for char_id, label_id in zip(train_data[0], train_label[0]):\n", - " char_list.append(idx2char[char_id])\n", - " label_list.append(idx2char[label_id])\n", - "\n", - "print('training data is :')\n", - "print(''.join(char_list))\n", - "print(\"------------\")\n", - "print('training_label is:')\n", - "print(''.join(label_list))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 用`paddle.batch`完成数据的加载" - ] - }, - { - "cell_type": "code", - "execution_count": 70, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "batch_size = 64\n", - "def train_reader():\n", - " for i in range(len(train_data)):\n", - " yield train_data[i], train_label[i]\n", - "batch_reader = paddle.batch(train_reader, batch_size=batch_size) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 基于GRU构建文本生成模型" - ] - }, - { - "cell_type": "code", - "execution_count": 71, - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "import numpy as np\n", - "\n", - "vocab_size = len(vocab)\n", - "embedding_dim = 256\n", - "hidden_size = 1024\n", - "class GRUModel(paddle.nn.Layer):\n", - " def __init__(self):\n", - " super(GRUModel, self).__init__()\n", - " self.embedding = paddle.nn.Embedding(size=[vocab_size, embedding_dim])\n", - " self.gru = paddle.incubate.hapi.text.GRU(input_size=embedding_dim, hidden_size=hidden_size)\n", - " self.linear1 = paddle.nn.Linear(hidden_size, hidden_size//2)\n", - " self.linear2 = paddle.nn.Linear(hidden_size//2, vocab_size)\n", - " def forward(self, x):\n", - " x = self.embedding(x)\n", - " x = paddle.reshape(x, [-1, 1, embedding_dim])\n", - " x, _ = self.gru(x)\n", - " x = paddle.reshape(x, [-1, hidden_size])\n", - " x = self.linear1(x)\n", - " x = paddle.nn.functional.relu(x)\n", - " x = self.linear2(x)\n", - " x = paddle.nn.functional.softmax(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch: 0, batch: 50, loss is: [3.7835407]\n", - "epoch: 0, batch: 100, loss is: [3.2774005]\n", - "epoch: 0, batch: 150, loss is: [3.2576294]\n", - "epoch: 1, batch: 50, loss is: [3.3434656]\n", - "epoch: 1, batch: 100, loss is: [2.9948606]\n", - "epoch: 1, batch: 150, loss is: [3.0285468]\n", - "epoch: 2, batch: 50, loss is: [3.133882]\n", - "epoch: 2, batch: 100, loss is: [2.7811327]\n", - "epoch: 2, batch: 150, loss is: [2.8133557]\n", - "epoch: 3, batch: 50, loss is: [3.000814]\n", - "epoch: 3, batch: 100, loss is: [2.6404488]\n", - "epoch: 3, batch: 150, loss is: [2.7050896]\n", - "epoch: 4, batch: 50, loss is: [2.9289591]\n", - "epoch: 4, batch: 100, loss is: [2.5629177]\n", - "epoch: 4, batch: 150, loss is: [2.6438713]\n", - "epoch: 5, batch: 50, loss is: [2.8832304]\n", - "epoch: 5, batch: 100, loss is: [2.5137548]\n", - "epoch: 5, batch: 150, loss is: [2.5926144]\n", - "epoch: 6, batch: 50, loss is: [2.8562953]\n", - "epoch: 6, batch: 100, loss is: [2.4752126]\n", - "epoch: 6, batch: 150, loss is: [2.5510798]\n", - "epoch: 7, batch: 50, loss is: [2.8426895]\n", - "epoch: 7, batch: 100, loss is: [2.4442513]\n", - "epoch: 7, batch: 150, loss is: [2.5187433]\n", - "epoch: 8, batch: 50, loss is: [2.8353484]\n", - "epoch: 8, batch: 100, loss is: [2.4200597]\n", - "epoch: 8, batch: 150, loss is: [2.4956212]\n", - "epoch: 9, batch: 50, loss is: [2.8308532]\n", - "epoch: 9, batch: 100, loss is: [2.4011066]\n", - "epoch: 9, batch: 150, loss is: [2.4787998]\n" - ] - } - ], - "source": [ - "paddle.enable_imperative()\n", - "losses = []\n", - "def train(model):\n", - " model.train()\n", - " optim = paddle.optimizer.SGD(learning_rate=0.001, parameter_list=model.parameters())\n", - " for epoch in range(10):\n", - " batch_id = 0\n", - " for batch_data in batch_reader():\n", - " batch_id += 1\n", - " data = np.array(batch_data)\n", - " x_data = data[:, 0]\n", - " y_data = data[:, 1]\n", - " for i in range(len(x_data[0])):\n", - " x_char = x_data[:, i]\n", - " y_char = y_data[:, i]\n", - " x_char = paddle.imperative.to_variable(x_char)\n", - " y_char = paddle.imperative.to_variable(y_char)\n", - " predicts = model(x_char)\n", - " loss = paddle.nn.functional.cross_entropy(predicts, y_char)\n", - " avg_loss = paddle.mean(loss)\n", - " avg_loss.backward()\n", - " optim.minimize(avg_loss)\n", - " model.clear_gradients()\n", - " if batch_id % 50 == 0:\n", - " print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n", - " losses.append(loss.numpy())\n", - "model = GRUModel()\n", - "train(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 模型预测\n", - "利用训练好的模型,输出初始化文本'ROMEO: ',自动生成后续的num_generate个字符。" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ROMEO:I the the the the the the the the the the the the the the the the the the the the the the the the th\n" - ] - } - ], - "source": [ - "def generate_text(model, start_string):\n", - " \n", - " model.eval()\n", - " num_generate = 100\n", - "\n", - " # Converting our start string to numbers (vectorizing)\n", - " input_eval = [char2idx[s] for s in start_string]\n", - " input_data = paddle.imperative.to_variable(np.array(input_eval))\n", - " input_data = paddle.reshape(input_data, [-1, 1])\n", - " text_generated = []\n", - "\n", - " for i in range(num_generate):\n", - " predicts = model(input_data)\n", - " predicts = predicts.numpy().tolist()[0]\n", - " # print(predicts)\n", - " predicts_id = predicts.index(max(predicts))\n", - " # print(predicts_id)\n", - " # using a categorical distribution to predict the character returned by the model\n", - " input_data = paddle.imperative.to_variable(np.array([predicts_id]))\n", - " input_data = paddle.reshape(input_data, [-1, 1])\n", - " text_generated.append(idx2char[predicts_id])\n", - " return (start_string + ''.join(text_generated))\n", - "print(generate_text(model, start_string=u\"ROMEO:\"))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}