Skip to content

Commit

Permalink
Completely replace concept "pairwise" with TaskType
Browse files Browse the repository at this point in the history
  • Loading branch information
anselmwang committed Sep 23, 2019
1 parent 0b7f73c commit 717fa2a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 27 deletions.
19 changes: 9 additions & 10 deletions mt_dnn/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self, data, batch_size=32, gpu=True, is_train=True,
maxlen=128, dropout_w=0.005,
do_batch=True, weighted_on=False,
task_id=0,
pairwise=False,
task=None,
task_type=TaskType.Classification,
data_type=DataFormat.PremiseOnly,
Expand All @@ -29,7 +28,6 @@ def __init__(self, data, batch_size=32, gpu=True, is_train=True,
self.weighted_on = weighted_on
self.data = data
self.task_id = task_id
self.pairwise = pairwise
self.pairwise_size = 1
self.data_type = data_type
self.task_type=task_type
Expand All @@ -50,7 +48,8 @@ def make_baches(data, batch_size=32):
return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]

@staticmethod
def load(path, is_train=True, maxlen=128, factor=1.0, pairwise=False):
def load(path, is_train=True, maxlen=128, factor=1.0, task_type=None):
assert task_type is not None
with open(path, 'r', encoding='utf-8') as reader:
data = []
cnt = 0
Expand All @@ -59,9 +58,9 @@ def load(path, is_train=True, maxlen=128, factor=1.0, pairwise=False):
sample['factor'] = factor
cnt += 1
if is_train:
if pairwise and (len(sample['token_id'][0]) > maxlen or len(sample['token_id'][1]) > maxlen):
if (task_type == TaskType.Ranking) and (len(sample['token_id'][0]) > maxlen or len(sample['token_id'][1]) > maxlen):
continue
if (not pairwise) and (len(sample['token_id']) > maxlen):
if (task_type != TaskType.Ranking) and (len(sample['token_id']) > maxlen):
continue
data.append(sample)
print('Loaded {} samples out of {}'.format(len(data), cnt))
Expand Down Expand Up @@ -111,10 +110,10 @@ def __if_pair__(self, data_type):
def __iter__(self):
while self.offset < len(self):
batch = self.data[self.offset]
if self.pairwise:
if self.task_type == TaskType.Ranking:
batch = self.rebacth(batch)

batch_size = len(batch)
batch_dict = {}
tok_len = max(len(x['token_id']) for x in batch)
hypothesis_len = max(len(x['type_id']) - sum(x['type_id']) for x in batch)
if self.encoder_type == EncoderModelType.ROBERTA:
Expand All @@ -128,7 +127,6 @@ def __iter__(self):
if self.__if_pair__(self.data_type):
premise_masks = torch.ByteTensor(batch_size, tok_len).fill_(1)
hypothesis_masks = torch.ByteTensor(batch_size, hypothesis_len).fill_(1)

for i, sample in enumerate(batch):
select_len = min(len(sample['token_id']), tok_len)
tok = sample['token_id']
Expand All @@ -142,6 +140,7 @@ def __iter__(self):
hypothesis_masks[i, :hlen] = torch.LongTensor([0] * hlen)
for j in range(hlen, select_len):
premise_masks[i, j] = 0

if self.__if_pair__(self.data_type):
batch_info = {
'token_id': 0,
Expand Down Expand Up @@ -185,13 +184,13 @@ def __iter__(self):
batch_info['uids'] = [sample['uid'] for sample in batch]
batch_info['task_id'] = self.task_id
batch_info['input_len'] = valid_input_len
batch_info['pairwise'] = self.pairwise
batch_info['task_type'] = self.task_type
batch_info['pairwise_size'] = self.pairwise_size
batch_info['task_type'] = self.task_type
if not self.is_train:
labels = [sample['label'] for sample in batch]
batch_info['label'] = labels
if self.pairwise:
if self.task_type == TaskType.Ranking:
batch_info['true_label'] = [sample['true_label'] for sample in batch]
self.offset += 1
yield batch_info, batch_data
6 changes: 3 additions & 3 deletions mt_dnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def update(self, batch_meta, batch_data):
if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
soft_labels = batch_meta['soft_label']

if batch_meta['pairwise']:
if batch_meta['task_type'] == TaskType.Ranking:
labels = labels.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0]
if self.config['cuda']:
y = labels.cuda(non_blocking=True)
Expand All @@ -162,7 +162,7 @@ def update(self, batch_meta, batch_data):
inputs.append(None)
inputs.append(task_id)
logits = self.mnetwork(*inputs)
if batch_meta['pairwise']:
if batch_meta['task_type'] == TaskType.Ranking:
logits = logits.view(-1, batch_meta['pairwise_size'])

