Skip to content

Commit

Permalink
add MRC to BatchGen & fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
anselmwang committed Sep 23, 2019
1 parent b8c2c56 commit df1b5f0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion experiments/glue/glue_prepro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import random
from sys import path

path.append(os.getcwd())
from experiments.common_utils import dump_rows

path.append(os.getcwd())
from data_utils.log_wrapper import create_logger
from experiments.glue.glue_utils import *

Expand Down
13 changes: 11 additions & 2 deletions mt_dnn/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,20 @@ def __iter__(self):
# in training model, label is used by Pytorch, so would be tensor
if self.task_type == TaskType.Regression:
batch_data.append(torch.FloatTensor(labels))
else:
batch_info['label'] = len(batch_data) - 1
elif self.task_type in (TaskType.Classification, TaskType.Ranking):
batch_data.append(torch.LongTensor(labels))
batch_info['label'] = len(batch_data) - 1
batch_info['label'] = len(batch_data) - 1
elif self.task_type == TaskType.Span:
start = [sample['token_start'] for sample in batch]
end = [sample['token_end'] for sample in batch]
batch_data.extend([torch.LongTensor(start), torch.LongTensor(end)])
batch_info['start'] = len(batch_data) - 2
batch_info['end'] = len(batch_data) - 1

# soft label generated by ensemble models for knowledge distillation
if self.soft_label_on and (batch[0].get('softlabel', None) is not None):
assert self.task_type != TaskType.Span # Span task doesn't support soft label yet.
sortlabels = [sample['softlabel'] for sample in batch]
sortlabels = torch.FloatTensor(sortlabels)
batch_info['soft_label'] = self.patch(sortlabels.pin_memory()) if self.gpu else sortlabels
Expand Down

0 comments on commit df1b5f0

Please sign in to comment.