-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Holger Kohr
committed
Dec 13, 2017
1 parent
a140457
commit dd943c9
Showing
2 changed files
with
238 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/hkohr/git/odl/odl/trafos/backends/pyfftw_bindings.py:30: RuntimeWarning: PyFFTW < 0.10.4 is known to cause problems with some ODL functionality, see issue #1002.\n", | ||
" RuntimeWarning)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"% matplotlib inline\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import matplotlib.cm\n", | ||
"import numpy as np\n", | ||
"from IPython.display import display, clear_output\n", | ||
"import odl\n", | ||
"import torch\n", | ||
"from torch import nn\n", | ||
"from torch.autograd import Variable\n", | ||
"from torch.nn import functional as F\n", | ||
"from torch import optim\n", | ||
"import torchvision\n", | ||
"from torchvision import datasets, transforms\n", | ||
"np.random.seed(0)\n", | ||
"\n", | ||
"train_loader = torch.utils.data.DataLoader(\n", | ||
" datasets.MNIST('./data', train=True, download=True,\n", | ||
" transform=transforms.Compose([\n", | ||
" transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.1307,), (0.3081,))\n", | ||
" ])),\n", | ||
" batch_size=64, shuffle=True)\n", | ||
"test_loader = torch.utils.data.DataLoader(\n", | ||
" datasets.MNIST('./data', train=False, transform=transforms.Compose([\n", | ||
" transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.1307,), (0.3081,))\n", | ||
" ])),\n", | ||
" batch_size=64, shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create ODL data structures\n", | ||
"space = odl.uniform_discr([-14, -14], [14, 14], [28, 28],\n", | ||
" dtype='float32')\n", | ||
"\n", | ||
"geometry = odl.tomo.parallel_beam_geometry(space, num_angles=5)\n", | ||
"operator = odl.tomo.RayTransform(space, geometry)\n", | ||
"fbp_op = odl.tomo.fbp_op(operator)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def generate_data(images):\n", | ||
" \"\"\"Generate data from images\n", | ||
" \n", | ||
" Parameters\n", | ||
" ----------\n", | ||
" images : np.array of shape [Batch, 28, 28, 1]\n", | ||
" The images (in reconstruction space) which we should create data for.\n", | ||
" \n", | ||
" Returns\n", | ||
" -------\n", | ||
" sinograms : np.array of shape [Batch, 5, 41, 1]\n", | ||
" Noisy sinograms corresponding to ``images``\n", | ||
" \"\"\"\n", | ||
" data = [operator(image.squeeze()).asarray() +\n", | ||
" np.random.randn(*operator.range.shape) for image in images]\n", | ||
" return np.array(data)[..., None]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"?? " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "ValueError", | ||
"evalue": "Too many dimensions: 3 > 2.", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | ||
"\u001b[0;32m<ipython-input-8-67e798522bac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | ||
"\u001b[0;32m~/git/torchvision/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;31m# doing this so that it is consistent with all other datasets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;31m# to return a PIL Image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfromarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'L'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;32m~/miniconda/envs/odl_py3/lib/python3.6/site-packages/PIL/Image.py\u001b[0m in \u001b[0;36mfromarray\u001b[0;34m(obj, mode)\u001b[0m\n\u001b[1;32m 2427\u001b[0m \u001b[0mndmax\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2428\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mndmax\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2429\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Too many dimensions: %d > %d.\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndmax\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2430\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2431\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;31mValueError\u001b[0m: Too many dimensions: 3 > 2." | ||
] | ||
} | ||
], | ||
"source": [ | ||
"batch = train_loader.dataset[:1000]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class MLP(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(MLP, self).__init__()\n", | ||
" self.lin1 = nn.Linear(784, 128)\n", | ||
" self.lin2 = nn.Linear(128, 32)\n", | ||
" self.lin3 = nn.Linear(32, 10)\n", | ||
"\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = F.relu(self.lin1(x.view(-1, 784)))\n", | ||
" x = F.relu(self.lin2(x))\n", | ||
" x = F.relu(self.lin3(x))\n", | ||
" return F.log_softmax(x)\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"class ConvNet(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(ConvNet, self).__init__()\n", | ||
" self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2)\n", | ||
" self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=2)\n", | ||
" self.fc = nn.Linear(32 * 36, 10)\n", | ||
"\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = F.relu(self.conv1(x))\n", | ||
" x = F.relu(self.conv2(x))\n", | ||
" x = x.view(-1, 32 * 36)\n", | ||
" x = self.fc(x)\n", | ||
" return F.log_softmax(x)\n", | ||
"\n", | ||
"use_cuda = True\n", | ||
"learning_rate = 1e-2\n", | ||
"log_interval = 500\n", | ||
"epochs = 20\n", | ||
"model = ConvNet()\n", | ||
"if use_cuda:\n", | ||
" model.cuda()\n", | ||
"optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", | ||
"\n", | ||
"\n", | ||
"def train(epoch):\n", | ||
" model.train()\n", | ||
" for batch_idx, (data, target) in enumerate(train_loader):\n", | ||
" if use_cuda:\n", | ||
" data, target = data.cuda(), target.cuda()\n", | ||
" data, target = Variable(data), Variable(target)\n", | ||
" optimizer.zero_grad()\n", | ||
" output = model(data)\n", | ||
" loss = F.nll_loss(output, target)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
" if batch_idx % log_interval == 0:\n", | ||
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", | ||
" epoch, batch_idx * len(data), len(train_loader.dataset),\n", | ||
" 100. * batch_idx / len(train_loader), loss.data[0]))\n", | ||
"\n", | ||
"\n", | ||
"def test():\n", | ||
" model.eval()\n", | ||
" test_loss = 0\n", | ||
" correct = 0\n", | ||
" for data, target in test_loader:\n", | ||
" if use_cuda:\n", | ||
" data, target = data.cuda(), target.cuda()\n", | ||
" data, target = Variable(data, volatile=True), Variable(target)\n", | ||
" output = model(data)\n", | ||
" test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss\n", | ||
" pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", | ||
" correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", | ||
"\n", | ||
"\n", | ||
" test_loss /= len(test_loader.dataset)\n", | ||
" print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", | ||
" test_loss, correct, len(test_loader.dataset),\n", | ||
" 100. * correct / len(test_loader.dataset)))\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"for epoch in range(1, epochs + 1):\n", | ||
" train(epoch)\n", | ||
" test()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |