JittorMirror/python/jittor/notebook/60分钟快速入门Jittor/计图入门教程 3 --- 尝试解决一个实际问题.ipynb

1161 lines
133 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. 第三章 - 尝试解决一个实际问题\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"目录:\n",
"* 3.1 加载 MNIST 数据集\n",
"* 3.2 定义模型\n",
"* 3.3 选择损失函数和优化器\n",
"* 3.4 模型训练并验证\n",
"* 3.5 可视化验证"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在,让我们依靠计图的强大力量,解决你的第一个实际问题吧! "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[38;5;2m[i 0202 23:00:55.471709 24 compiler.py:847] Jittor(1.2.2.27) src: /home/llt/.local/lib/python3.7/site-packages/jittor\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.473366 24 compiler.py:848] g++ at /usr/bin/g++\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.474573 24 compiler.py:849] cache_path: /home/llt/.cache/jittor/default/g++\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.489276 24 __init__.py:257] Found /usr/local/cuda/bin/nvcc(10.2.89) at /usr/local/cuda/bin/nvcc.\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.567521 24 __init__.py:257] Found gdb(8.1.0) at /usr/bin/gdb.\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.583515 24 __init__.py:257] Found addr2line(2.30) at /usr/bin/addr2line.\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.632399 24 compiler.py:889] pybind_include: -I/usr/include/python3.7m -I/usr/local/lib/python3.7/dist-packages/pybind11/include\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.657575 24 compiler.py:891] extension_suffix: .cpython-37m-x86_64-linux-gnu.so\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:55.868891 24 __init__.py:169] Total mem: 62.78GB, using 16 procs for compiling.\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.064987 24 jit_compiler.cc:21] Load cc_path: /usr/bin/g++\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.261122 24 init.cc:54] Found cuda archs: [75,]\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.443019 24 __init__.py:257] Found mpicc(2.1.1) at /usr/bin/mpicc.\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.511110 24 compiler.py:654] handle pyjt_include/home/llt/.local/lib/python3.7/site-packages/jittor/extern/mpi/inc/mpi_warper.h\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.556598 24 compile_extern.py:287] Downloading nccl...\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.632489 24 compile_extern.py:16] found /usr/local/cuda/include/cublas.h\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:56.634748 24 compile_extern.py:16] found /usr/lib/x86_64-linux-gnu/libcublas.so\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:57.221814 24 compile_extern.py:16] found /usr/local/cuda/include/cudnn.h\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:57.223031 24 compile_extern.py:16] found /usr/local/cuda/lib64/libcudnn.so\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:57.255155 24 compiler.py:654] handle pyjt_include/home/llt/.local/lib/python3.7/site-packages/jittor/extern/cuda/cudnn/inc/cudnn_warper.h\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:59.305098 24 compile_extern.py:16] found /usr/local/cuda/include/curand.h\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:59.306353 24 compile_extern.py:16] found /usr/local/cuda/lib64/libcurand.so\u001b[m\n",
"\u001b[38;5;2m[i 0202 23:00:59.338311 24 cuda_flags.cc:26] CUDA enabled.\u001b[m\n"
]
}
],
"source": [
"# 加载计图\n",
"import jittor as jt\n",
"\n",
"# 开启 GPU 加速\n",
"jt.flags.use_cuda = 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 任务:使用 Jittor 对 MNIST 手写数字进行识别"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**任务描述如下:** \n",
"* MNIST 手写数字数据库,主要收集了不同人群真实的手写数字记录(包括 0 到 9 十个数字)。该数据库包含训练集上 60,000 个示例,和测试集上 10,000 个示例。数据库详细信息可见http://yann.lecun.com/exdb/mnist/ \n",
"\n",
"\n",
"* MNIST 手写数字示例: \n",
"\n",
"<!-- ![avatar](mnist.png) -->\n",
"<!-- ![mnist.png](attachment:mnist.png) -->\n",
"<img src=\"mnist.png\" width=600 height=600>\n",
"\n",
"* 目标:利用 Jittor 对 MNIST 手写数字进行识别,辨认出手写数字所要表达的真实数字值。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**解决步骤如下:**\n",
"1. 通过 Jittor 加载 MNIST 手写数字的 **数据集**\n",
"\n",
"\n",
"2. 定义 **模型** :设计卷积神经网络; \n",
"\n",
"\n",
"3. 选择合适的 **损失函数** 和 **优化器** \n",
"\n",
"\n",
"4. 完成模型 **训练** 并 **验证** 的主代码块。 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.1 加载 MNIST 数据集"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"由于 MNIST 是一个常见的数据集,为方便使用,其数据已经被封装进 Jittor。我们可以直接调用 MNIST 类,来加载需要的数据集。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from jittor.dataset.mnist import MNIST\n",
"import jittor.transform as trans\n",
"\n",
"# 设置超参数 batch_size其值代表一个批次中含有多少个数据。\n",
"batch_size = 64\n",
"\n",
"# 创建 MNIST 训练集数据加载器\n",
"train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)\n",
"# 创建 MNIST 测试集数据加载器\n",
"val_loader = MNIST(train=False, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"为了更好地了解 MNIST 手写数字数据集,现在,我们尝试将数据进行可视化展示。 \n",
"以测试集数据为例:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"inputs.shape: [64,3,28,28,]\n",
"targets.shape: [64,]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAM3UlEQVR4nO3db6wU9b3H8c/nUpoY6QPwDzmhKL3EJ+QmwhWJCXhzDGmD+AAbiSkPGm7SePoATBsbco33AT5sTP/k+sAmp9GUNlwbEqoSYywUG0mjNh4NygECIkEB+WODScHEIPbbB2dsjrg7u+7M7ix836/kZHfnuzPzzcjHmZ3Z2Z8jQgCufv/WdAMABoOwA0kQdiAJwg4kQdiBJL42yJXZ5tQ/0GcR4VbTK+3Zba+yfcj2EdsPV1kWgP5yr9fZbc+QdFjStyWdkPS6pHURcaBkHvbsQJ/1Y8++TNKRiDgaERcl/V7SmgrLA9BHVcI+T9Lxaa9PFNO+wPaY7QnbExXWBaCivp+gi4hxSeMSh/FAk6rs2U9Kmj/t9TeLaQCGUJWwvy7pFtvfsv11Sd+TtKOetgDUrefD+Ii4ZHujpD9KmiHpqYjYX1tnAGrV86W3nlbGZ3ag7/rypRoAVw7CDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBI9j88uSbaPSTov6TNJlyJiaR1NAahfpbAX7oqIv9WwHAB9xGE8kETVsIeknbbfsD3W6g22x2xP2J6ouC4AFTgiep/ZnhcRJ23fKGmXpAcjYk/J+3tfGYCuRIRbTa+0Z4+Ik8XjWUnPSFpWZXkA+qfnsNu+1vY3Pn8u6TuSJutqDEC9qpyNnyvpGdufL+f/I+LFWroCULtKn9m/8sr4zA70XV8+swO4chB2IAnCDiRB2IEkCDuQRB03wqSwdu3atrUHHnigdN4PPvigtP7JJ5+U1rdu3VpaP336dNvakSNHSudFHuzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJ7nrr0tGjR9vWFixYMLhGWjh//nzb2v79+wfYyXA5ceJE29pjjz1WOu/ExJX7K2rc9QYkR9iBJAg7kARhB5Ig7EAShB1IgrADSXA/e5fK7lm/9dZbS+c9cOBAaX3RokWl9SVLlpTWR0dH29buuOOO0nmPHz9eWp8/f35pvYpLly6V1j/88MPS+sjISM/rfv/990vrV/J19nbYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEtzPfhWYPXt221qna/SdrifffvvtPfXUjU6/l3/48OHS+sGDB0vrc+bMaVvbuHFj6bxPPPFEaX2Y9Xw/u+2nbJ+1PTlt2hzbu2y/Uzy2/9cGYCh0cxj/G0mrLpv2sKTdEXGLpN3FawBDrGPYI2KPpHOXTV4jaUvxfIuke2vuC0DNev1u/NyIOFU8Py1pbrs32h6TNNbjegDUpPKNMBERZSfeImJc0rjECTqgSb1eejtje0SSisez9bUEoB96DfsOSeuL5+slPVdPOwD6peN1dttPSxqVdL2kM5I2S3pW0jZJN0l6T9L9EXH5SbxWy+IwHl277777Suvbtm0rrU9OTrat3XXXXaXznjvX8Z/z0Gp3nb3jZ/aIWNemtLJSRwAGiq/LAkkQdiAJwg4kQdiBJAg7kAS3uKIxN954Y2l93759leZfu3Zt29r27dtL572SMWQzkBxhB5Ig7EAShB1IgrADSRB2IAnCDiTBkM1ozIYNG0rrN9xwQ2n9o48+Kq0fOnToK/d0NWPPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcD87+mr58uVtay+99FLpvDNnziytj46Oltb37NlTWr9acT87kBxhB5Ig7EAShB1IgrADSRB2IAnCDiTB/ezoq9WrV7etdbqOvnv37tL6q6++2lNPWXXcs9t+yvZZ25PTpj1q+6TtvcVf+/+iAIZCN4fxv5G0qsX0X0bE4uLvhXrbAlC3jmGPiD2Szg2gFwB9VOUE3UbbbxeH+bPbvcn2mO0J2xMV1gWgol7D/itJCyUtlnRK0s/bvTEixiNiaUQs7XFdAGrQU9gj4kxEfBYR/5D0a0nL6m0LQN16CrvtkWkvvytpst17AQyHjtfZbT8taVTS9bZPSNosadT2Ykkh6ZikH/axRwyxa665prS+alWrCzlTLl68WDrv5s2bS+uffvppaR1f1DHsEbGuxeQn+9ALgD7i67JAEoQdSIKwA0kQdiAJwg4kwS2uqGTTpk2l9SVLlrStvfjii6XzvvLKKz31hNbYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEgzZjFL33HNPaf3ZZ58trX/88cdta3fffXfpvPxUdG8YshlIjrADSRB2IAnCDiRB2IEkCDuQBGEHkuB+9uSuu+660vrjjz9eWp8xY0Zp/YUX2o/5yXX0wWLPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcD/7Va7TdfDXXnuttH7bbbeV1t99993SetmQzZ3mRW96vp/d9nzbf7Z9wPZ+2z8qps+xvcv2O8Xj7LqbBlCfbg7jL0n6SUQsknSHpA22F0l6WNLuiLhF0u7iNYAh1THsEXEqIt4snp+XdFDSPElrJG0p3rZF0r39ahJAdV/pu/G2F0haIumvkuZGxKmidFrS3DbzjEka671FAHXo+my87VmStkv6cUT8fXotps7ytTz5FhHjEbE0IpZW6hRAJV2F3fZMTQV9a0T8oZh8xvZIUR+RdLY/LQKoQ8fDeNuW9KSkgxHxi2mlHZLWS/pp8fhcXzpEJQsXLiytd7q01slDDz1UWufy2vDo5jP7cknfl7TP9t5i2iOaCvk22z+Q9J6k+/vTIoA6dAx7RPxFUsuL9JJW1tsOgH7h67JAEoQdSIKwA0kQdiAJwg4kwU9JXwVuvvnmtrWdO3dWWvamTZtK688//3yl5WNw2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJcZ78KjI21/9Wvm266qdKyX3755dL6IH+KHNWwZweSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJLjOfgW48847S+sPPvjggDrBlYw9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0c347PMl/VbSXEkhaTwi/s/2o5IekPRh8dZHIuKFfjWa2YoVK0rrs2bN6nnZncZPv3DhQs/LxnDp5ks1lyT9JCLetP0NSW/Y3lXUfhkRP+tfewDq0s347KcknSqen7d9UNK8fjcGoF5f6TO77QWSlkj6azFpo+23bT9le3abecZsT9ieqNQpgEq6DrvtWZK2S/pxRPxd0q8kLZS0WFN7/p+3mi8ixiNiaUQsraFfAD3qKuy2Z2oq6Fsj4g+SFBFnIuKziPiHpF9LWta/NgFU1THsti3pSUkHI+IX06aPTHvbdyVN1t8egLp0czZ+uaTvS9pne28x7RFJ62wv1tTluGOSftiXDlHJW2+9VVpfuXJlaf3cuXN1toMGdXM2/i+S3KLENXXgCsI36IAkCDuQBGEHkiDsQBKEHUiCsANJeJBD7tpmfF+gzyKi1aVy9uxAFoQdSIKwA0kQdiAJwg4kQdiBJAg7kMSgh2z+m6T3pr2+vpg2jIa1t2HtS6K3XtXZ283tCgP9Us2XVm5PDOtv0w1rb8Pal0RvvRpUbxzGA0kQdiCJpsM+3vD6ywxrb8Pal0RvvRpIb41+ZgcwOE3v2QEMCGEHkmgk7LZX2T5k+4jth5vooR3bx2zvs7236fHpijH0ztqenDZtju1dtt8pHluOsddQb4/aPllsu722VzfU23zbf7Z9wPZ+2z8qpje67Ur6Gsh2G/hndtszJB2W9G1JJyS9LmldRBwYaCNt2D4maWlENP4FDNv/JemCpN9GxH8U0x6TdC4iflr8j3J2RPzPkPT2qKQLTQ/jXYxWNDJ9mHFJ90r6bzW47Ur6ul8D2G5N7NmXSToSEUcj4qKk30ta00AfQy8i9ki6fEiWNZK2FM+3aOofy8C16W0oRMSpiHizeH5e0ufDjDe67Ur6Gogmwj5P0vFpr09ouMZ7D0k7bb9he6zpZlqYGxGniuenJc1tspkWOg7jPUiXDTM+NNuul+HPq+IE3ZetiIj/lHS3pA3F4epQiqnPYMN07bSrYbwHpcUw4//S5LbrdfjzqpoI+0lJ86e9/mYxbShExMni8aykZzR8Q1Gf+XwE3eLxbMP9/MswDePdaphxDcG2a3L48ybC/rqkW2x/y/bXJX1P0o4G+vgS29cWJ05k+1pJ39HwDUW9Q9L64vl6Sc812MsXDMsw3u2GGVfD267x4c8jYuB/klZr6oz8u5L+t4ke2vT175LeKv72N92bpKc1dVj3qabObfxA0nWSdkt6R9KfJM0Zot5+J2mfpLc1FayRhnpboalD9Lcl7S3+Vje97Ur6Gsh24+uyQBKcoAOSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJP4JQ6Ub9g3W1GYAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"target: 7\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"num = 0 # 选择展示第几个 input 数据\n",
"for inputs, targets in val_loader: # 通过测试集加载器遍历每批次数据\n",
" print(\"inputs.shape:\", inputs.shape) # 查看 inputs 的形状\n",
" print(\"targets.shape:\", targets.shape) # 查看 targets 的形状\n",
" \n",
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n",
" plt.show() # 展示图像\n",
" print(\"target:\", targets[num].numpy()[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"通过上述打印信息,可以看到:\n",
"* 加载器每个批次的遍历,无论是 inputs 还是 targets都包含了 64 个数据。这与我们设置的 batch_size 相符;\n",
"\n",
"\n",
"* 在 inputs 中,每个数据实质为 3 个通道,每个通道包含 28\\*28 个像素点的数字集合。我们可以通过 imshow() 函数画出对应的手写数字图像;\n",
"\n",
"\n",
"* targets 中的每个数据,实质为 1 个数字,即手写数字图像所表达的真实数字值。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.2 定义模型"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"同样,我们需要继承 Module 类,并实现 \\_\\_init__ 函数和 execute 函数:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from jittor import nn, Module\n",
"\n",
"class Model(Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" self.conv1 = nn.Conv(3, 32, 3, 1) # 卷积层 1参数含义该层输入通道 3输出通道 32卷积核大小 3*3移动步长为 1\n",
" self.conv2 = nn.Conv(32, 64, 3, 1) # 卷积层 2参数含义该层输入通道 32输出通道 64卷积核大小 3*3移动步长为 1\n",
" self.bn = nn.BatchNorm(64) # 批量归一化层,参数含义:该层输入通道数为 64\n",
"\n",
" self.max_pool = nn.Pool(2, 2) # 池化层,参数含义:窗口大小为 2窗口移动步长为 2\n",
" self.relu = nn.Relu() # 非线性激活函数 Relu\n",
" self.fc1 = nn.Linear(64 * 12 * 12, 256) # 线性全连接 1参数含义输入通道数 64*12*12由上一步reshape变化得来输出通道数 256\n",
" self.fc2 = nn.Linear(256, 10) # 线性全连接 2参数含义输入通道数 256输出通道数 10\n",
" \n",
" def execute(self, x) :\n",
" x = self.conv1(x) # 作用第一层卷积层,输入由 batch_size*3*28*28 变为输出 batch_size*32*26*26\n",
" x = self.relu(x) # 通过非线性激活函数 Relu\n",
" \n",
" x = self.conv2(x) # 作用第二层卷积层,输入由 batch_size*32*26*26 变为输出 batch_size*64*24*24\n",
" x = self.bn(x) # 批量归一化操作\n",
" x = self.relu(x) # 通过非线性激活函数 Relu\n",
"\n",
" x = self.max_pool(x) # 池化操作,输入由 batch_size*64*24*24 变为输出 batch_size*64*12*12\n",
" x = jt.reshape(x, [x.shape[0], -1]) # 将 x 压缩成只保留第一维度,输入由 batch_size*64*12*12 变为输出 batch_size*(64*12*12)\n",
" x = self.fc1(x) # 作用第一层全连接,输入由 batch_size*(64*12*12) 变为输出 batch_size*256\n",
" x = self.relu(x) # 通过非线性激活函数 Relu\n",
" x = self.fc2(x) # 第二层全连接,并控制最后输出为 batch_size*10每个数据的 10 个分量,分别代表十个数字的相似度\n",
" return x\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"接下来,我们创建一个模型实例:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = Model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.3 选择损失函数和优化器"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"在本次实践中,我们采用交叉熵损失函数 CrossEntropyLoss(),以及随机梯度下降 (Stochastic Gradient Descent, SGD) 作为参数优化器。 \n",
"这些常用的损失函数和优化器我们可以通过计图的 nn 类获取得到。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# 设置损失函数\n",
"loss_function = nn.CrossEntropyLoss()\n",
"\n",
"# 设置优化器\n",
"learning_rate = 0.1\n",
"momentum = 0.9\n",
"weight_decay = 1e-4\n",
"optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.4 模型训练并验证"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 0 [0/938 (0%)]\tLoss: 2.367674\n",
"Train Epoch: 0 [10/938 (1%)]\tLoss: 2.302495\n",
"Train Epoch: 0 [20/938 (2%)]\tLoss: 2.215536\n",
"Train Epoch: 0 [30/938 (3%)]\tLoss: 0.997462\n",
"Train Epoch: 0 [40/938 (4%)]\tLoss: 0.957385\n",
"Train Epoch: 0 [50/938 (5%)]\tLoss: 0.932777\n",
"Train Epoch: 0 [60/938 (6%)]\tLoss: 0.419664\n",
"Train Epoch: 0 [70/938 (7%)]\tLoss: 0.452631\n",
"Train Epoch: 0 [80/938 (9%)]\tLoss: 0.402917\n",
"Train Epoch: 0 [90/938 (10%)]\tLoss: 0.428482\n",
"Train Epoch: 0 [100/938 (11%)]\tLoss: 0.604908\n",
"Train Epoch: 0 [110/938 (12%)]\tLoss: 0.391322\n",
"Train Epoch: 0 [120/938 (13%)]\tLoss: 0.407544\n",
"Train Epoch: 0 [130/938 (14%)]\tLoss: 0.389958\n",
"Train Epoch: 0 [140/938 (15%)]\tLoss: 0.894348\n",
"Train Epoch: 0 [150/938 (16%)]\tLoss: 0.264137\n",
"Train Epoch: 0 [160/938 (17%)]\tLoss: 0.304531\n",
"Train Epoch: 0 [170/938 (18%)]\tLoss: 0.083782\n",
"Train Epoch: 0 [180/938 (19%)]\tLoss: 0.360765\n",
"Train Epoch: 0 [190/938 (20%)]\tLoss: 0.162649\n",
"Train Epoch: 0 [200/938 (21%)]\tLoss: 0.390720\n",
"Train Epoch: 0 [210/938 (22%)]\tLoss: 0.625486\n",
"Train Epoch: 0 [220/938 (23%)]\tLoss: 0.324941\n",
"Train Epoch: 0 [230/938 (25%)]\tLoss: 0.438149\n",
"Train Epoch: 0 [240/938 (26%)]\tLoss: 0.323761\n",
"Train Epoch: 0 [250/938 (27%)]\tLoss: 0.349010\n",
"Train Epoch: 0 [260/938 (28%)]\tLoss: 0.289034\n",
"Train Epoch: 0 [270/938 (29%)]\tLoss: 0.306782\n",
"Train Epoch: 0 [280/938 (30%)]\tLoss: 0.572093\n",
"Train Epoch: 0 [290/938 (31%)]\tLoss: 0.184500\n",
"Train Epoch: 0 [300/938 (32%)]\tLoss: 0.260680\n",
"Train Epoch: 0 [310/938 (33%)]\tLoss: 0.285141\n",
"Train Epoch: 0 [320/938 (34%)]\tLoss: 0.440616\n",
"Train Epoch: 0 [330/938 (35%)]\tLoss: 0.283369\n",
"Train Epoch: 0 [340/938 (36%)]\tLoss: 0.209447\n",
"Train Epoch: 0 [350/938 (37%)]\tLoss: 0.192667\n",
"Train Epoch: 0 [360/938 (38%)]\tLoss: 0.228336\n",
"Train Epoch: 0 [370/938 (39%)]\tLoss: 0.172963\n",
"Train Epoch: 0 [380/938 (41%)]\tLoss: 0.314350\n",
"Train Epoch: 0 [390/938 (42%)]\tLoss: 0.260803\n",
"Train Epoch: 0 [400/938 (43%)]\tLoss: 0.250771\n",
"Train Epoch: 0 [410/938 (44%)]\tLoss: 0.398289\n",
"Train Epoch: 0 [420/938 (45%)]\tLoss: 0.184596\n",
"Train Epoch: 0 [430/938 (46%)]\tLoss: 0.610948\n",
"Train Epoch: 0 [440/938 (47%)]\tLoss: 0.138442\n",
"Train Epoch: 0 [450/938 (48%)]\tLoss: 0.213884\n",
"Train Epoch: 0 [460/938 (49%)]\tLoss: 0.236248\n",
"Train Epoch: 0 [470/938 (50%)]\tLoss: 0.155862\n",
"Train Epoch: 0 [480/938 (51%)]\tLoss: 0.189624\n",
"Train Epoch: 0 [490/938 (52%)]\tLoss: 0.110163\n",
"Train Epoch: 0 [500/938 (53%)]\tLoss: 0.299565\n",
"Train Epoch: 0 [510/938 (54%)]\tLoss: 0.194797\n",
"Train Epoch: 0 [520/938 (55%)]\tLoss: 0.331710\n",
"Train Epoch: 0 [530/938 (57%)]\tLoss: 0.051663\n",
"Train Epoch: 0 [540/938 (58%)]\tLoss: 0.226978\n",
"Train Epoch: 0 [550/938 (59%)]\tLoss: 0.227077\n",
"Train Epoch: 0 [560/938 (60%)]\tLoss: 0.143331\n",
"Train Epoch: 0 [570/938 (61%)]\tLoss: 0.322037\n",
"Train Epoch: 0 [580/938 (62%)]\tLoss: 0.298587\n",
"Train Epoch: 0 [590/938 (63%)]\tLoss: 0.145827\n",
"Train Epoch: 0 [600/938 (64%)]\tLoss: 0.164992\n",
"Train Epoch: 0 [610/938 (65%)]\tLoss: 0.181216\n",
"Train Epoch: 0 [620/938 (66%)]\tLoss: 0.143517\n",
"Train Epoch: 0 [630/938 (67%)]\tLoss: 0.344943\n",
"Train Epoch: 0 [640/938 (68%)]\tLoss: 0.408126\n",
"Train Epoch: 0 [650/938 (69%)]\tLoss: 0.236928\n",
"Train Epoch: 0 [660/938 (70%)]\tLoss: 0.203909\n",
"Train Epoch: 0 [670/938 (71%)]\tLoss: 0.256777\n",
"Train Epoch: 0 [680/938 (72%)]\tLoss: 0.381356\n",
"Train Epoch: 0 [690/938 (74%)]\tLoss: 0.249726\n",
"Train Epoch: 0 [700/938 (75%)]\tLoss: 0.156209\n",
"Train Epoch: 0 [710/938 (76%)]\tLoss: 0.207083\n",
"Train Epoch: 0 [720/938 (77%)]\tLoss: 0.264972\n",
"Train Epoch: 0 [730/938 (78%)]\tLoss: 0.089759\n",
"Train Epoch: 0 [740/938 (79%)]\tLoss: 0.287047\n",
"Train Epoch: 0 [750/938 (80%)]\tLoss: 0.253490\n",
"Train Epoch: 0 [760/938 (81%)]\tLoss: 0.188711\n",
"Train Epoch: 0 [770/938 (82%)]\tLoss: 0.182997\n",
"Train Epoch: 0 [780/938 (83%)]\tLoss: 0.173947\n",
"Train Epoch: 0 [790/938 (84%)]\tLoss: 0.298086\n",
"Train Epoch: 0 [800/938 (85%)]\tLoss: 0.225113\n",
"Train Epoch: 0 [810/938 (86%)]\tLoss: 0.036622\n",
"Train Epoch: 0 [820/938 (87%)]\tLoss: 0.214803\n",
"Train Epoch: 0 [830/938 (88%)]\tLoss: 0.303559\n",
"Train Epoch: 0 [840/938 (90%)]\tLoss: 0.143316\n",
"Train Epoch: 0 [850/938 (91%)]\tLoss: 0.078837\n",
"Train Epoch: 0 [860/938 (92%)]\tLoss: 0.272784\n",
"Train Epoch: 0 [870/938 (93%)]\tLoss: 0.040230\n",
"Train Epoch: 0 [880/938 (94%)]\tLoss: 0.255706\n",
"Train Epoch: 0 [890/938 (95%)]\tLoss: 0.257530\n",
"Train Epoch: 0 [900/938 (96%)]\tLoss: 0.265747\n",
"Train Epoch: 0 [910/938 (97%)]\tLoss: 0.164142\n",
"Train Epoch: 0 [920/938 (98%)]\tLoss: 0.184068\n",
"Train Epoch: 0 [930/938 (99%)]\tLoss: 0.103193\n",
"Test Epoch: 0 [0/157 (0%)]\tAcc: 1.000000\n",
"Test Epoch: 0 [10/157 (6%)]\tAcc: 0.953125\n",
"Test Epoch: 0 [20/157 (13%)]\tAcc: 0.906250\n",
"Test Epoch: 0 [30/157 (19%)]\tAcc: 0.906250\n",
"Test Epoch: 0 [40/157 (25%)]\tAcc: 0.953125\n",
"Test Epoch: 0 [50/157 (32%)]\tAcc: 0.968750\n",
"Test Epoch: 0 [60/157 (38%)]\tAcc: 0.953125\n",
"Test Epoch: 0 [70/157 (45%)]\tAcc: 0.937500\n",
"Test Epoch: 0 [80/157 (51%)]\tAcc: 0.984375\n",
"Test Epoch: 0 [90/157 (57%)]\tAcc: 1.000000\n",
"Test Epoch: 0 [100/157 (64%)]\tAcc: 1.000000\n",
"Test Epoch: 0 [110/157 (70%)]\tAcc: 0.984375\n",
"Test Epoch: 0 [120/157 (76%)]\tAcc: 0.984375\n",
"Test Epoch: 0 [130/157 (83%)]\tAcc: 0.984375\n",
"Test Epoch: 0 [140/157 (89%)]\tAcc: 0.937500\n",
"Test Epoch: 0 [150/157 (96%)]\tAcc: 0.953125\n",
"Total test acc = 0.9617\n",
"Train Epoch: 1 [0/938 (0%)]\tLoss: 0.037812\n",
"Train Epoch: 1 [10/938 (1%)]\tLoss: 0.191765\n",
"Train Epoch: 1 [20/938 (2%)]\tLoss: 0.118993\n",
"Train Epoch: 1 [30/938 (3%)]\tLoss: 0.091133\n",
"Train Epoch: 1 [40/938 (4%)]\tLoss: 0.259714\n",
"Train Epoch: 1 [50/938 (5%)]\tLoss: 0.129380\n",
"Train Epoch: 1 [60/938 (6%)]\tLoss: 0.126883\n",
"Train Epoch: 1 [70/938 (7%)]\tLoss: 0.048264\n",
"Train Epoch: 1 [80/938 (9%)]\tLoss: 0.208334\n",
"Train Epoch: 1 [90/938 (10%)]\tLoss: 0.092803\n",
"Train Epoch: 1 [100/938 (11%)]\tLoss: 0.009891\n",
"Train Epoch: 1 [110/938 (12%)]\tLoss: 0.153850\n",
"Train Epoch: 1 [120/938 (13%)]\tLoss: 0.062122\n",
"Train Epoch: 1 [130/938 (14%)]\tLoss: 0.055835\n",
"Train Epoch: 1 [140/938 (15%)]\tLoss: 0.114874\n",
"Train Epoch: 1 [150/938 (16%)]\tLoss: 0.135059\n",
"Train Epoch: 1 [160/938 (17%)]\tLoss: 0.209164\n",
"Train Epoch: 1 [170/938 (18%)]\tLoss: 0.200797\n",
"Train Epoch: 1 [180/938 (19%)]\tLoss: 0.118919\n",
"Train Epoch: 1 [190/938 (20%)]\tLoss: 0.196783\n",
"Train Epoch: 1 [200/938 (21%)]\tLoss: 0.242684\n",
"Train Epoch: 1 [210/938 (22%)]\tLoss: 0.413652\n",
"Train Epoch: 1 [220/938 (23%)]\tLoss: 0.108963\n",
"Train Epoch: 1 [230/938 (25%)]\tLoss: 0.530779\n",
"Train Epoch: 1 [240/938 (26%)]\tLoss: 0.111462\n",
"Train Epoch: 1 [250/938 (27%)]\tLoss: 0.233948\n",
"Train Epoch: 1 [260/938 (28%)]\tLoss: 0.096871\n",
"Train Epoch: 1 [270/938 (29%)]\tLoss: 0.005476\n",
"Train Epoch: 1 [280/938 (30%)]\tLoss: 0.136239\n",
"Train Epoch: 1 [290/938 (31%)]\tLoss: 0.086216\n",
"Train Epoch: 1 [300/938 (32%)]\tLoss: 0.320376\n",
"Train Epoch: 1 [310/938 (33%)]\tLoss: 0.056435\n",
"Train Epoch: 1 [320/938 (34%)]\tLoss: 0.088900\n",
"Train Epoch: 1 [330/938 (35%)]\tLoss: 0.077269\n",
"Train Epoch: 1 [340/938 (36%)]\tLoss: 0.222857\n",
"Train Epoch: 1 [350/938 (37%)]\tLoss: 0.090470\n",
"Train Epoch: 1 [360/938 (38%)]\tLoss: 0.181398\n",
"Train Epoch: 1 [370/938 (39%)]\tLoss: 0.258413\n",
"Train Epoch: 1 [380/938 (41%)]\tLoss: 0.044568\n",
"Train Epoch: 1 [390/938 (42%)]\tLoss: 0.208491\n",
"Train Epoch: 1 [400/938 (43%)]\tLoss: 0.024349\n",
"Train Epoch: 1 [410/938 (44%)]\tLoss: 0.097962\n",
"Train Epoch: 1 [420/938 (45%)]\tLoss: 0.035660\n",
"Train Epoch: 1 [430/938 (46%)]\tLoss: 0.319142\n",
"Train Epoch: 1 [440/938 (47%)]\tLoss: 0.112544\n",
"Train Epoch: 1 [450/938 (48%)]\tLoss: 0.284870\n",
"Train Epoch: 1 [460/938 (49%)]\tLoss: 0.283764\n",
"Train Epoch: 1 [470/938 (50%)]\tLoss: 0.250175\n",
"Train Epoch: 1 [480/938 (51%)]\tLoss: 0.157756\n",
"Train Epoch: 1 [490/938 (52%)]\tLoss: 0.071510\n",
"Train Epoch: 1 [500/938 (53%)]\tLoss: 0.296052\n",
"Train Epoch: 1 [510/938 (54%)]\tLoss: 0.063261\n",
"Train Epoch: 1 [520/938 (55%)]\tLoss: 0.087280\n",
"Train Epoch: 1 [530/938 (57%)]\tLoss: 0.102561\n",
"Train Epoch: 1 [540/938 (58%)]\tLoss: 0.083404\n",
"Train Epoch: 1 [550/938 (59%)]\tLoss: 0.052055\n",
"Train Epoch: 1 [560/938 (60%)]\tLoss: 0.044225\n",
"Train Epoch: 1 [570/938 (61%)]\tLoss: 0.080570\n",
"Train Epoch: 1 [580/938 (62%)]\tLoss: 0.167150\n",
"Train Epoch: 1 [590/938 (63%)]\tLoss: 0.137994\n",
"Train Epoch: 1 [600/938 (64%)]\tLoss: 0.314702\n",
"Train Epoch: 1 [610/938 (65%)]\tLoss: 0.386576\n",
"Train Epoch: 1 [620/938 (66%)]\tLoss: 0.154651\n",
"Train Epoch: 1 [630/938 (67%)]\tLoss: 0.211713\n",
"Train Epoch: 1 [640/938 (68%)]\tLoss: 0.201374\n",
"Train Epoch: 1 [650/938 (69%)]\tLoss: 0.070127\n",
"Train Epoch: 1 [660/938 (70%)]\tLoss: 0.095240\n",
"Train Epoch: 1 [670/938 (71%)]\tLoss: 0.051987\n",
"Train Epoch: 1 [680/938 (72%)]\tLoss: 0.142315\n",
"Train Epoch: 1 [690/938 (74%)]\tLoss: 0.090122\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 1 [700/938 (75%)]\tLoss: 0.104628\n",
"Train Epoch: 1 [710/938 (76%)]\tLoss: 0.116404\n",
"Train Epoch: 1 [720/938 (77%)]\tLoss: 0.059856\n",
"Train Epoch: 1 [730/938 (78%)]\tLoss: 0.346445\n",
"Train Epoch: 1 [740/938 (79%)]\tLoss: 0.057055\n",
"Train Epoch: 1 [750/938 (80%)]\tLoss: 0.070142\n",
"Train Epoch: 1 [760/938 (81%)]\tLoss: 0.148698\n",
"Train Epoch: 1 [770/938 (82%)]\tLoss: 0.042575\n",
"Train Epoch: 1 [780/938 (83%)]\tLoss: 0.131458\n",
"Train Epoch: 1 [790/938 (84%)]\tLoss: 0.023700\n",
"Train Epoch: 1 [800/938 (85%)]\tLoss: 0.268284\n",
"Train Epoch: 1 [810/938 (86%)]\tLoss: 0.011392\n",
"Train Epoch: 1 [820/938 (87%)]\tLoss: 0.170202\n",
"Train Epoch: 1 [830/938 (88%)]\tLoss: 0.137273\n",
"Train Epoch: 1 [840/938 (90%)]\tLoss: 0.100842\n",
"Train Epoch: 1 [850/938 (91%)]\tLoss: 0.057559\n",
"Train Epoch: 1 [860/938 (92%)]\tLoss: 0.163903\n",
"Train Epoch: 1 [870/938 (93%)]\tLoss: 0.270894\n",
"Train Epoch: 1 [880/938 (94%)]\tLoss: 0.110774\n",
"Train Epoch: 1 [890/938 (95%)]\tLoss: 0.142618\n",
"Train Epoch: 1 [900/938 (96%)]\tLoss: 0.187053\n",
"Train Epoch: 1 [910/938 (97%)]\tLoss: 0.151790\n",
"Train Epoch: 1 [920/938 (98%)]\tLoss: 0.018912\n",
"Train Epoch: 1 [930/938 (99%)]\tLoss: 0.183618\n",
"Test Epoch: 1 [0/157 (0%)]\tAcc: 1.000000\n",
"Test Epoch: 1 [10/157 (6%)]\tAcc: 0.921875\n",
"Test Epoch: 1 [20/157 (13%)]\tAcc: 0.921875\n",
"Test Epoch: 1 [30/157 (19%)]\tAcc: 0.953125\n",
"Test Epoch: 1 [40/157 (25%)]\tAcc: 0.921875\n",
"Test Epoch: 1 [50/157 (32%)]\tAcc: 0.968750\n",
"Test Epoch: 1 [60/157 (38%)]\tAcc: 0.937500\n",
"Test Epoch: 1 [70/157 (45%)]\tAcc: 0.953125\n",
"Test Epoch: 1 [80/157 (51%)]\tAcc: 1.000000\n",
"Test Epoch: 1 [90/157 (57%)]\tAcc: 1.000000\n",
"Test Epoch: 1 [100/157 (64%)]\tAcc: 1.000000\n",
"Test Epoch: 1 [110/157 (70%)]\tAcc: 1.000000\n",
"Test Epoch: 1 [120/157 (76%)]\tAcc: 1.000000\n",
"Test Epoch: 1 [130/157 (83%)]\tAcc: 0.984375\n",
"Test Epoch: 1 [140/157 (89%)]\tAcc: 0.953125\n",
"Test Epoch: 1 [150/157 (96%)]\tAcc: 0.953125\n",
"Total test acc = 0.9671\n",
"Train Epoch: 2 [0/938 (0%)]\tLoss: 0.085735\n",
"Train Epoch: 2 [10/938 (1%)]\tLoss: 0.130465\n",
"Train Epoch: 2 [20/938 (2%)]\tLoss: 0.027238\n",
"Train Epoch: 2 [30/938 (3%)]\tLoss: 0.116073\n",
"Train Epoch: 2 [40/938 (4%)]\tLoss: 0.187502\n",
"Train Epoch: 2 [50/938 (5%)]\tLoss: 0.065617\n",
"Train Epoch: 2 [60/938 (6%)]\tLoss: 0.093609\n",
"Train Epoch: 2 [70/938 (7%)]\tLoss: 0.104410\n",
"Train Epoch: 2 [80/938 (9%)]\tLoss: 0.095942\n",
"Train Epoch: 2 [90/938 (10%)]\tLoss: 0.282522\n",
"Train Epoch: 2 [100/938 (11%)]\tLoss: 0.091899\n",
"Train Epoch: 2 [110/938 (12%)]\tLoss: 0.142710\n",
"Train Epoch: 2 [120/938 (13%)]\tLoss: 0.051769\n",
"Train Epoch: 2 [130/938 (14%)]\tLoss: 0.071588\n",
"Train Epoch: 2 [140/938 (15%)]\tLoss: 0.021449\n",
"Train Epoch: 2 [150/938 (16%)]\tLoss: 0.006790\n",
"Train Epoch: 2 [160/938 (17%)]\tLoss: 0.055043\n",
"Train Epoch: 2 [170/938 (18%)]\tLoss: 0.359066\n",
"Train Epoch: 2 [180/938 (19%)]\tLoss: 0.086634\n",
"Train Epoch: 2 [190/938 (20%)]\tLoss: 0.030564\n",
"Train Epoch: 2 [200/938 (21%)]\tLoss: 0.155594\n",
"Train Epoch: 2 [210/938 (22%)]\tLoss: 0.446819\n",
"Train Epoch: 2 [220/938 (23%)]\tLoss: 0.040492\n",
"Train Epoch: 2 [230/938 (25%)]\tLoss: 0.122431\n",
"Train Epoch: 2 [240/938 (26%)]\tLoss: 0.127685\n",
"Train Epoch: 2 [250/938 (27%)]\tLoss: 0.066049\n",
"Train Epoch: 2 [260/938 (28%)]\tLoss: 0.092976\n",
"Train Epoch: 2 [270/938 (29%)]\tLoss: 0.057877\n",
"Train Epoch: 2 [280/938 (30%)]\tLoss: 0.124391\n",
"Train Epoch: 2 [290/938 (31%)]\tLoss: 0.054755\n",
"Train Epoch: 2 [300/938 (32%)]\tLoss: 0.097547\n",
"Train Epoch: 2 [310/938 (33%)]\tLoss: 0.151855\n",
"Train Epoch: 2 [320/938 (34%)]\tLoss: 0.093891\n",
"Train Epoch: 2 [330/938 (35%)]\tLoss: 0.104686\n",
"Train Epoch: 2 [340/938 (36%)]\tLoss: 0.256665\n",
"Train Epoch: 2 [350/938 (37%)]\tLoss: 0.290896\n",
"Train Epoch: 2 [360/938 (38%)]\tLoss: 0.207254\n",
"Train Epoch: 2 [370/938 (39%)]\tLoss: 0.167920\n",
"Train Epoch: 2 [380/938 (41%)]\tLoss: 0.025714\n",
"Train Epoch: 2 [390/938 (42%)]\tLoss: 0.049858\n",
"Train Epoch: 2 [400/938 (43%)]\tLoss: 0.038995\n",
"Train Epoch: 2 [410/938 (44%)]\tLoss: 0.033156\n",
"Train Epoch: 2 [420/938 (45%)]\tLoss: 0.063976\n",
"Train Epoch: 2 [430/938 (46%)]\tLoss: 0.295632\n",
"Train Epoch: 2 [440/938 (47%)]\tLoss: 0.028926\n",
"Train Epoch: 2 [450/938 (48%)]\tLoss: 0.026876\n",
"Train Epoch: 2 [460/938 (49%)]\tLoss: 0.035151\n",
"Train Epoch: 2 [470/938 (50%)]\tLoss: 0.043150\n",
"Train Epoch: 2 [480/938 (51%)]\tLoss: 0.099015\n",
"Train Epoch: 2 [490/938 (52%)]\tLoss: 0.156179\n",
"Train Epoch: 2 [500/938 (53%)]\tLoss: 0.157619\n",
"Train Epoch: 2 [510/938 (54%)]\tLoss: 0.140651\n",
"Train Epoch: 2 [520/938 (55%)]\tLoss: 0.176397\n",
"Train Epoch: 2 [530/938 (57%)]\tLoss: 0.099402\n",
"Train Epoch: 2 [540/938 (58%)]\tLoss: 0.035010\n",
"Train Epoch: 2 [550/938 (59%)]\tLoss: 0.287024\n",
"Train Epoch: 2 [560/938 (60%)]\tLoss: 0.135259\n",
"Train Epoch: 2 [570/938 (61%)]\tLoss: 0.020481\n",
"Train Epoch: 2 [580/938 (62%)]\tLoss: 0.071563\n",
"Train Epoch: 2 [590/938 (63%)]\tLoss: 0.041979\n",
"Train Epoch: 2 [600/938 (64%)]\tLoss: 0.053910\n",
"Train Epoch: 2 [610/938 (65%)]\tLoss: 0.105802\n",
"Train Epoch: 2 [620/938 (66%)]\tLoss: 0.077079\n",
"Train Epoch: 2 [630/938 (67%)]\tLoss: 0.072032\n",
"Train Epoch: 2 [640/938 (68%)]\tLoss: 0.164553\n",
"Train Epoch: 2 [650/938 (69%)]\tLoss: 0.020185\n",
"Train Epoch: 2 [660/938 (70%)]\tLoss: 0.164906\n",
"Train Epoch: 2 [670/938 (71%)]\tLoss: 0.160516\n",
"Train Epoch: 2 [680/938 (72%)]\tLoss: 0.185206\n",
"Train Epoch: 2 [690/938 (74%)]\tLoss: 0.252789\n",
"Train Epoch: 2 [700/938 (75%)]\tLoss: 0.058496\n",
"Train Epoch: 2 [710/938 (76%)]\tLoss: 0.082438\n",
"Train Epoch: 2 [720/938 (77%)]\tLoss: 0.200533\n",
"Train Epoch: 2 [730/938 (78%)]\tLoss: 0.011665\n",
"Train Epoch: 2 [740/938 (79%)]\tLoss: 0.093606\n",
"Train Epoch: 2 [750/938 (80%)]\tLoss: 0.093081\n",
"Train Epoch: 2 [760/938 (81%)]\tLoss: 0.036109\n",
"Train Epoch: 2 [770/938 (82%)]\tLoss: 0.038048\n",
"Train Epoch: 2 [780/938 (83%)]\tLoss: 0.078773\n",
"Train Epoch: 2 [790/938 (84%)]\tLoss: 0.054961\n",
"Train Epoch: 2 [800/938 (85%)]\tLoss: 0.035109\n",
"Train Epoch: 2 [810/938 (86%)]\tLoss: 0.150152\n",
"Train Epoch: 2 [820/938 (87%)]\tLoss: 0.025836\n",
"Train Epoch: 2 [830/938 (88%)]\tLoss: 0.077902\n",
"Train Epoch: 2 [840/938 (90%)]\tLoss: 0.154891\n",
"Train Epoch: 2 [850/938 (91%)]\tLoss: 0.181755\n",
"Train Epoch: 2 [860/938 (92%)]\tLoss: 0.068052\n",
"Train Epoch: 2 [870/938 (93%)]\tLoss: 0.223926\n",
"Train Epoch: 2 [880/938 (94%)]\tLoss: 0.024089\n",
"Train Epoch: 2 [890/938 (95%)]\tLoss: 0.268356\n",
"Train Epoch: 2 [900/938 (96%)]\tLoss: 0.029194\n",
"Train Epoch: 2 [910/938 (97%)]\tLoss: 0.030891\n",
"Train Epoch: 2 [920/938 (98%)]\tLoss: 0.037380\n",
"Train Epoch: 2 [930/938 (99%)]\tLoss: 0.146263\n",
"Test Epoch: 2 [0/157 (0%)]\tAcc: 1.000000\n",
"Test Epoch: 2 [10/157 (6%)]\tAcc: 0.953125\n",
"Test Epoch: 2 [20/157 (13%)]\tAcc: 0.968750\n",
"Test Epoch: 2 [30/157 (19%)]\tAcc: 0.953125\n",
"Test Epoch: 2 [40/157 (25%)]\tAcc: 0.984375\n",
"Test Epoch: 2 [50/157 (32%)]\tAcc: 0.953125\n",
"Test Epoch: 2 [60/157 (38%)]\tAcc: 0.984375\n",
"Test Epoch: 2 [70/157 (45%)]\tAcc: 0.937500\n",
"Test Epoch: 2 [80/157 (51%)]\tAcc: 1.000000\n",
"Test Epoch: 2 [90/157 (57%)]\tAcc: 0.984375\n",
"Test Epoch: 2 [100/157 (64%)]\tAcc: 0.953125\n",
"Test Epoch: 2 [110/157 (70%)]\tAcc: 0.984375\n",
"Test Epoch: 2 [120/157 (76%)]\tAcc: 1.000000\n",
"Test Epoch: 2 [130/157 (83%)]\tAcc: 0.984375\n",
"Test Epoch: 2 [140/157 (89%)]\tAcc: 0.968750\n",
"Test Epoch: 2 [150/157 (96%)]\tAcc: 0.968750\n",
"Total test acc = 0.9747\n",
"Train Epoch: 3 [0/938 (0%)]\tLoss: 0.029276\n",
"Train Epoch: 3 [10/938 (1%)]\tLoss: 0.128718\n",
"Train Epoch: 3 [20/938 (2%)]\tLoss: 0.114136\n",
"Train Epoch: 3 [30/938 (3%)]\tLoss: 0.175863\n",
"Train Epoch: 3 [40/938 (4%)]\tLoss: 0.119407\n",
"Train Epoch: 3 [50/938 (5%)]\tLoss: 0.023124\n",
"Train Epoch: 3 [60/938 (6%)]\tLoss: 0.069409\n",
"Train Epoch: 3 [70/938 (7%)]\tLoss: 0.175537\n",
"Train Epoch: 3 [80/938 (9%)]\tLoss: 0.044068\n",
"Train Epoch: 3 [90/938 (10%)]\tLoss: 0.039270\n",
"Train Epoch: 3 [100/938 (11%)]\tLoss: 0.041589\n",
"Train Epoch: 3 [110/938 (12%)]\tLoss: 0.311189\n",
"Train Epoch: 3 [120/938 (13%)]\tLoss: 0.048966\n",
"Train Epoch: 3 [130/938 (14%)]\tLoss: 0.021020\n",
"Train Epoch: 3 [140/938 (15%)]\tLoss: 0.123300\n",
"Train Epoch: 3 [150/938 (16%)]\tLoss: 0.015196\n",
"Train Epoch: 3 [160/938 (17%)]\tLoss: 0.086319\n",
"Train Epoch: 3 [170/938 (18%)]\tLoss: 0.025925\n",
"Train Epoch: 3 [180/938 (19%)]\tLoss: 0.026906\n",
"Train Epoch: 3 [190/938 (20%)]\tLoss: 0.035199\n",
"Train Epoch: 3 [200/938 (21%)]\tLoss: 0.156299\n",
"Train Epoch: 3 [210/938 (22%)]\tLoss: 0.099553\n",
"Train Epoch: 3 [220/938 (23%)]\tLoss: 0.092651\n",
"Train Epoch: 3 [230/938 (25%)]\tLoss: 0.096245\n",
"Train Epoch: 3 [240/938 (26%)]\tLoss: 0.075171\n",
"Train Epoch: 3 [250/938 (27%)]\tLoss: 0.042659\n",
"Train Epoch: 3 [260/938 (28%)]\tLoss: 0.346644\n",
"Train Epoch: 3 [270/938 (29%)]\tLoss: 0.025780\n",
"Train Epoch: 3 [280/938 (30%)]\tLoss: 0.065534\n",
"Train Epoch: 3 [290/938 (31%)]\tLoss: 0.055503\n",
"Train Epoch: 3 [300/938 (32%)]\tLoss: 0.265311\n",
"Train Epoch: 3 [310/938 (33%)]\tLoss: 0.254793\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Epoch: 3 [320/938 (34%)]\tLoss: 0.045063\n",
"Train Epoch: 3 [330/938 (35%)]\tLoss: 0.058299\n",
"Train Epoch: 3 [340/938 (36%)]\tLoss: 0.205273\n",
"Train Epoch: 3 [350/938 (37%)]\tLoss: 0.016032\n",
"Train Epoch: 3 [360/938 (38%)]\tLoss: 0.037333\n",
"Train Epoch: 3 [370/938 (39%)]\tLoss: 0.089765\n",
"Train Epoch: 3 [380/938 (41%)]\tLoss: 0.062916\n",
"Train Epoch: 3 [390/938 (42%)]\tLoss: 0.140819\n",
"Train Epoch: 3 [400/938 (43%)]\tLoss: 0.084308\n",
"Train Epoch: 3 [410/938 (44%)]\tLoss: 0.223716\n",
"Train Epoch: 3 [420/938 (45%)]\tLoss: 0.018608\n",
"Train Epoch: 3 [430/938 (46%)]\tLoss: 0.040922\n",
"Train Epoch: 3 [440/938 (47%)]\tLoss: 0.091088\n",
"Train Epoch: 3 [450/938 (48%)]\tLoss: 0.100574\n",
"Train Epoch: 3 [460/938 (49%)]\tLoss: 0.031961\n",
"Train Epoch: 3 [470/938 (50%)]\tLoss: 0.044045\n",
"Train Epoch: 3 [480/938 (51%)]\tLoss: 0.019030\n",
"Train Epoch: 3 [490/938 (52%)]\tLoss: 0.056948\n",
"Train Epoch: 3 [500/938 (53%)]\tLoss: 0.066515\n",
"Train Epoch: 3 [510/938 (54%)]\tLoss: 0.127204\n",
"Train Epoch: 3 [520/938 (55%)]\tLoss: 0.068011\n",
"Train Epoch: 3 [530/938 (57%)]\tLoss: 0.170617\n",
"Train Epoch: 3 [540/938 (58%)]\tLoss: 0.085911\n",
"Train Epoch: 3 [550/938 (59%)]\tLoss: 0.040213\n",
"Train Epoch: 3 [560/938 (60%)]\tLoss: 0.071809\n",
"Train Epoch: 3 [570/938 (61%)]\tLoss: 0.100353\n",
"Train Epoch: 3 [580/938 (62%)]\tLoss: 0.088990\n",
"Train Epoch: 3 [590/938 (63%)]\tLoss: 0.017604\n",
"Train Epoch: 3 [600/938 (64%)]\tLoss: 0.069384\n",
"Train Epoch: 3 [610/938 (65%)]\tLoss: 0.050279\n",
"Train Epoch: 3 [620/938 (66%)]\tLoss: 0.099423\n",
"Train Epoch: 3 [630/938 (67%)]\tLoss: 0.087007\n",
"Train Epoch: 3 [640/938 (68%)]\tLoss: 0.123801\n",
"Train Epoch: 3 [650/938 (69%)]\tLoss: 0.108569\n",
"Train Epoch: 3 [660/938 (70%)]\tLoss: 0.015272\n",
"Train Epoch: 3 [670/938 (71%)]\tLoss: 0.058114\n",
"Train Epoch: 3 [680/938 (72%)]\tLoss: 0.051281\n",
"Train Epoch: 3 [690/938 (74%)]\tLoss: 0.056625\n",
"Train Epoch: 3 [700/938 (75%)]\tLoss: 0.024394\n",
"Train Epoch: 3 [710/938 (76%)]\tLoss: 0.229214\n",
"Train Epoch: 3 [720/938 (77%)]\tLoss: 0.157146\n",
"Train Epoch: 3 [730/938 (78%)]\tLoss: 0.057152\n",
"Train Epoch: 3 [740/938 (79%)]\tLoss: 0.170796\n",
"Train Epoch: 3 [750/938 (80%)]\tLoss: 0.005900\n",
"Train Epoch: 3 [760/938 (81%)]\tLoss: 0.206662\n",
"Train Epoch: 3 [770/938 (82%)]\tLoss: 0.020888\n",
"Train Epoch: 3 [780/938 (83%)]\tLoss: 0.102198\n",
"Train Epoch: 3 [790/938 (84%)]\tLoss: 0.156167\n",
"Train Epoch: 3 [800/938 (85%)]\tLoss: 0.118774\n",
"Train Epoch: 3 [810/938 (86%)]\tLoss: 0.094498\n",
"Train Epoch: 3 [820/938 (87%)]\tLoss: 0.084717\n",
"Train Epoch: 3 [830/938 (88%)]\tLoss: 0.113925\n",
"Train Epoch: 3 [840/938 (90%)]\tLoss: 0.054973\n",
"Train Epoch: 3 [850/938 (91%)]\tLoss: 0.109440\n",
"Train Epoch: 3 [860/938 (92%)]\tLoss: 0.052542\n",
"Train Epoch: 3 [870/938 (93%)]\tLoss: 0.081913\n",
"Train Epoch: 3 [880/938 (94%)]\tLoss: 0.034008\n",
"Train Epoch: 3 [890/938 (95%)]\tLoss: 0.055243\n",
"Train Epoch: 3 [900/938 (96%)]\tLoss: 0.026066\n",
"Train Epoch: 3 [910/938 (97%)]\tLoss: 0.081119\n",
"Train Epoch: 3 [920/938 (98%)]\tLoss: 0.155930\n",
"Train Epoch: 3 [930/938 (99%)]\tLoss: 0.056044\n",
"Test Epoch: 3 [0/157 (0%)]\tAcc: 1.000000\n",
"Test Epoch: 3 [10/157 (6%)]\tAcc: 0.953125\n",
"Test Epoch: 3 [20/157 (13%)]\tAcc: 0.968750\n",
"Test Epoch: 3 [30/157 (19%)]\tAcc: 0.968750\n",
"Test Epoch: 3 [40/157 (25%)]\tAcc: 0.937500\n",
"Test Epoch: 3 [50/157 (32%)]\tAcc: 0.953125\n",
"Test Epoch: 3 [60/157 (38%)]\tAcc: 0.953125\n",
"Test Epoch: 3 [70/157 (45%)]\tAcc: 0.937500\n",
"Test Epoch: 3 [80/157 (51%)]\tAcc: 0.968750\n",
"Test Epoch: 3 [90/157 (57%)]\tAcc: 1.000000\n",
"Test Epoch: 3 [100/157 (64%)]\tAcc: 1.000000\n",
"Test Epoch: 3 [110/157 (70%)]\tAcc: 0.984375\n",
"Test Epoch: 3 [120/157 (76%)]\tAcc: 0.984375\n",
"Test Epoch: 3 [130/157 (83%)]\tAcc: 1.000000\n",
"Test Epoch: 3 [140/157 (89%)]\tAcc: 0.937500\n",
"Test Epoch: 3 [150/157 (96%)]\tAcc: 0.921875\n",
"Total test acc = 0.9661\n",
"Train Epoch: 4 [0/938 (0%)]\tLoss: 0.033561\n",
"Train Epoch: 4 [10/938 (1%)]\tLoss: 0.023106\n",
"Train Epoch: 4 [20/938 (2%)]\tLoss: 0.232753\n",
"Train Epoch: 4 [30/938 (3%)]\tLoss: 0.014634\n",
"Train Epoch: 4 [40/938 (4%)]\tLoss: 0.068259\n",
"Train Epoch: 4 [50/938 (5%)]\tLoss: 0.036399\n",
"Train Epoch: 4 [60/938 (6%)]\tLoss: 0.011522\n",
"Train Epoch: 4 [70/938 (7%)]\tLoss: 0.017262\n",
"Train Epoch: 4 [80/938 (9%)]\tLoss: 0.017650\n",
"Train Epoch: 4 [90/938 (10%)]\tLoss: 0.252838\n",
"Train Epoch: 4 [100/938 (11%)]\tLoss: 0.067233\n",
"Train Epoch: 4 [110/938 (12%)]\tLoss: 0.010737\n",
"Train Epoch: 4 [120/938 (13%)]\tLoss: 0.027090\n",
"Train Epoch: 4 [130/938 (14%)]\tLoss: 0.003570\n",
"Train Epoch: 4 [140/938 (15%)]\tLoss: 0.001152\n",
"Train Epoch: 4 [150/938 (16%)]\tLoss: 0.004174\n",
"Train Epoch: 4 [160/938 (17%)]\tLoss: 0.044346\n",
"Train Epoch: 4 [170/938 (18%)]\tLoss: 0.017931\n",
"Train Epoch: 4 [180/938 (19%)]\tLoss: 0.069598\n",
"Train Epoch: 4 [190/938 (20%)]\tLoss: 0.071746\n",
"Train Epoch: 4 [200/938 (21%)]\tLoss: 0.076535\n",
"Train Epoch: 4 [210/938 (22%)]\tLoss: 0.100941\n",
"Train Epoch: 4 [220/938 (23%)]\tLoss: 0.021138\n",
"Train Epoch: 4 [230/938 (25%)]\tLoss: 0.088135\n",
"Train Epoch: 4 [240/938 (26%)]\tLoss: 0.117837\n",
"Train Epoch: 4 [250/938 (27%)]\tLoss: 0.032977\n",
"Train Epoch: 4 [260/938 (28%)]\tLoss: 0.167049\n",
"Train Epoch: 4 [270/938 (29%)]\tLoss: 0.013822\n",
"Train Epoch: 4 [280/938 (30%)]\tLoss: 0.045409\n",
"Train Epoch: 4 [290/938 (31%)]\tLoss: 0.119697\n",
"Train Epoch: 4 [300/938 (32%)]\tLoss: 0.067969\n",
"Train Epoch: 4 [310/938 (33%)]\tLoss: 0.131902\n",
"Train Epoch: 4 [320/938 (34%)]\tLoss: 0.066060\n",
"Train Epoch: 4 [330/938 (35%)]\tLoss: 0.066008\n",
"Train Epoch: 4 [340/938 (36%)]\tLoss: 0.184514\n",
"Train Epoch: 4 [350/938 (37%)]\tLoss: 0.063161\n",
"Train Epoch: 4 [360/938 (38%)]\tLoss: 0.093329\n",
"Train Epoch: 4 [370/938 (39%)]\tLoss: 0.066812\n",
"Train Epoch: 4 [380/938 (41%)]\tLoss: 0.059804\n",
"Train Epoch: 4 [390/938 (42%)]\tLoss: 0.030450\n",
"Train Epoch: 4 [400/938 (43%)]\tLoss: 0.008455\n",
"Train Epoch: 4 [410/938 (44%)]\tLoss: 0.062510\n",
"Train Epoch: 4 [420/938 (45%)]\tLoss: 0.144027\n",
"Train Epoch: 4 [430/938 (46%)]\tLoss: 0.031752\n",
"Train Epoch: 4 [440/938 (47%)]\tLoss: 0.098326\n",
"Train Epoch: 4 [450/938 (48%)]\tLoss: 0.073496\n",
"Train Epoch: 4 [460/938 (49%)]\tLoss: 0.059776\n",
"Train Epoch: 4 [470/938 (50%)]\tLoss: 0.029075\n",
"Train Epoch: 4 [480/938 (51%)]\tLoss: 0.065083\n",
"Train Epoch: 4 [490/938 (52%)]\tLoss: 0.072848\n",
"Train Epoch: 4 [500/938 (53%)]\tLoss: 0.209735\n",
"Train Epoch: 4 [510/938 (54%)]\tLoss: 0.059197\n",
"Train Epoch: 4 [520/938 (55%)]\tLoss: 0.078309\n",
"Train Epoch: 4 [530/938 (57%)]\tLoss: 0.046048\n",
"Train Epoch: 4 [540/938 (58%)]\tLoss: 0.015106\n",
"Train Epoch: 4 [550/938 (59%)]\tLoss: 0.119855\n",
"Train Epoch: 4 [560/938 (60%)]\tLoss: 0.209653\n",
"Train Epoch: 4 [570/938 (61%)]\tLoss: 0.016287\n",
"Train Epoch: 4 [580/938 (62%)]\tLoss: 0.169566\n",
"Train Epoch: 4 [590/938 (63%)]\tLoss: 0.022472\n",
"Train Epoch: 4 [600/938 (64%)]\tLoss: 0.065485\n",
"Train Epoch: 4 [610/938 (65%)]\tLoss: 0.036357\n",
"Train Epoch: 4 [620/938 (66%)]\tLoss: 0.008508\n",
"Train Epoch: 4 [630/938 (67%)]\tLoss: 0.058652\n",
"Train Epoch: 4 [640/938 (68%)]\tLoss: 0.042606\n",
"Train Epoch: 4 [650/938 (69%)]\tLoss: 0.012600\n",
"Train Epoch: 4 [660/938 (70%)]\tLoss: 0.051655\n",
"Train Epoch: 4 [670/938 (71%)]\tLoss: 0.016207\n",
"Train Epoch: 4 [680/938 (72%)]\tLoss: 0.023592\n",
"Train Epoch: 4 [690/938 (74%)]\tLoss: 0.046827\n",
"Train Epoch: 4 [700/938 (75%)]\tLoss: 0.169964\n",
"Train Epoch: 4 [710/938 (76%)]\tLoss: 0.107680\n",
"Train Epoch: 4 [720/938 (77%)]\tLoss: 0.054638\n",
"Train Epoch: 4 [730/938 (78%)]\tLoss: 0.048630\n",
"Train Epoch: 4 [740/938 (79%)]\tLoss: 0.003170\n",
"Train Epoch: 4 [750/938 (80%)]\tLoss: 0.042371\n",
"Train Epoch: 4 [760/938 (81%)]\tLoss: 0.108998\n",
"Train Epoch: 4 [770/938 (82%)]\tLoss: 0.029650\n",
"Train Epoch: 4 [780/938 (83%)]\tLoss: 0.147029\n",
"Train Epoch: 4 [790/938 (84%)]\tLoss: 0.031367\n",
"Train Epoch: 4 [800/938 (85%)]\tLoss: 0.016677\n",
"Train Epoch: 4 [810/938 (86%)]\tLoss: 0.194396\n",
"Train Epoch: 4 [820/938 (87%)]\tLoss: 0.060913\n",
"Train Epoch: 4 [830/938 (88%)]\tLoss: 0.061074\n",
"Train Epoch: 4 [840/938 (90%)]\tLoss: 0.038237\n",
"Train Epoch: 4 [850/938 (91%)]\tLoss: 0.177125\n",
"Train Epoch: 4 [860/938 (92%)]\tLoss: 0.093419\n",
"Train Epoch: 4 [870/938 (93%)]\tLoss: 0.003814\n",
"Train Epoch: 4 [880/938 (94%)]\tLoss: 0.016418\n",
"Train Epoch: 4 [890/938 (95%)]\tLoss: 0.053389\n",
"Train Epoch: 4 [900/938 (96%)]\tLoss: 0.053131\n",
"Train Epoch: 4 [910/938 (97%)]\tLoss: 0.020287\n",
"Train Epoch: 4 [920/938 (98%)]\tLoss: 0.010422\n",
"Train Epoch: 4 [930/938 (99%)]\tLoss: 0.084738\n",
"Test Epoch: 4 [0/157 (0%)]\tAcc: 1.000000\n",
"Test Epoch: 4 [10/157 (6%)]\tAcc: 0.937500\n",
"Test Epoch: 4 [20/157 (13%)]\tAcc: 0.953125\n",
"Test Epoch: 4 [30/157 (19%)]\tAcc: 0.968750\n",
"Test Epoch: 4 [40/157 (25%)]\tAcc: 0.953125\n",
"Test Epoch: 4 [50/157 (32%)]\tAcc: 0.953125\n",
"Test Epoch: 4 [60/157 (38%)]\tAcc: 0.968750\n",
"Test Epoch: 4 [70/157 (45%)]\tAcc: 0.937500\n",
"Test Epoch: 4 [80/157 (51%)]\tAcc: 1.000000\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Epoch: 4 [90/157 (57%)]\tAcc: 1.000000\n",
"Test Epoch: 4 [100/157 (64%)]\tAcc: 1.000000\n",
"Test Epoch: 4 [110/157 (70%)]\tAcc: 1.000000\n",
"Test Epoch: 4 [120/157 (76%)]\tAcc: 1.000000\n",
"Test Epoch: 4 [130/157 (83%)]\tAcc: 0.984375\n",
"Test Epoch: 4 [140/157 (89%)]\tAcc: 0.953125\n",
"Test Epoch: 4 [150/157 (96%)]\tAcc: 0.968750\n",
"Total test acc = 0.98\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"# 训练函数\n",
"def train(model, train_loader, loss_function, optimizer, epoch):\n",
" model.train() # 开启训练模式\n",
" train_losses = list() # 初始化 Loss 容器,用于记录每一批次的 Loss\n",
" for batch_idx, (inputs, targets) in enumerate(train_loader): # 通过训练集加载器,按批次迭代数据\n",
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" loss = loss_function(outputs, targets) # 计算损失函数\n",
" optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n",
" train_losses.append(loss.item()) # 记录该批次的 Loss\n",
" \n",
" if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx, len(train_loader),\n",
" 100. * batch_idx / len(train_loader), loss.item()))\n",
" return train_losses # 返回本纪元的 Loss\n",
"\n",
"\n",
"# 测试函数\n",
"def test(model, val_loader, loss_function, epoch):\n",
" model.eval() # 开启训练模式\n",
" total_correct = 0 # 本纪元预测正确总次数 \n",
" total_num = 0 # 本纪元数据总数\n",
" for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n",
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" pred = np.argmax(outputs.numpy(), axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
" correct = np.sum(targets.numpy()==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
" batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n",
" acc = correct / batch_size # 计算本批次的正确率\n",
" \n",
" total_correct += correct # 将本批次的正确预测次数记录到总数中\n",
" total_num += batch_size # 将本批次的数据数目记录到总数中\n",
" \n",
" if batch_idx % 10 == 0: # 每十个批次,打印一次测试集上的准确率\n",
" print('Test Epoch: {} [{}/{} ({:.0f}%)]\\tAcc: {:.6f}'.format(epoch, \\\n",
" batch_idx, len(val_loader),100. * float(batch_idx) / len(val_loader), acc))\n",
" test_acc = total_correct / total_num # 计算本纪元的正确率\n",
" print ('Total test acc =', test_acc) \n",
" return test_acc\n",
"\n",
"\n",
"# 设置纪元数,并开始训练和测试模型\n",
"epochs = 5\n",
"train_losses = list()\n",
"test_acc = list()\n",
"for epoch in range(epochs):\n",
" loss = train(model, train_loader, loss_function, optimizer, epoch) # 训练模型,并返回该纪元的 Loss 列表\n",
" train_losses += loss\n",
" acc = test(model, val_loader, loss_function, epoch) # 测试模型,并返回该纪元的正确率\n",
" test_acc.append(acc)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3.5 可视化验证"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"模型训练完毕。最后,我们利用可视化工具,直观感受下我们的训练结果吧!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* **训练集: Loss 下降趋势**"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(train_losses, label=\"Training Loss\")\n",
"plt.xlabel(\"Iterations\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* **测试集: 正确率上升状况**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(test_acc, label=\"Test Accuracy\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* **模型预测效果** \n",
"\n",
"我们在测试集上观看一下预测效果:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAOF0lEQVR4nO3da6xV9ZnH8d9PxqpRjBzJIFK0F020qUonBE1sJowNjZJ4jfGS2DiJDJXoqKGRIcyLkugLM0yp80o8TYloqk2TltQXximDJt6SKhjk4qXYeql45IjGFFRE4JkXZ2FO9ez/Puy1b5zn+0lOzt7r2WuvJxt+Z+29/nutvyNCACa+o3rdAIDuIOxAEoQdSIKwA0kQdiCJf+jmxmxz6B/osIjwWMtr7dltX2z7Nduv215a57kAdJZbHWe3PUnSnyTNk/SOpBckXR8RLxfWYc8OdFgn9uxzJL0eEX+JiH2Sfi3p8hrPB6CD6oR9hqS/jrr/TrXs79heaHuD7Q01tgWgpo4foIuIQUmDEm/jgV6qs2ffIWnmqPtfr5YB6EN1wv6CpDNtf9P21yRdJ+nR9rQFoN1afhsfEftt3yrpfyVNkrQ6Ira1rTMAbdXy0FtLG+MzO9BxHflSDYAjB2EHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiZbnZ5ck229K2i3pgKT9ETG7HU0BaL9aYa/8S0TsasPzAOgg3sYDSdQNe0j6g+2NtheO9QDbC21vsL2h5rYA1OCIaH1le0ZE7LD9j5LWSfr3iHiq8PjWNwZgXCLCYy2vtWePiB3V72FJayXNqfN8ADqn5bDbPt725EO3Jf1Q0tZ2NQagveocjZ8maa3tQ8/zcEQ83pauALRdrc/sh70xPrMDHdeRz+wAjhyEHUiCsANJEHYgCcIOJNGOE2FwBDvjjDOK9alTpxbrV155ZbE+d+7chrWDBw8W1121alWx/txzzxXr27dvL9azYc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lw1tsEcM455zSs3XLLLcV1r7rqqmK92Th7L+3fv79Yf+211xrWnnnmmeK6t99+e7G+b9++Yr2XOOsNSI6wA0kQdiAJwg4kQdiBJAg7kARhB5LgfPY+cO655xbrzcbKr7322oa1E088saWeDtmxY0ex/vTTTxfrb7zxRsPakiVLiutu3LixWJ8zpzwnycDAQMPa/Pnzi+u+9NJLxXqzc+37EXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiC89m74P777y/Wm117vc455evXry/Wt2zZUqwvW7asWN+7d+9h93TIk08+WawvWrSoWF+9enWxPmvWrIa1nTt3Ftc97bTTivVTTjmlWH///feL9U5q+Xx226ttD9veOmrZgO11trdXv6e0s1kA7Teet/EPSLr4S8uWSlofEWdKWl/dB9DHmoY9Ip6S9OGXFl8uaU11e42kK9rcF4A2a/W78dMiYqi6/Z6kaY0eaHuhpIUtbgdAm9Q+ESYionTgLSIGJQ1KeQ/QAf2g1aG3nbanS1L1e7h9LQHohFbD/qikG6vbN0r6fXvaAdApTd/G235E0lxJU22/I+mnku6R9BvbN0l6S9I1nWyyHxx77LENa83Oy16wYEGxbo85LPqFZmO29913X8PaihUriut+/PHHxXonnXzyycX6pEmTivXly5cX648//njD2umnn15cdyJqGvaIuL5B6Qdt7gVAB/F1WSAJwg4kQdiBJAg7kARhB5LgUtLjNHfu3Ia1O++8s7hus6G1d999t1hvNq3y888/X6x3UrPhsZkzZzasPfjgg8V1H3vssWJ9ypTWT7Zs9m/y0EMPFesfffRRy9vuFfbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zjVBpPPnDgQK3n/vzzz4v1888/v1i/+uqrG9bOOuuslno65NNPPy3Wzz777Jbru3btKq47bVrDq53V1uxS0nfffXex3uzfrB+xZweSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJiyeZyOO+64hrWHH364uO68efNafm5JOuqo8t/kOv+Gzb4j0Ox89V46ePBgsb527dqGtdtuu6247tDQULHez1qeshnAxEDYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzt4FJ510UrG+dOnSYv3CCy8s1j/44IOGtbfffru47jHHHFOsn3feecX6nDlzivVOWrVqVbG+bNmyhrUj8brv49XyOLvt1baHbW8dtWy57R22N1U/89vZLID2G8/b+AckXTzG8p9HxKzqpzx1B4Ceaxr2iHhK0odd6AVAB9U5QHer7c3V2/yGk27ZXmh7g+0NNbYFoKZWw36fpG9LmiVpSNLPGj0wIgYjYnZEzG5xWwDaoKWwR8TOiDgQEQcl/UJS7w7JAhiXlsJue/qou1dK2trosQD6Q9NxdtuPSJoraaqknZJ+Wt2fJSkkvSnpxxHR9ATgrOPsR7Jmc6jfcMMNLT/37t27i/XFixcX6w888ECxXvd6/keqRuPsTSeJiIjrx1j8y9odAegqvi4LJEHYgSQIO5AEYQeSIOxAEkzZnNySJUuK9euuu65j2160aFGx3uwS3Tg87NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAkuJT3BLViwoFhfuXJlsX7CCSfU2v62bdsa1mbPLl+86LPPPqu17ayYshlIjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcfQIoTZu8bt264rqTJ0+ute09e/YU65dccknD2rPPPltr2xgb4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXjZ8ALr300oa1uuPon3zySbF+2WWXFeuMpfePpnt22zNtP2n7ZdvbbN9eLR+wvc729ur3lM63C6BV43kbv1/STyLiO5IukHSL7e9IWippfUScKWl9dR9An2oa9ogYiogXq9u7Jb0iaYakyyWtqR62RtIVnWoSQH2H9Znd9jckfU/SHyVNi4ihqvSepGkN1lkoaWHrLQJoh3Efjbd9gqTfSrojIv42uhYjZ9OMeZJLRAxGxOyIKF9dEEBHjSvsto/WSNB/FRG/qxbvtD29qk+XNNyZFgG0Q9NTXG1bI5/JP4yIO0YtXyHpg4i4x/ZSSQMRUZz/l1NcW9Ns+GzXrl0Na0cffXStbQ8ODhbrN998c63nR/s1OsV1PJ/ZL5T0I0lbbG+qli2TdI+k39i+SdJbkq5pR6MAOqNp2CPiGUlj/qWQ9IP2tgOgU/i6LJAEYQeSIOxAEoQdSIKwA0lwKek+0Gxa5FdffbVYP/XUU1ve9ubNm4v1Cy64oFjfu3dvy9tGZ3ApaSA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgktJ94GLLrqoWJ8xY0axXue7EosXLy7WGUefONizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLP3gbvuuqtYrzOOvmLFimL9iSeeaPm5cWRhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTQdZ7c9U9KDkqZJCkmDEfE/tpdL+jdJ71cPXRYRj3Wq0YlsYGCgWLcbTaI7Ynh4uGHt3nvvbaknTDzj+VLNfkk/iYgXbU+WtNH2uqr284j47861B6BdxjM/+5Ckoer2btuvSCpfOgVA3zmsz+y2vyHpe5L+WC261fZm26ttT2mwzkLbG2xvqNUpgFrGHXbbJ0j6raQ7IuJvku6T9G1JszSy5//ZWOtFxGBEzI6I2W3oF0CLxhV220drJOi/iojfSVJE7IyIAxFxUNIvJM3pXJsA6moado8cCv6lpFciYuWo5dNHPexKSVvb3x6AdhnP0fgLJf1I0hbbm6plyyRdb3uWRobj3pT04450mMDKlStr1UunyA4NDbXUEyae8RyNf0bSWAO9jKkDRxC+QQckQdiBJAg7kARhB5Ig7EAShB1IwnUuU3zYG7O7tzEgqYgY85xo9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kES3p2zeJemtUfenVsv6Ub/21q99SfTWqnb2dnqjQle/VPOVjdsb+vXadP3aW7/2JdFbq7rVG2/jgSQIO5BEr8M+2OPtl/Rrb/3al0RvrepKbz39zA6ge3q9ZwfQJYQdSKInYbd9se3XbL9ue2kvemjE9pu2t9je1Ov56ao59IZtbx21bMD2Otvbq99jzrHXo96W295RvXabbM/vUW8zbT9p+2Xb22zfXi3v6WtX6Ksrr1vXP7PbniTpT5LmSXpH0guSro+Il7vaSAO235Q0OyJ6/gUM2/8saY+kByPiu9Wy/5L0YUTcU/2hnBIR/9EnvS2XtKfX03hXsxVNHz3NuKQrJP2revjaFfq6Rl143XqxZ58j6fWI+EtE7JP0a0mX96CPvhcRT0n68EuLL5e0prq9RiP/WbquQW99ISKGIuLF6vZuSYemGe/pa1foqyt6EfYZkv466v476q/53kPSH2xvtL2w182MYVpEHJrT6T1J03rZzBiaTuPdTV+aZrxvXrtWpj+viwN0X/X9iPgnSZdIuqV6u9qXYuQzWD+NnY5rGu9uGWOa8S/08rVrdfrzunoR9h2SZo66//VqWV+IiB3V72FJa9V/U1HvPDSDbvV7uMf9fKGfpvEea5px9cFr18vpz3sR9hcknWn7m7a/Juk6SY/2oI+vsH18deBEto+X9EP131TUj0q6sbp9o6Tf97CXv9Mv03g3mmZcPX7tej79eUR0/UfSfI0ckf+zpP/sRQ8N+vqWpJeqn2297k3SIxp5W/e5Ro5t3CTpZEnrJW2X9H+SBvqot4ckbZG0WSPBmt6j3r6vkbfomyVtqn7m9/q1K/TVldeNr8sCSXCADkiCsANJEHYgCcIOJEHYgSQIO5AEYQeS+H9C3nyYchtG2QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"target: 9\n",
"prediction: 9\n"
]
}
],
"source": [
"num = 9 # 选择查看第几个数据的验证效果\n",
"for inputs, targets in val_loader: \n",
" \n",
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n",
" plt.show() \n",
" \n",
" print(\"target:\", targets[num].numpy()[0]) # 打印该数据的真实标签值\n",
" \n",
" outputs = model(inputs) # 模型根据输入数据进行预测\n",
" pred = np.argmax(outputs.numpy(), axis=1) # 根据最大相似度得到预测值\n",
" print(\"prediction:\", pred[num]) # 打印该数据的预测值\n",
" break"
]
},
{
"attachments": {
"image.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"# 尾声 📣\n",
"\n",
"恭喜您,已经完成了计图入门教程的所有内容。 \n",
"现在,您已经是一名合格的计图使用者了。 \n",
"计图官方时常会举办一些 “人工智能算法挑战赛” ,并附赠丰厚的奖金回报。作为一名合格的计图使用者,请来大赛中斩获一席之地吧!🎉🎊🎈\n",
"\n",
"更多学习资料,可以参考:\n",
"\n",
"* [在线PyTorch转Jittor工具](https://cg.cs.tsinghua.edu.cn/jittor/news/2020-12-13-20-40-pt_converter/)\n",
"* [PyTorch模型转换指南](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-2-16-43-pytorchconvert/)\n",
"* [Jittor文档](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html)\n",
"* [Jittor模型库](https://cg.cs.tsinghua.edu.cn/jittor/resources/)\n",
"* [Jittor教程](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/)\n",
"\n",
"Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟踪系统QQ群761222083中报告它们。 我们欢迎您为Jittor做出贡献。\n",
"\n",
"您可以用以下方式帮助Jittor\n",
"\n",
"* 在论文中引用 Jittor\n",
"* 向身边的好朋友推荐 Jittor\n",
"* 贡献代码\n",
"* 贡献教程和文档\n",
"* 提出issue\n",
"* 回答 jittor 相关问题\n",
"* 点亮小星星\n",
"* 持续关注 jittor\n",
"\n",
"[Jittor Github 地址](https://github.com/jittor/jittor), \n",
"[Jittor Gitee 地址](https://gitee.com/jittor/jittor)\n",
"\n",
"求star您的支持是对我们最大的鼓励\n",
"\n",
"![image.png](attachment:image.png)\n",
"\n",
"特别感谢本教程作者llt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}