forked from DragonFive/python_cv_AI_ML
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cee69be
commit a9b8c20
Showing
1 changed file
with
137 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# VGGNet\n", | ||
"\n", | ||
"重复网络结构的构建\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from mxnet.gluon import nn\n", | ||
"def vgg_block3(num_convs,channels):\n", | ||
" out = nn.Sequential()\n", | ||
" for _ in range(num_convs):\n", | ||
" out.add(nn.Conv2D(channels,kernel_size=3,padding=1,activation='relu'))\n", | ||
" out.add(nn.MaxPool2D(pool_size=2,strides=2))\n", | ||
" return out\n", | ||
"def vgg_block3_1(num_convs,channels):\n", | ||
" out = nn.Sequential()\n", | ||
" for _ in range(num_convs):\n", | ||
" out.add(nn.Conv2D(channels,kernel_size=3,padding=1,activation='relu'))\n", | ||
" out.add(nn.Conv2D(channels,kernel_size=1,padding=1,activation='relu'),\n", | ||
" nn.MaxPool2D(pool_size=2,strides=2))\n", | ||
" return out\n", | ||
"def vgg_fc(num_outputs):\n", | ||
" out = nn.Sequential()\n", | ||
" out.add(nn.Flatten(),\n", | ||
" nn.Dense(4096,activation='relu'),\n", | ||
" nn.Dropout(.5),\n", | ||
" nn.Dense(4096,activation='relu'),\n", | ||
" nn.Dropout(.5),\n", | ||
" nn.Dense(num_outputs)\n", | ||
" )\n", | ||
" return out" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_vgg(architecture,num_outputs):\n", | ||
" net = nn.Sequential()\n", | ||
" with net.name_scope():\n", | ||
" for num_convs,channels in architecture:\n", | ||
" net.add(vgg_block3(num_convs,channels))\n", | ||
" net.add(vgg_fc(num_outputs))\n", | ||
" return net\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"获取数据,初始化网络,设置loss和solver" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"('Start training on ', gpu(0))\n", | ||
"Epoch 0. Loss: 1.011, Train acc 0.62, Test acc 0.81, Time 93.4 sec\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import sys\n", | ||
"sys.path.append('..')\n", | ||
"import utils\n", | ||
"from mxnet import init\n", | ||
"from mxnet import gluon\n", | ||
"\n", | ||
"batch_size = 128\n", | ||
"resize=128\n", | ||
"train_data, test_data = utils.load_data_fashion_mnist(batch_size,resize=resize)\n", | ||
"ctx = utils.try_gpu()\n", | ||
"\n", | ||
"num_outputs = 10\n", | ||
"architecture = ((1,64),(1,128),(2,256),(2,512),(2,512))\n", | ||
"net=get_vgg(architecture,num_outputs)\n", | ||
"net.initialize(ctx=ctx,init=init.Xavier())\n", | ||
"\n", | ||
"loss = gluon.loss.SoftmaxCrossEntropyLoss()\n", | ||
"#train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):\n", | ||
"trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.05})\n", | ||
"utils.train(train_data,test_data,net,loss,trainer,ctx,num_epochs=5)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python [conda env:gluon]", | ||
"language": "python", | ||
"name": "conda-env-gluon-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |