Skip to content

Commit

Permalink
graykode#22 edit comment shape mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode committed Sep 26, 2019
1 parent 6e171b9 commit cb4881e
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions 5-2.BERT/BERT-Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ def forward(self, input_ids, segment_ids, masked_pos):
h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]
logits_clsf = self.classifier(h_pooled) # [batch_size, 2]

masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, maxlen, d_model]
h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, len, d_model]
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
# get masked position from final output of transformer.
h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
h_masked = self.norm(self.activ2(self.linear(h_masked)))
logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, maxlen, n_vocab]
logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

return logits_lm, logits_clsf

Expand Down Expand Up @@ -239,4 +240,4 @@ def forward(self, input_ids, segment_ids, masked_pos):

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)
print('predict isNext : ',True if logits_clsf else False)

0 comments on commit cb4881e

Please sign in to comment.