Skip to content

Commit

Permalink
fcn epo5 modified
Browse files Browse the repository at this point in the history
  • Loading branch information
astonzhang committed Jul 20, 2018
1 parent 2f07422 commit 0157f6a
Showing 1 changed file with 26 additions and 120 deletions.
146 changes: 26 additions & 120 deletions chapter_computer-vision/fcn.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import sys
sys.path.append('..')
import gluonbook as gb
from mxnet import gluon, init, nd, image
from mxnet.gluon import loss as gloss, model_zoo, nn
from mxnet.gluon import data as gdata, loss as gloss, model_zoo, nn
import numpy as np
```

## 转置卷积层

假设$f$是一个卷积层,给定输入$x$,我们可以计算前向输出$y=f(x)$。在反向求导$z=\frac{\partial\, f(y)}{\partial\,x}$时,我们知道$z$会得到跟$x$一样形状的输出。因为卷积运算的导数的导数是自己本身,我们可以合法定义转置卷积层,记为$g$,为交互了前向和反向求导函数的卷积层。也就是$z=g(y)$。

下面我们构造一个卷积层并打印其输出形状
下面我们构造一个卷积层并打印它的输出形状

```{.python .input n=2}
conv = nn.Conv2D(10, kernel_size=4, padding=1, strides=2)
Expand All @@ -28,7 +28,7 @@ y = conv(x)
y.shape
```

使用用样的卷积窗、填充和步幅的转置卷积层,我们可以得到跟`x`一样的输出。
使用用样的卷积窗、填充和步幅的转置卷积层,我们可以得到和`x`一样的输出。

```{.python .input n=3}
conv_trans = nn.Conv2DTranspose(3, kernel_size=4, padding=1, strides=2)
Expand Down Expand Up @@ -126,131 +126,38 @@ net[-1].initialize(init.Constant(trans_conv_weights))
net[-2].initialize(init=init.Xavier())
```

## 读取数据

我们使用较大的输入图片尺寸,其值选成了32的倍数。数据的读取方法已在上一节描述。

```{.python .input}
###
###
###
import os
from mxnet.gluon import data as gdata, utils as gutils
def load_data_pascal_voc(batch_size, output_shape):
"""Download the pascal voc dataest and then load into memory."""
voc_train = VOCSegDataset(True, output_shape)
voc_test = VOCSegDataset(False, output_shape)
train_iter = gdata.DataLoader(
voc_train, batch_size, shuffle=True, last_batch='discard',
num_workers=4)
test_iter = gdata.DataLoader(
voc_test, batch_size, last_batch='discard', num_workers=4)
return train_iter, test_iter
def _download_voc_pascal(data_dir='../data'):
voc_dir = os.path.join(data_dir, 'VOCdevkit/VOC2012')
url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012'
'/VOCtrainval_11-May-2012.tar')
sha1 = '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'
fname = gutils.download(url, data_dir, sha1_hash=sha1)
if not os.path.exists(os.path.join(voc_dir,
'ImageSets/Segmentation/train.txt')):
with tarfile.open(fname, 'r') as f:
f.extractall(data_dir)
return voc_dir
def read_voc_images(root='../data/VOCdevkit/VOC2012', train=True):
"""Read VOC images."""
txt_fname = '%s/ImageSets/Segmentation/%s' % (
root, 'train.txt' if train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
data, label = [None] * len(images), [None] * len(images)
for i, fname in enumerate(images):
data[i] = image.imread('%s/JPEGImages/%s.jpg' % (root, fname))
label[i] = image.imread('%s/SegmentationClass/%s.png' % (root, fname))
return data, label
class VOCSegDataset(gluon.data.Dataset):
"""The Pascal VOC2012 Dataset."""
def __init__(self, train, crop_size):
self.train = train
self.crop_size = crop_size
self.rgb_mean = nd.array([0.485, 0.456, 0.406])
self.rgb_std = nd.array([0.229, 0.224, 0.225])
self.voc_colormap = [[0, 0, 0], [128, 0, 0], [0, 128, 0],
[128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128],
[128, 128, 128], [64, 0, 0], [192, 0, 0],
[64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0],
[128, 192, 0], [0, 64, 128]]
self.voc_classes = ['background', 'aeroplane', 'bicycle', 'bird',
'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'potted plant', 'sheep', 'sofa',
'train', 'tv/monitor']
self.colormap2label = None
self.load_images()
def voc_label_indices(self, img):
if self.colormap2label is None:
self.colormap2label = nd.zeros(256 ** 3)
for i, cm in enumerate(self.voc_colormap):
self.colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
data = img.astype('int32')
idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
return self.colormap2label[idx]
def rand_crop(self, data, label, height, width):
data, rect = image.random_crop(data, (width, height))
label = image.fixed_crop(label, *rect)
return data, label
def load_images(self):
voc_dir = _download_voc_pascal()
data, label = read_voc_images(root=voc_dir, train=self.train)
self.data = [self.normalize_image(im) for im in self.filter(data)]
self.label = self.filter(label)
print('read '+ str(len(self.data)) + ' examples')
def normalize_image(self, data):
return (data.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
def filter(self, images):
return [im for im in images if (
im.shape[0] >= self.crop_size[0] and
im.shape[1] >= self.crop_size[1])]
def __getitem__(self, idx):
data, label = self.rand_crop(self.data[idx], self.label[idx],
*self.crop_size)
return data.transpose((2, 0, 1)), self.voc_label_indices(label)
def __len__(self):
return len(self.data)
input_shape = (320, 480)
batch_size = 32
colormap2label = nd.zeros(256**3)
for i, cm in enumerate(gb.voc_colormap):
colormap2label[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = i
voc_dir = gb.download_voc_pascal(data_dir='../data')
99
num_workers = 0 if sys.platform.startswith('win32') else 4
train_iter = gdata.DataLoader(
gb.VOCSegDataset(True, input_shape, voc_dir, colormap2label), batch_size,
shuffle=True, last_batch='discard', num_workers=num_workers)
test_iter = gdata.DataLoader(
gb.VOCSegDataset(False, input_shape, voc_dir, colormap2label), batch_size,
last_batch='discard', num_workers=num_workers)
```

## 训练

这时候我们可以真正开始训练了。我们使用较大的输入图片尺寸,其值选成了32的倍数。因为我们使用转置卷积层的通道来预测像素的类别,所以在做softmax是作用在通道这个维度(维度1),所以在`SoftmaxCrossEntropyLoss`里加入了额外了`axis=1`选项。
这时候我们可以真正开始训练了。因为我们使用转置卷积层的通道来预测像素的类别,所以在做softmax是作用在通道这个维度(维度1),所以在`SoftmaxCrossEntropyLoss`里加入了额外了`axis=1`选项。

```{.python .input n=12}
input_shape = (320, 480)
batch_size = 32
ctx = gb.try_all_gpus()
loss = gloss.SoftmaxCrossEntropyLoss(axis=1)
net.collect_params().reset_ctx(ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': 0.1, 'wd': 1e-3})
train_iter, test_iter = load_data_pascal_voc(batch_size, input_shape)
gb.train(train_iter, test_iter, net, loss, trainer, ctx, num_epochs=10)
gb.train(train_iter, test_iter, net, loss, trainer, ctx, num_epochs=5)
```

## 预测
Expand All @@ -270,8 +177,7 @@ def predict(im):

```{.python .input n=14}
def label2image(pred):
colormap = nd.array(
test_iter._dataset.voc_colormap, ctx=ctx[0], dtype='uint8')
colormap = nd.array(gb.voc_colormap, ctx=ctx[0], dtype='uint8')
x = pred.astype('int32')
return colormap[x,:]
```
Expand Down Expand Up @@ -300,7 +206,7 @@ gb.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n);
* 试着改改最后的转置卷积层的参数设定。
* 看看双线性差值初始化是不是必要的。
* 试着改改训练参数来使得收敛更好些。
* FCN论文[1]中提到了不只是使用主体卷积网络输出,还可以考虑其中间层的输出。试着实现这个想法。
* FCN论文 [1] 中提到了不只是使用主体卷积网络输出,还可以考虑其中间层的输出。试着实现这个想法。

## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/3041)

Expand All @@ -309,4 +215,4 @@ gb.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n);

## 参考文献

[1] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." CVPR. 2015.
[1] Long, J., Shelhamer, E., & Darrell, T. (2015). Fully convolutional networks for semantic segmentation. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 3431-3440).

0 comments on commit 0157f6a

Please sign in to comment.