if self.config.get('weighted_on', False):
Expand Down Expand Up @@ -225,7 +225,7 @@ def predict(self, batch_meta, batch_data):
inputs.append(None)
inputs.append(task_id)
score = self.mnetwork(*inputs)
if batch_meta['pairwise']:
if batch_meta['task_type'] == TaskType.Ranking:
score = score.contiguous().view(-1, batch_meta['pairwise_size'])
assert task_type == TaskType.Ranking
score = F.softmax(score, dim=1)
Expand Down
6 changes: 1 addition & 5 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,13 @@ def dump(path, data):
data_type = task_defs.data_type_map[args.task]
task_type = task_defs.task_type_map[args.task]
metric_meta = task_defs.metric_meta_map[args.task]
pw_task = False
if task_type == TaskType.Ranking:
pw_task = True

# load data
test_data = BatchGen(BatchGen.load(args.prep_input, False, pairwise=pw_task, maxlen=args.max_seq_len),
test_data = BatchGen(BatchGen.load(args.prep_input, False, task_type=task_type, maxlen=args.max_seq_len),
batch_size=args.batch_size_eval,
gpu=args.cuda, is_train=False,
task_id=args.task_id,
maxlen=args.max_seq_len,
pairwise=pw_task,
data_type=data_type,
task_type=task_type)

Expand Down
12 changes: 3 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ def main():
task_id = tasks_class[nclass] if nclass in tasks_class else len(tasks_class)

task_type = task_defs.task_type_map[prefix]
pw_task = False
if task_type == TaskType.Ranking:
pw_task = True

dopt = generate_decoder_opt(task_defs.enable_san_map[prefix], opt['answer_opt'])
if task_id < len(decoder_opts):
Expand All @@ -203,13 +200,12 @@ def main():

train_path = os.path.join(data_dir, '{}_train.json'.format(dataset))
logger.info('Loading {} as task {}'.format(train_path, task_id))
train_data = BatchGen(BatchGen.load(train_path, True, pairwise=pw_task, maxlen=args.max_seq_len),
train_data = BatchGen(BatchGen.load(train_path, True, task_type=task_type, maxlen=args.max_seq_len),
batch_size=batch_size,
dropout_w=args.dropout_w,
gpu=args.cuda,
task_id=task_id,
maxlen=args.max_seq_len,
pairwise=pw_task,
data_type=data_type,
task_type=task_type,
encoder_type=encoder_type)
Expand Down Expand Up @@ -237,12 +233,11 @@ def main():
dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
dev_data = None
if os.path.exists(dev_path):
dev_data = BatchGen(BatchGen.load(dev_path, False, pairwise=pw_task, maxlen=args.max_seq_len),
dev_data = BatchGen(BatchGen.load(dev_path, False, task_type=task_type, maxlen=args.max_seq_len),
batch_size=args.batch_size_eval,
gpu=args.cuda, is_train=False,
task_id=task_id,
maxlen=args.max_seq_len,
pairwise=pw_task,
data_type=data_type,
task_type=task_type,
encoder_type=encoder_type)
Expand All @@ -251,12 +246,11 @@ def main():
test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
test_data = None
if os.path.exists(test_path):
test_data = BatchGen(BatchGen.load(test_path, False, pairwise=pw_task, maxlen=args.max_seq_len),
test_data = BatchGen(BatchGen.load(test_path, False, task_type=task_type, maxlen=args.max_seq_len),
batch_size=args.batch_size_eval,
gpu=args.cuda, is_train=False,
task_id=task_id,
maxlen=args.max_seq_len,
pairwise=pw_task,
data_type=data_type,
task_type=task_type,
encoder_type=encoder_type)
Expand Down

0 comments on commit 717fa2a

Please sign in to comment.