Skip to content

Commit

Permalink
vgg in gluon
Browse files Browse the repository at this point in the history
  • Loading branch information
DragonFive authored Dec 22, 2017
1 parent cee69be commit a9b8c20
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions mxnet/vggnet.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# VGGNet\n",
"\n",
"重复网络结构的构建\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from mxnet.gluon import nn\n",
"def vgg_block3(num_convs,channels):\n",
" out = nn.Sequential()\n",
" for _ in range(num_convs):\n",
" out.add(nn.Conv2D(channels,kernel_size=3,padding=1,activation='relu'))\n",
" out.add(nn.MaxPool2D(pool_size=2,strides=2))\n",
" return out\n",
"def vgg_block3_1(num_convs,channels):\n",
" out = nn.Sequential()\n",
" for _ in range(num_convs):\n",
" out.add(nn.Conv2D(channels,kernel_size=3,padding=1,activation='relu'))\n",
" out.add(nn.Conv2D(channels,kernel_size=1,padding=1,activation='relu'),\n",
" nn.MaxPool2D(pool_size=2,strides=2))\n",
" return out\n",
"def vgg_fc(num_outputs):\n",
" out = nn.Sequential()\n",
" out.add(nn.Flatten(),\n",
" nn.Dense(4096,activation='relu'),\n",
" nn.Dropout(.5),\n",
" nn.Dense(4096,activation='relu'),\n",
" nn.Dropout(.5),\n",
" nn.Dense(num_outputs)\n",
" )\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def get_vgg(architecture,num_outputs):\n",
" net = nn.Sequential()\n",
" with net.name_scope():\n",
" for num_convs,channels in architecture:\n",
" net.add(vgg_block3(num_convs,channels))\n",
" net.add(vgg_fc(num_outputs))\n",
" return net\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"获取数据,初始化网络,设置loss和solver"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Start training on ', gpu(0))\n",
"Epoch 0. Loss: 1.011, Train acc 0.62, Test acc 0.81, Time 93.4 sec\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"import utils\n",
"from mxnet import init\n",
"from mxnet import gluon\n",
"\n",
"batch_size = 128\n",
"resize=128\n",
"train_data, test_data = utils.load_data_fashion_mnist(batch_size,resize=resize)\n",
"ctx = utils.try_gpu()\n",
"\n",
"num_outputs = 10\n",
"architecture = ((1,64),(1,128),(2,256),(2,512),(2,512))\n",
"net=get_vgg(architecture,num_outputs)\n",
"net.initialize(ctx=ctx,init=init.Xavier())\n",
"\n",
"loss = gluon.loss.SoftmaxCrossEntropyLoss()\n",
"#train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):\n",
"trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.05})\n",
"utils.train(train_data,test_data,net,loss,trainer,ctx,num_epochs=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:gluon]",
"language": "python",
"name": "conda-env-gluon-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit a9b8c20

Please sign in to comment.