-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from kohr-h/pytorch
Add classification example in Pytorch
- Loading branch information
Showing
3 changed files
with
476 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## MNIST classification\n", | ||
"\n", | ||
"In this notebook we tackle the perhaps most well known problem in all of machine learning, classifying hand-written digits.\n", | ||
"\n", | ||
"The particular dataset we will use is the MNIST (Modified National Institute of Standards and Technology)\n", | ||
"The digits are 28x28 pixel images that look somewhat like this:\n", | ||
"\n", | ||
"![](https://user-images.githubusercontent.com/2202312/32365318-b0ccc44a-c079-11e7-8fb1-6b1566c0bdc4.png)\n", | ||
"\n", | ||
"Each digit has been hand classified, e.g. for the above 9-7-0-9-0-...\n", | ||
"\n", | ||
"Our task is to teach a machine to perform this classification, i.e. we want to find a function $\\mathcal{T}_\\theta$ such that\n", | ||
"\n", | ||
"| | |\n", | ||
"|-|-|\n", | ||
"|$\\mathcal{T}_\\theta$(|<img align=\"center\" src=\"https://user-images.githubusercontent.com/2202312/33177374-b134e572-d062-11e7-87c7-0574c6f5bee9.png\" width=\"28\"/>|) = 4|" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Import dependencies\n", | ||
"\n", | ||
"This should run without errors if all dependencies are installed properly." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"% matplotlib inline\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import matplotlib.cm\n", | ||
"import numpy as np\n", | ||
"np.random.seed(0)\n", | ||
"from IPython.display import display, clear_output" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from torch import nn\n", | ||
"from torch.autograd import Variable\n", | ||
"from torch.nn import functional as F\n", | ||
"from torch import optim\n", | ||
"import torchvision\n", | ||
"from torchvision import datasets, transforms\n", | ||
"\n", | ||
"train_loader = torch.utils.data.DataLoader(\n", | ||
" datasets.MNIST('./data', train=True, download=True,\n", | ||
" transform=transforms.Compose([\n", | ||
" transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.1307,), (0.3081,))\n", | ||
" ])),\n", | ||
" batch_size=64, shuffle=True)\n", | ||
"test_loader = torch.utils.data.DataLoader(\n", | ||
" datasets.MNIST('./data', train=False, transform=transforms.Compose([\n", | ||
" transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.1307,), (0.3081,))\n", | ||
" ])),\n", | ||
" batch_size=64, shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<matplotlib.image.AxesImage at 0x7f44b0c58a20>" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
}, | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADfZJREFUeJzt3X+o1fUdx/HXO7cZdAUV0ZmzWaNCDWzjcpO2RiPuciOwQV1WQa7Frn8oJUQU/eOtWI2x3CRhcIemZcuErLRGc8hYDiLUfm/OrdL0Tr1OHKlEret974/7NW52z+d7POd7zvfc+34+QM6P9/d7vm8Ovu73+z2f8z0fc3cBiOecshsAUA7CDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqC81c2NmxtcJgQZzd6tmubr2/Ga2wMz2mNm7ZnZvPa8FoLms1u/2m9k4Sf+U1CmpT9IOSTe5+98T67DnBxqsGXv+Dknvuvv77v4/SRskLazj9QA0UT3hnyHpwLDHfdlzn2Nm3Wa208x21rEtAAWr5wO/kQ4tvnBY7+69knolDvuBVlLPnr9P0sxhj78m6WB97QBolnrCv0PSxWZ2oZl9RdKPJW0upi0AjVbzYb+7D5jZUkl/lDRO0hp3/1thnQFoqJqH+mraGOf8QMM15Us+AEYvwg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqeYpuSTKzfZJOSDolacDd24toCsUZN25csj5p0qSGbr+np6dira2tLbnunDlzkvUbbrghWV+/fn3F2lVXXZVcd2BgIFnv7e1N1pcsWZKst4K6wp/5nrsfLeB1ADQRh/1AUPWG3yVtNbNdZtZdREMAmqPew/5vu/tBM5sq6U9m9g93f3n4AtkfBf4wAC2mrj2/ux/Mbo9IelZSxwjL9Lp7Ox8GAq2l5vCb2XlmNuH0fUnfl/ROUY0BaKx6DvunSXrWzE6/zu/d/aVCugLQcDWH393flzSvwF7GrIsuuihZP/fcc5P1a6+9Nlnv7OysWJs4cWJy3fnz5yfrZTp+/HiyvnHjxmS9o+MLZ6Gf+eSTT5LrHjhwIFnftm1bsj4aMNQHBEX4gaAIPxAU4QeCIvxAUIQfCMrcvXkbM2vexpoo7/LQrVu3Juvjx48vsp1RI+//3l133ZWsnzx5suZt5w3lHT58OFl/8803a952o7m7VbMce34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/gJMmTIlWd+zZ0+y3uifz67H3r17k/UTJ04k63Pnzq1YO3XqVHLdvEudMTLG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUEXM0hve0aPpSYrvvvvuZL2rqytZf+WVV5L15cuXJ+spfX19yfq8eelfZ8+7pr69vfJETQ888EByXTQWe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCr3en4zWyPpOklH3P2y7LnJkp6WNEvSPkld7v7f3I2N0ev565U3jfaHH36YrL/44osVawsWLEiue+eddybrjz76aLKO1lPk9fxrJZ35P+heSdvc/WJJ27LHAEaR3PC7+8uSjp3x9EJJ67L76yRdX3BfABqs1nP+ae5+SJKy26nFtQSgGRr+3X4z65bU3ejtADg7te75+81suiRlt0cqLejuve7e7u6Vr/AA0HS1hn+zpEXZ/UWSni+mHQDNkht+M3tK0iuSLjWzPjO7XdIvJHWa2b8kdWaPAYwi/G7/GLB+/fqKtZtvvjm5bt6cAqnf3ZekwcHBZB3Nx+/2A0gi/EBQhB8IivADQRF+ICjCDwTFUN8Y0NbWVrG2Y8eO5LqXXnppsp43VLhhw4ZkHc3HUB+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/jFu9uzZyfrrr7+erH/88cfJ+q5du5L17du3V6zdf//9yXWb+X9zLGGcH0AS4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/cLfffnuyvmrVqmR9/PjxNW97xYoVyfrKlSuT9QMHDtS87bGMcX4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EFTuOL+ZrZF0naQj7n5Z9lyPpJ9J+k+22H3u/ofcjTHOP+pcccUVyfrq1auT9Tlz5tS87S1btiTrd9xxR7L+wQcf1Lzt0azIcf61khaM8Pyv3f3y7F9u8AG0ltzwu/vLko41oRcATVTPOf9SM3vLzNaY2aTCOgLQFLWG/7eSviHpckmHJD1SaUEz6zaznWa2s8ZtAWiAmsLv7v3ufsrdByX9TlJHYtled2939/ZamwRQvJrCb2bThz38kaR3imkHQLN8KW8BM3tK0tWSpphZn6Tlkq42s8sluaR9khY3sEcADcD1/KjL5MmTk/Vbb721Yu2RRyp+VCRJMksPV+/evTtZnzt3brI+VnE9P4Akwg8ERfiBoAg/EBThB4Ii/EBQDPWhNAMDA8n6Oeek902Dg4PJeldXV8Xapk2bkuuOZgz1AUgi/EBQhB8IivADQRF+ICjCDwRF+IGgcq/nR2zz589P1m+77baa188bx89z+PDhZP25556r6/XHOvb8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4/xj3Lx585L1np6eZP2aa65J1tva2s62parlXa9/9OjRutaPjj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwSVO85vZjMlPS7pq5IGJfW6+0ozmyzpaUmzJO2T1OXu/21cq3HNmDEjWV+6dGnF2uLFi5PrTpw4saaeirB///5kPe87CGvXri2umYCq2fMPSLrL3WdLmi9piZnNkXSvpG3ufrGkbdljAKNEbvjd/ZC7v5bdPyFpt6QZkhZKWpcttk7S9Y1qEkDxzuqc38xmSfqmpFclTXP3Q9LQHwhJU4tuDkDjVP3dfjNrk/SMpGXuftysqunAZGbdkrpraw9Ao1S15zezL2so+E+6++kZDvvNbHpWny7pyEjrunuvu7e7e3sRDQMoRm74bWgXv1rSbndfMay0WdKi7P4iSc8X3x6ARsmdotvMviNpu6S3NTTUJ0n3aei8f6OkCyTtl3Sjux/Lea2QU3Sff/75yfqVV16ZrK9atSpZnzq1vI9b9u7dm6w/9NBDFWuPPfZYcl0uya1NtVN0557zu/tfJVV6sfTF3gBaFt/wA4Ii/EBQhB8IivADQRF+ICjCDwTFT3dXacqUKRVrW7ZsSa57ySWXJOuTJk2qqacivPfee8n6ww8/nKxv2LAhWf/oo4/Ouic0B3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwgqzDh/Z2dnsv7ggw8m67Nnz65YmzBhQk09FeXTTz+tWHviiSeS6y5btixZP3nyZE09ofWx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMKM899yyy3JekdHR8O23d/fn6y/9NJLyfrAwECyfs8991SsHTuWnEoBgbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgzN3TC5jNlPS4pK9KGpTU6+4rzaxH0s8k/Sdb9D53/0POa6U3BqBu7m7VLFdN+KdLmu7ur5nZBEm7JF0vqUvSSXf/VbVNEX6g8aoNf+43/Nz9kKRD2f0TZrZb0oz62gNQtrM65zezWZK+KenV7KmlZvaWma0xsxHnnDKzbjPbaWY76+oUQKFyD/s/W9CsTdJfJP3c3TeZ2TRJRyW5pAc1dGrw05zX4LAfaLDCzvklycy+LOkFSX909xUj1GdJesHdL8t5HcIPNFi14c897Dczk7Ra0u7hwc8+CDztR5LeOdsmAZSnmk/7vyNpu6S3NTTUJ0n3SbpJ0uUaOuzfJ2lx9uFg6rXY8wMNVuhhf1EIP9B4hR32AxibCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0E1e4ruo5I+GPZ4SvZcK2rV3lq1L4nealVkb1+vdsGmXs//hY2b7XT39tIaSGjV3lq1L4nealVWbxz2A0ERfiCossPfW/L2U1q1t1btS6K3WpXSW6nn/ADKU/aeH0BJSgm/mS0wsz1m9q6Z3VtGD5WY2T4ze9vM3ih7irFsGrQjZvbOsOcmm9mfzOxf2e2I06SV1FuPmf07e+/eMLMfltTbTDP7s5ntNrO/mdmd2fOlvneJvkp535p+2G9m4yT9U1KnpD5JOyTd5O5/b2ojFZjZPknt7l76mLCZfVfSSUmPn54Nycx+KemYu/8i+8M5yd3vaZHeenSWMzc3qLdKM0v/RCW+d0XOeF2EMvb8HZLedff33f1/kjZIWlhCHy3P3V+WdOyMpxdKWpfdX6eh/zxNV6G3luDuh9z9tez+CUmnZ5Yu9b1L9FWKMsI/Q9KBYY/71FpTfrukrWa2y8y6y25mBNNOz4yU3U4tuZ8z5c7c3ExnzCzdMu9dLTNeF62M8I80m0grDTl8292/JekHkpZkh7eozm8lfUND07gdkvRImc1kM0s/I2mZux8vs5fhRuirlPetjPD3SZo57PHXJB0soY8RufvB7PaIpGc1dJrSSvpPT5Ka3R4puZ/PuHu/u59y90FJv1OJ7102s/Qzkp50903Z06W/dyP1Vdb7Vkb4d0i62MwuNLOvSPqxpM0l9PEFZnZe9kGMzOw8Sd9X680+vFnSouz+IknPl9jL57TKzM2VZpZWye9dq814XcqXfLKhjN9IGidpjbv/vOlNjMDMLtLQ3l4auuLx92X2ZmZPSbpaQ1d99UtaLuk5SRslXSBpv6Qb3b3pH7xV6O1qneXMzQ3qrdLM0q+qxPeuyBmvC+mHb/gBMfENPyAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQf0fPcdHNxTQ528AAAAASUVORK5CYII=\n", | ||
"text/plain": [ | ||
"<matplotlib.figure.Figure at 0x7f448e942e10>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"img = train_loader.dataset[0][0]\n", | ||
"plt.imshow(img.squeeze(), cmap='Greys_r')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Multilayer Perceptron\n", | ||
"\n", | ||
"The first \"deep\" neural networks were [multilayer perceptrons](https://en.wikipedia.org/wiki/Multilayer_perceptron), in these we have a function of the following form\n", | ||
"\n", | ||
"$$\n", | ||
"\\rho(W_3\\rho(W_2\\rho(W_1 x + b_1) + b_2) + b_3)\n", | ||
"$$\n", | ||
"\n", | ||
"Where $W_i$ are matrices and $b_i$ vectors. Note that the logistic regression can be cast into this form (how?)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class MLP(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(MLP, self).__init__()\n", | ||
" self.lin1 = nn.Linear(784, 128)\n", | ||
" self.lin2 = nn.Linear(128, 32)\n", | ||
" self.lin3 = nn.Linear(32, 10)\n", | ||
"\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = F.relu(self.lin1(x.view(-1, 784)))\n", | ||
" x = F.relu(self.lin2(x))\n", | ||
" x = F.relu(self.lin3(x))\n", | ||
" return F.log_softmax(x)\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"class ConvNet(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(ConvNet, self).__init__()\n", | ||
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2)\n", | ||
" self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2)\n", | ||
" self.fc = nn.Linear(32 * 36, 10)\n", | ||
"\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = F.relu(self.conv1(x))\n", | ||
" x = F.relu(self.conv2(x))\n", | ||
" x = x.view(-1, 32 * 36)\n", | ||
" x = self.fc(x)\n", | ||
" return F.log_softmax(x)\n", | ||
"\n", | ||
"use_cuda = True\n", | ||
"learning_rate = 1e-2\n", | ||
"log_interval = 500\n", | ||
"epochs = 20\n", | ||
"model = ConvNet()\n", | ||
"if use_cuda:\n", | ||
" model.cuda()\n", | ||
"optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", | ||
"\n", | ||
"\n", | ||
"def train(epoch):\n", | ||
" model.train()\n", | ||
" for batch_idx, (data, target) in enumerate(train_loader):\n", | ||
" if use_cuda:\n", | ||
" data, target = data.cuda(), target.cuda()\n", | ||
" data, target = Variable(data), Variable(target)\n", | ||
" optimizer.zero_grad()\n", | ||
" output = model(data)\n", | ||
" loss = F.nll_loss(output, target)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
" if batch_idx % log_interval == 0:\n", | ||
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", | ||
" epoch, batch_idx * len(data), len(train_loader.dataset),\n", | ||
" 100. * batch_idx / len(train_loader), loss.data[0]))\n", | ||
"\n", | ||
"\n", | ||
"def test():\n", | ||
" model.eval()\n", | ||
" test_loss = 0\n", | ||
" correct = 0\n", | ||
" for data, target in test_loader:\n", | ||
" if use_cuda:\n", | ||
" data, target = data.cuda(), target.cuda()\n", | ||
" data, target = Variable(data, volatile=True), Variable(target)\n", | ||
" output = model(data)\n", | ||
" test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss\n", | ||
" pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", | ||
" correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", | ||
"\n", | ||
"\n", | ||
" test_loss /= len(test_loader.dataset)\n", | ||
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", | ||
" test_loss, correct, len(test_loader.dataset),\n", | ||
" 100. * correct / len(test_loader.dataset)))\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"for epoch in range(1, epochs + 1):\n", | ||
" train(epoch)\n", | ||
" test()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.