Skip to content

Commit

Permalink
Add stub of part 3 pytorch version
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Dec 13, 2017
1 parent a140457 commit dd943c9
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 1 deletion.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,9 @@ target/
# Documentation build files

# Course specific
*.gz
*.gz

# Downloaded data (pytorch)
code/data/raw/
code/data/processed/

232 changes: 232 additions & 0 deletions code/part3_learned_reconstruction_pytorch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hkohr/git/odl/odl/trafos/backends/pyfftw_bindings.py:30: RuntimeWarning: PyFFTW < 0.10.4 is known to cause problems with some ODL functionality, see issue #1002.\n",
" RuntimeWarning)\n"
]
}
],
"source": [
"% matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.cm\n",
"import numpy as np\n",
"from IPython.display import display, clear_output\n",
"import odl\n",
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"from torch.nn import functional as F\n",
"from torch import optim\n",
"import torchvision\n",
"from torchvision import datasets, transforms\n",
"np.random.seed(0)\n",
"\n",
"train_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('./data', train=True, download=True,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=64, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST('./data', train=False, transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])),\n",
" batch_size=64, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create ODL data structures\n",
"space = odl.uniform_discr([-14, -14], [14, 14], [28, 28],\n",
" dtype='float32')\n",
"\n",
"geometry = odl.tomo.parallel_beam_geometry(space, num_angles=5)\n",
"operator = odl.tomo.RayTransform(space, geometry)\n",
"fbp_op = odl.tomo.fbp_op(operator)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def generate_data(images):\n",
" \"\"\"Generate data from images\n",
" \n",
" Parameters\n",
" ----------\n",
" images : np.array of shape [Batch, 28, 28, 1]\n",
" The images (in reconstruction space) which we should create data for.\n",
" \n",
" Returns\n",
" -------\n",
" sinograms : np.array of shape [Batch, 5, 41, 1]\n",
" Noisy sinograms corresponding to ``images``\n",
" \"\"\"\n",
" data = [operator(image.squeeze()).asarray() +\n",
" np.random.randn(*operator.range.shape) for image in images]\n",
" return np.array(data)[..., None]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"?? "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "Too many dimensions: 3 > 2.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-8-67e798522bac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/git/torchvision/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;31m# doing this so that it is consistent with all other datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;31m# to return a PIL Image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfromarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'L'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda/envs/odl_py3/lib/python3.6/site-packages/PIL/Image.py\u001b[0m in \u001b[0;36mfromarray\u001b[0;34m(obj, mode)\u001b[0m\n\u001b[1;32m 2427\u001b[0m \u001b[0mndmax\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2428\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mndmax\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2429\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Too many dimensions: %d > %d.\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndmax\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2430\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2431\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Too many dimensions: 3 > 2."
]
}
],
"source": [
"batch = train_loader.dataset[:1000]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super(MLP, self).__init__()\n",
" self.lin1 = nn.Linear(784, 128)\n",
" self.lin2 = nn.Linear(128, 32)\n",
" self.lin3 = nn.Linear(32, 10)\n",
"\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.lin1(x.view(-1, 784)))\n",
" x = F.relu(self.lin2(x))\n",
" x = F.relu(self.lin3(x))\n",
" return F.log_softmax(x)\n",
"\n",
"\n",
"\n",
"class ConvNet(nn.Module):\n",
" def __init__(self):\n",
" super(ConvNet, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2)\n",
" self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2)\n",
" self.fc = nn.Linear(32 * 36, 10)\n",
"\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(self.conv1(x))\n",
" x = F.relu(self.conv2(x))\n",
" x = x.view(-1, 32 * 36)\n",
" x = self.fc(x)\n",
" return F.log_softmax(x)\n",
"\n",
"use_cuda = True\n",
"learning_rate = 1e-2\n",
"log_interval = 500\n",
"epochs = 20\n",
"model = ConvNet()\n",
"if use_cuda:\n",
" model.cuda()\n",
"optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
"\n",
"\n",
"def train(epoch):\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" if use_cuda:\n",
" data, target = data.cuda(), target.cuda()\n",
" data, target = Variable(data), Variable(target)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" if batch_idx % log_interval == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader.dataset),\n",
" 100. * batch_idx / len(train_loader), loss.data[0]))\n",
"\n",
"\n",
"def test():\n",
" model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" for data, target in test_loader:\n",
" if use_cuda:\n",
" data, target = data.cuda(), target.cuda()\n",
" data, target = Variable(data, volatile=True), Variable(target)\n",
" output = model(data)\n",
" test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss\n",
" pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
" correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n",
"\n",
"\n",
" test_loss /= len(test_loader.dataset)\n",
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(test_loader.dataset),\n",
" 100. * correct / len(test_loader.dataset)))\n",
"\n",
"\n",
"\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" train(epoch)\n",
" test()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit dd943c9

Please sign in to comment.