Skip to content

Commit

Permalink
Merge pull request #18 from kohr-h/pytorch
Browse files Browse the repository at this point in the history
Add classification example in Pytorch
  • Loading branch information
adler-j committed Dec 13, 2017
2 parents 56e8b88 + dd943c9 commit b3c5867
Show file tree
Hide file tree
Showing 3 changed files with 476 additions and 1 deletion.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,9 @@ target/
# Documentation build files

# Course specific
*.gz
*.gz

# Downloaded data (pytorch)
code/data/raw/
code/data/processed/

238 changes: 238 additions & 0 deletions code/part2_classification_pytorch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MNIST classification\n",
"\n",
"In this notebook we tackle the perhaps most well known problem in all of machine learning, classifying hand-written digits.\n",
"\n",
"The particular dataset we will use is the MNIST (Modified National Institute of Standards and Technology)\n",
"The digits are 28x28 pixel images that look somewhat like this:\n",
"\n",
"![](https://user-images.githubusercontent.com/2202312/32365318-b0ccc44a-c079-11e7-8fb1-6b1566c0bdc4.png)\n",
"\n",
"Each digit has been hand classified, e.g. for the above 9-7-0-9-0-...\n",
"\n",
"Our task is to teach a machine to perform this classification, i.e. we want to find a function $\\mathcal{T}_\\theta$ such that\n",
"\n",
"| | |\n",
"|-|-|\n",
"|$\\mathcal{T}_\\theta$(|<img align=\"center\" src=\"https://user-images.githubusercontent.com/2202312/33177374-b134e572-d062-11e7-87c7-0574c6f5bee9.png\" width=\"28\"/>|) = 4|"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import dependencies\n",
"\n",
"This should run without errors if all dependencies are installed properly."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"% matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm\n",
"import numpy as np\n",
"np.random.seed(0)\n",
"from IPython.display import display, clear_output"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"from torch.nn import functional as F\n",
"from torch import optim\n",
"import torchvision\n",
"from torchvision import datasets, transforms\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('./data', train=True, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=64, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('./data', train=False, transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=64, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f44b0c58a20>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADfZJREFUeJzt3X+o1fUdx/HXO7cZdAUV0ZmzWaNCDWzjcpO2RiPuciOwQV1WQa7Frn8oJUQU/eOtWI2x3CRhcIemZcuErLRGc8hYDiLUfm/OrdL0Tr1OHKlEret974/7NW52z+d7POd7zvfc+34+QM6P9/d7vm8Ovu73+z2f8z0fc3cBiOecshsAUA7CDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqC81c2NmxtcJgQZzd6tmubr2/Ga2wMz2mNm7ZnZvPa8FoLms1u/2m9k4Sf+U1CmpT9IOSTe5+98T67DnBxqsGXv+Dknvuvv77v4/SRskLazj9QA0UT3hnyHpwLDHfdlzn2Nm3Wa208x21rEtAAWr5wO/kQ4tvnBY7+69knolDvuBVlLPnr9P0sxhj78m6WB97QBolnrCv0PSxWZ2oZl9RdKPJW0upi0AjVbzYb+7D5jZUkl/lDRO0hp3/1thnQFoqJqH+mraGOf8QMM15Us+AEYvwg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqeYpuSTKzfZJOSDolacDd24toCsUZN25csj5p0qSGbr+np6dira2tLbnunDlzkvUbbrghWV+/fn3F2lVXXZVcd2BgIFnv7e1N1pcsWZKst4K6wp/5nrsfLeB1ADQRh/1AUPWG3yVtNbNdZtZdREMAmqPew/5vu/tBM5sq6U9m9g93f3n4AtkfBf4wAC2mrj2/ux/Mbo9IelZSxwjL9Lp7Ox8GAq2l5vCb2XlmNuH0fUnfl/ROUY0BaKx6DvunSXrWzE6/zu/d/aVCugLQcDWH393flzSvwF7GrIsuuihZP/fcc5P1a6+9Nlnv7OysWJs4cWJy3fnz5yfrZTp+/HiyvnHjxmS9o+MLZ6Gf+eSTT5LrHjhwIFnftm1bsj4aMNQHBEX4gaAIPxAU4QeCIvxAUIQfCMrcvXkbM2vexpoo7/LQrVu3Juvjx48vsp1RI+//3l133ZWsnzx5suZt5w3lHT58OFl/8803a952o7m7VbMce34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/gJMmTIlWd+zZ0+y3uifz67H3r17k/UTJ04k63Pnzq1YO3XqVHLdvEudMTLG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUEXM0hve0aPpSYrvvvvuZL2rqytZf+WVV5L15cuXJ+spfX19yfq8eelfZ8+7pr69vfJETQ888EByXTQWe34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCCr3en4zWyPpOklH3P2y7LnJkp6WNEvSPkld7v7f3I2N0ev565U3jfaHH36YrL/44osVawsWLEiue+eddybrjz76aLKO1lPk9fxrJZ35P+heSdvc/WJJ27LHAEaR3PC7+8uSjp3x9EJJ67L76yRdX3BfABqs1nP+ae5+SJKy26nFtQSgGRr+3X4z65bU3ejtADg7te75+81suiRlt0cqLejuve7e7u6Vr/AA0HS1hn+zpEXZ/UWSni+mHQDNkht+M3tK0iuSLjWzPjO7XdIvJHWa2b8kdWaPAYwi/G7/GLB+/fqKtZtvvjm5bt6cAqnf3ZekwcHBZB3Nx+/2A0gi/EBQhB8IivADQRF+ICjCDwTFUN8Y0NbWVrG2Y8eO5LqXXnppsp43VLhhw4ZkHc3HUB+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIpx/jFu9uzZyfrrr7+erH/88cfJ+q5du5L17du3V6zdf//9yXWb+X9zLGGcH0AS4QeCIvxAUIQfCIrwA0ERfiAowg8ExTh/cLfffnuyvmrVqmR9/PjxNW97xYoVyfrKlSuT9QMHDtS87bGMcX4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EFTuOL+ZrZF0naQj7n5Z9lyPpJ9J+k+22H3u/ofcjTHOP+pcccUVyfrq1auT9Tlz5tS87S1btiTrd9xxR7L+wQcf1Lzt0azIcf61khaM8Pyv3f3y7F9u8AG0ltzwu/vLko41oRcATVTPOf9SM3vLzNaY2aTCOgLQFLWG/7eSviHpckmHJD1SaUEz6zaznWa2s8ZtAWiAmsLv7v3ufsrdByX9TlJHYtled2939/ZamwRQvJrCb2bThz38kaR3imkHQLN8KW8BM3tK0tWSpphZn6Tlkq42s8sluaR9khY3sEcADcD1/KjL5MmTk/Vbb721Yu2RRyp+VCRJMksPV+/evTtZnzt3brI+VnE9P4Akwg8ERfiBoAg/EBThB4Ii/EBQDPWhNAMDA8n6Oeek902Dg4PJeldXV8Xapk2bkuuOZgz1AUgi/EBQhB8IivADQRF+ICjCDwRF+IGgcq/nR2zz589P1m+77baa188bx89z+PDhZP25556r6/XHOvb8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4/xj3Lx585L1np6eZP2aa65J1tva2s62parlXa9/9OjRutaPjj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwSVO85vZjMlPS7pq5IGJfW6+0ozmyzpaUmzJO2T1OXu/21cq3HNmDEjWV+6dGnF2uLFi5PrTpw4saaeirB///5kPe87CGvXri2umYCq2fMPSLrL3WdLmi9piZnNkXSvpG3ufrGkbdljAKNEbvjd/ZC7v5bdPyFpt6QZkhZKWpcttk7S9Y1qEkDxzuqc38xmSfqmpFclTXP3Q9LQHwhJU4tuDkDjVP3dfjNrk/SMpGXuftysqunAZGbdkrpraw9Ao1S15zezL2so+E+6++kZDvvNbHpWny7pyEjrunuvu7e7e3sRDQMoRm74bWgXv1rSbndfMay0WdKi7P4iSc8X3x6ARsmdotvMviNpu6S3NTTUJ0n3aei8f6OkCyTtl3Sjux/Lea2QU3Sff/75yfqVV16ZrK9atSpZnzq1vI9b9u7dm6w/9NBDFWuPPfZYcl0uya1NtVN0557zu/tfJVV6sfTF3gBaFt/wA4Ii/EBQhB8IivADQRF+ICjCDwTFT3dXacqUKRVrW7ZsSa57ySWXJOuTJk2qqacivPfee8n6ww8/nKxv2LAhWf/oo4/Ouic0B3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwgqzDh/Z2dnsv7ggw8m67Nnz65YmzBhQk09FeXTTz+tWHviiSeS6y5btixZP3nyZE09ofWx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMKM899yyy3JekdHR8O23d/fn6y/9NJLyfrAwECyfs8991SsHTuWnEoBgbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgzN3TC5jNlPS4pK9KGpTU6+4rzaxH0s8k/Sdb9D53/0POa6U3BqBu7m7VLFdN+KdLmu7ur5nZBEm7JF0vqUvSSXf/VbVNEX6g8aoNf+43/Nz9kKRD2f0TZrZb0oz62gNQtrM65zezWZK+KenV7KmlZvaWma0xsxHnnDKzbjPbaWY76+oUQKFyD/s/W9CsTdJfJP3c3TeZ2TRJRyW5pAc1dGrw05zX4LAfaLDCzvklycy+LOkFSX909xUj1GdJesHdL8t5HcIPNFi14c897Dczk7Ra0u7hwc8+CDztR5LeOdsmAZSnmk/7vyNpu6S3NTTUJ0n3SbpJ0uUaOuzfJ2lx9uFg6rXY8wMNVuhhf1EIP9B4hR32AxibCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0E1e4ruo5I+GPZ4SvZcK2rV3lq1L4nealVkb1+vdsGmXs//hY2b7XT39tIaSGjV3lq1L4nealVWbxz2A0ERfiCossPfW/L2U1q1t1btS6K3WpXSW6nn/ADKU/aeH0BJSgm/mS0wsz1m9q6Z3VtGD5WY2T4ze9vM3ih7irFsGrQjZvbOsOcmm9mfzOxf2e2I06SV1FuPmf07e+/eMLMfltTbTDP7s5ntNrO/mdmd2fOlvneJvkp535p+2G9m4yT9U1KnpD5JOyTd5O5/b2ojFZjZPknt7l76mLCZfVfSSUmPn54Nycx+KemYu/8i+8M5yd3vaZHeenSWMzc3qLdKM0v/RCW+d0XOeF2EMvb8HZLedff33f1/kjZIWlhCHy3P3V+WdOyMpxdKWpfdX6eh/zxNV6G3luDuh9z9tez+CUmnZ5Yu9b1L9FWKMsI/Q9KBYY/71FpTfrukrWa2y8y6y25mBNNOz4yU3U4tuZ8z5c7c3ExnzCzdMu9dLTNeF62M8I80m0grDTl8292/JekHkpZkh7eozm8lfUND07gdkvRImc1kM0s/I2mZux8vs5fhRuirlPetjPD3SZo57PHXJB0soY8RufvB7PaIpGc1dJrSSvpPT5Ka3R4puZ/PuHu/u59y90FJv1OJ7102s/Qzkp50903Z06W/dyP1Vdb7Vkb4d0i62MwuNLOvSPqxpM0l9PEFZnZe9kGMzOw8Sd9X680+vFnSouz+IknPl9jL57TKzM2VZpZWye9dq814XcqXfLKhjN9IGidpjbv/vOlNjMDMLtLQ3l4auuLx92X2ZmZPSbpaQ1d99UtaLuk5SRslXSBpv6Qb3b3pH7xV6O1qneXMzQ3qrdLM0q+qxPeuyBmvC+mHb/gBMfENPyAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQf0fPcdHNxTQ528AAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f448e942e10>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = train_loader.dataset[0][0]\n",
"plt.imshow(img.squeeze(), cmap='Greys_r')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multilayer Perceptron\n",
"\n",
"The first \"deep\" neural networks were [multilayer perceptrons](https://en.wikipedia.org/wiki/Multilayer_perceptron), in these we have a function of the following form\n",
"\n",
"$$\n",
"\\rho(W_3\\rho(W_2\\rho(W_1 x + b_1) + b_2) + b_3)\n",
"$$\n",
"\n",
"Where $W_i$ are matrices and $b_i$ vectors. Note that the logistic regression can be cast into this form (how?)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super(MLP, self).__init__()\n",
" self.lin1 = nn.Linear(784, 128)\n",
" self.lin2 = nn.Linear(128, 32)\n",
" self.lin3 = nn.Linear(32, 10)\n",
"\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.lin1(x.view(-1, 784)))\n",
" x = F.relu(self.lin2(x))\n",
" x = F.relu(self.lin3(x))\n",
" return F.log_softmax(x)\n",
"\n",
"\n",
"\n",
"class ConvNet(nn.Module):\n",
" def __init__(self):\n",
" super(ConvNet, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2)\n",
" self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2)\n",
" self.fc = nn.Linear(32 * 36, 10)\n",
"\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.conv1(x))\n",
" x = F.relu(self.conv2(x))\n",
" x = x.view(-1, 32 * 36)\n",
" x = self.fc(x)\n",
" return F.log_softmax(x)\n",
"\n",
"use_cuda = True\n",
"learning_rate = 1e-2\n",
"log_interval = 500\n",
"epochs = 20\n",
"model = ConvNet()\n",
"if use_cuda:\n",
" model.cuda()\n",
"optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
"\n",
"\n",
"def train(epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" if use_cuda:\n",
" data, target = data.cuda(), target.cuda()\n",
" data, target = Variable(data), Variable(target)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" if batch_idx % log_interval == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), loss.data[0]))\n",
"\n",
"\n",
"def test():\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" for data, target in test_loader:\n",
" if use_cuda:\n",
" data, target = data.cuda(), target.cuda()\n",
" data, target = Variable(data, volatile=True), Variable(target)\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss\n",
" pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
" correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n",
"\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset)))\n",
"\n",
"\n",
"\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" train(epoch)\n",
" test()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit b3c5867

Please sign in to comment.