Skip to content

Commit

Permalink
pr4 : add W parameter in ipynb file
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode committed Feb 3, 2019
1 parent 8f61d60 commit 49790c4
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions 1-1.NNLM/NNLM_Torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
"metadata": {
"id": "mvlw9p3tPJjr",
"colab_type": "code",
"outputId": "a9d7624b-4a3b-4078-9a89-11c2d6d177d5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 161
},
"outputId": "307285fc-ce69-4fb1-caea-1841ffd0bfd8"
}
},
"cell_type": "code",
"source": [
Expand Down Expand Up @@ -68,14 +68,15 @@
" super(NNLM, self).__init__()\n",
"\n",
" self.H = nn.Parameter(torch.randn(n_step * n_class, n_hidden).type(dtype))\n",
" self.W = nn.Parameter(torch.randn(n_step * n_class, n_class).type(dtype))\n",
" self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n",
" self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n",
" self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n",
"\n",
" def forward(self, X):\n",
" input = X.view(-1, n_step * n_class) # [batch_size, n_step * n_class]\n",
" tanh = nn.functional.tanh(self.d + torch.mm(input, self.H)) # [batch_size, n_hidden]\n",
" output = torch.mm(tanh, self.U) + self.b # [batch_size, n_class]\n",
" output = self.b + torch.mm(input, self.W) + torch.mm(tanh, self.U) # [batch_size, n_class]\n",
" return output\n",
"\n",
"model = NNLM()\n",
Expand Down Expand Up @@ -107,7 +108,7 @@
"# Test\n",
"print([sen.split()[:2] for sen in sentences], '->', [number_dict[n.item()] for n in predict.squeeze()])"
],
"execution_count": 1,
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
Expand All @@ -120,11 +121,11 @@
{
"output_type": "stream",
"text": [
"Epoch: 1000 cost = 1.209171\n",
"Epoch: 2000 cost = 1.115136\n",
"Epoch: 3000 cost = 0.219242\n",
"Epoch: 4000 cost = 0.059575\n",
"Epoch: 5000 cost = 0.027686\n",
"Epoch: 1000 cost = 0.283353\n",
"Epoch: 2000 cost = 0.058013\n",
"Epoch: 3000 cost = 0.023128\n",
"Epoch: 4000 cost = 0.011383\n",
"Epoch: 5000 cost = 0.006090\n",
"[['i', 'like'], ['i', 'love'], ['i', 'hate']] -> ['dog', 'coffee', 'milk']\n"
],
"name": "stdout"
Expand Down

0 comments on commit 49790c4

Please sign in to comment.