Skip to content

Commit

Permalink
Merge internal changes (facebookresearch#483)
Browse files Browse the repository at this point in the history
Summary:
Changelog:
- `4889802`: can now remove detokenize sentencepiece output with `--remove-bpe=sentencepiece` (fixes facebookresearch#331). Also added `--sacrebleu` for computing detokenized BLEU.
- `0d76427`: fix assertion error when training language model with dataset containing empty sentences
- minor bug and style fixes
Pull Request resolved: facebookresearch#483

Differential Revision: D13867899

Pulled By: myleott

fbshipit-source-id: 25c940b847fe270262ac8f5ac838407b3977fdda
  • Loading branch information
myleott authored and facebook-github-bot committed Jan 30, 2019
1 parent 66ce217 commit 42be3eb
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 34 deletions.
25 changes: 25 additions & 0 deletions fairseq/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,31 @@ class BleuStat(ctypes.Structure):
]


class SacrebleuScorer(object):
def __init__(self):
import sacrebleu
self.sacrebleu = sacrebleu
self.reset()

def reset(self, one_init=False):
if one_init:
raise NotImplementedError
self.ref = []
self.sys = []

def add_string(self, ref, pred):
self.ref.append(ref)
self.sys.append(pred)

def score(self, order=4):
return self.result_string(order).bleu

def result_string(self, order=4):
if order != 4:
raise NotImplementedError
return self.sacrebleu.corpus_bleu(self.sys, [self.ref])


class Scorer(object):
def __init__(self, pad, eos, unk):
self.stat = BleuStat()
Expand Down
7 changes: 4 additions & 3 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def check_size(idx):
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
all(a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key]))
for key in intersect_keys)
all(a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key]))
for key in intersect_keys
)
else:
return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions))
Expand Down
8 changes: 6 additions & 2 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def token_string(i):
else:
return self[i]

sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
if bpe_symbol == 'sentencepiece':
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent

Expand Down
7 changes: 3 additions & 4 deletions fairseq/data/token_block_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,17 @@ def block_at(i):
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
self.slice_indices = np.empty((sum(sizes > 1), 2), dtype=int)
self.slice_indices = np.empty((len(sizes), 2), dtype=int)
curr = 0
for i, sz in enumerate(sizes):
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices[i] = (curr, curr + sz)
self.slice_indices[i] = (curr, curr + sz)
curr += sz
else:
raise ValueError('Invalid break_mode: ' + break_mode)

self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int)

# build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def add_args(parser):
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError
raise NotImplementedError('FairseqModels must implement the build_model method')

def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
Expand Down
4 changes: 3 additions & 1 deletion fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
help='remove BPE tokens before scoring (can be set to sentencepiece)')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
Expand Down Expand Up @@ -350,6 +350,8 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
Expand Down
3 changes: 2 additions & 1 deletion fairseq/tasks/multilingual_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def sum_over_languages(key):
for k, v in agg_logging_output.items()
}
flat_logging_output['loss'] = sum_over_languages('loss')
flat_logging_output['nll_loss'] = sum_over_languages('nll_loss')
if any('nll_loss' in logging_output for logging_output in agg_logging_outputs.values()):
flat_logging_output['nll_loss'] = sum_over_languages('nll_loss')
flat_logging_output['sample_size'] = sum_over_languages('sample_size')
flat_logging_output['nsentences'] = sum_over_languages('nsentences')
flat_logging_output['ntokens'] = sum_over_languages('ntokens')
Expand Down
20 changes: 9 additions & 11 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,12 @@ def nullsafe_min(l):


def import_user_module(args):
if hasattr(args, 'user_dir'):
module_path = args.user_dir

if module_path is not None:
module_path = os.path.abspath(args.user_dir)
module_parent, module_name = os.path.split(module_path)

if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
module_path = getattr(args, 'user_dir', None)
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
module_parent, module_name = os.path.split(module_path)

if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
10 changes: 8 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def main(args):
translator.cuda()

# Generate and compute BLEU score
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
if args.sacrebleu:
scorer = bleu.SacrebleuScorer()
else:
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
Expand Down Expand Up @@ -160,7 +163,10 @@ def main(args):
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)

wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
Expand Down
25 changes: 17 additions & 8 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def get_parser():
type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring')
parser.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
# fmt: on
return parser

Expand All @@ -49,14 +51,21 @@ def readlines(fd):
else:
yield line

def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
if args.sacrebleu:
import sacrebleu

def score(fdsys):
with open(args.ref) as fdref:
print(sacrebleu.corpus_bleu(fdsys, [fdref]))
else:
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))

if args.sys == '-':
score(sys.stdin)
Expand Down
2 changes: 1 addition & 1 deletion scripts/sacrebleu_pregen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git

SCRIPTS=mosesdecoder/scripts
DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl

grep ^H $GEN \
| sed 's/^H\-//' \
Expand Down
80 changes: 80 additions & 0 deletions tests/test_token_block_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import unittest

import torch

from fairseq.data import TokenBlockDataset

import tests.utils as test_utils


class TestTokenBlockDataset(unittest.TestCase):

def _build_dataset(self, data, **kwargs):
sizes = [len(x) for x in data]
underlying_ds = test_utils.TestDataset(data)
return TokenBlockDataset(underlying_ds, sizes, **kwargs)

def test_eos_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([1]), # this should be filtered
torch.LongTensor([8, 7, 6, 1]),
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [1])
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])

data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([1]), # this should be filtered
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
self.assertEqual(ds[2].tolist(), [1])

def test_block_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
self.assertEqual(ds[0].tolist(), [5, 4, 3])
self.assertEqual(ds[1].tolist(), [2, 1, 8])
self.assertEqual(ds[2].tolist(), [7, 6, 1])
self.assertEqual(ds[3].tolist(), [9, 1])

def test_complete_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
]
ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])

data = [
torch.LongTensor([4, 3, 2, 1]),
torch.LongTensor([5, 1]),
torch.LongTensor([1]),
torch.LongTensor([6, 1]),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [5, 1, 1])
self.assertEqual(ds[2].tolist(), [6, 1])


if __name__ == "__main__":
unittest.main()

0 comments on commit 42be3eb

Please sign in to comment.