diff --git a/fairseq/bleu.py b/fairseq/bleu.py index a211d7741c..6045ec6e26 100644 --- a/fairseq/bleu.py +++ b/fairseq/bleu.py @@ -35,6 +35,31 @@ class BleuStat(ctypes.Structure): ] +class SacrebleuScorer(object): + def __init__(self): + import sacrebleu + self.sacrebleu = sacrebleu + self.reset() + + def reset(self, one_init=False): + if one_init: + raise NotImplementedError + self.ref = [] + self.sys = [] + + def add_string(self, ref, pred): + self.ref.append(ref) + self.sys.append(pred) + + def score(self, order=4): + return self.result_string(order).bleu + + def result_string(self, order=4): + if order != 4: + raise NotImplementedError + return self.sacrebleu.corpus_bleu(self.sys, [self.ref]) + + class Scorer(object): def __init__(self, pad, eos, unk): self.stat = BleuStat() diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 8e060184b7..8b99f8add0 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -92,9 +92,10 @@ def check_size(idx): assert isinstance(idx_size, dict) intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) return all( - all(a is None or b is None or a <= b - for a, b in zip(idx_size[key], max_positions[key])) - for key in intersect_keys) + all(a is None or b is None or a <= b + for a, b in zip(idx_size[key], max_positions[key])) + for key in intersect_keys + ) else: return all(a is None or b is None or a <= b for a, b in zip(size_fn(idx), max_positions)) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 4d60daaa53..265be5a84f 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -57,8 +57,12 @@ def token_string(i): else: return self[i] - sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) - if bpe_symbol is not None: + if bpe_symbol == 'sentencepiece': + sent = ''.join(token_string(i) for i in tensor if i != self.eos()) + sent = sent.replace('\u2581', ' ').strip() + else: + sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) + if bpe_symbol is not None and bpe_symbol != 'sentencepiece': sent = (sent + ' ').replace(bpe_symbol, '').rstrip() return sent diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index a793e0cbb6..f52ebd4afe 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -66,18 +66,17 @@ def block_at(i): if curr_size > 0: self.slice_indices.append((tok_idx, tok_idx + curr_size)) elif break_mode == 'eos': - self.slice_indices = np.empty((sum(sizes > 1), 2), dtype=int) + self.slice_indices = np.empty((len(sizes), 2), dtype=int) curr = 0 for i, sz in enumerate(sizes): - # skip samples with just 1 example (which would be just the eos token) - if sz > 1: - self.slice_indices[i] = (curr, curr + sz) + self.slice_indices[i] = (curr, curr + sz) curr += sz else: raise ValueError('Invalid break_mode: ' + break_mode) self.sizes = np.array([e - s for s, e in self.slice_indices]) self.slice_indices = np.array(self.slice_indices, dtype=int) + # build index mapping block indices to the underlying dataset indices self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int) ds_idx, ds_remaining = -1, 0 diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 66e68d157e..e17c887aba 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -29,7 +29,7 @@ def add_args(parser): @classmethod def build_model(cls, args, task): """Build a new model instance.""" - raise NotImplementedError + raise NotImplementedError('FairseqModels must implement the build_model method') def get_targets(self, sample, net_output): """Get targets from either the sample or the net's output.""" diff --git a/fairseq/options.py b/fairseq/options.py index 9cd2becc14..c08101e609 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -298,7 +298,7 @@ def add_common_eval_args(group): group.add_argument('--path', metavar='FILE', help='path(s) to model file(s), colon separated') group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, - help='remove BPE tokens before scoring') + help='remove BPE tokens before scoring (can be set to sentencepiece)') group.add_argument('--quiet', action='store_true', help='only print final scores') group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', @@ -350,6 +350,8 @@ def add_generation_args(parser): help='unknown word penalty: <0 produces more unks, >0 produces fewer') group.add_argument('--replace-unk', nargs='?', const=True, default=None, help='perform unknown replacement (optionally with alignment dictionary)') + group.add_argument('--sacrebleu', action='store_true', + help='score with sacrebleu') group.add_argument('--score-reference', action='store_true', help='just score the reference translation') group.add_argument('--prefix-size', default=0, type=int, metavar='PS', diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 69fcb12446..15c27b4ef8 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -246,7 +246,8 @@ def sum_over_languages(key): for k, v in agg_logging_output.items() } flat_logging_output['loss'] = sum_over_languages('loss') - flat_logging_output['nll_loss'] = sum_over_languages('nll_loss') + if any('nll_loss' in logging_output for logging_output in agg_logging_outputs.values()): + flat_logging_output['nll_loss'] = sum_over_languages('nll_loss') flat_logging_output['sample_size'] = sum_over_languages('sample_size') flat_logging_output['nsentences'] = sum_over_languages('nsentences') flat_logging_output['ntokens'] = sum_over_languages('ntokens') diff --git a/fairseq/utils.py b/fairseq/utils.py index e3ba7fb08a..8df2670359 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -438,14 +438,12 @@ def nullsafe_min(l): def import_user_module(args): - if hasattr(args, 'user_dir'): - module_path = args.user_dir - - if module_path is not None: - module_path = os.path.abspath(args.user_dir) - module_parent, module_name = os.path.split(module_path) - - if module_name not in sys.modules: - sys.path.insert(0, module_parent) - importlib.import_module(module_name) - sys.path.pop(0) + module_path = getattr(args, 'user_dir', None) + if module_path is not None: + module_path = os.path.abspath(args.user_dir) + module_parent, module_name = os.path.split(module_path) + + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + sys.path.pop(0) diff --git a/generate.py b/generate.py index e4955a2abd..cc258a48f0 100644 --- a/generate.py +++ b/generate.py @@ -95,7 +95,10 @@ def main(args): translator.cuda() # Generate and compute BLEU score - scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) + if args.sacrebleu: + scorer = bleu.SacrebleuScorer() + else: + scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: @@ -160,7 +163,10 @@ def main(args): # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tokenizer.Tokenizer.tokenize( target_str, tgt_dict, add_if_not_exist=True) - scorer.add(target_tokens, hypo_tokens) + if hasattr(scorer, 'add_string'): + scorer.add_string(target_str, hypo_str) + else: + scorer.add(target_tokens, hypo_tokens) wps_meter.update(src_tokens.size(0)) t.log({'wps': round(wps_meter.avg)}) diff --git a/score.py b/score.py index 3ffc222f9a..184b431ffd 100644 --- a/score.py +++ b/score.py @@ -26,6 +26,8 @@ def get_parser(): type=int, help='consider ngrams up to this order') parser.add_argument('--ignore-case', action='store_true', help='case-insensitive scoring') + parser.add_argument('--sacrebleu', action='store_true', + help='score with sacrebleu') # fmt: on return parser @@ -49,14 +51,21 @@ def readlines(fd): else: yield line - def score(fdsys): - with open(args.ref) as fdref: - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) - for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): - sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) - ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) - scorer.add(ref_tok, sys_tok) - print(scorer.result_string(args.order)) + if args.sacrebleu: + import sacrebleu + + def score(fdsys): + with open(args.ref) as fdref: + print(sacrebleu.corpus_bleu(fdsys, [fdref])) + else: + def score(fdsys): + with open(args.ref) as fdref: + scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): + sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) + ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) + scorer.add(ref_tok, sys_tok) + print(scorer.result_string(args.order)) if args.sys == '-': score(sys.stdin) diff --git a/scripts/sacrebleu_pregen.sh b/scripts/sacrebleu_pregen.sh index 2599b94d6a..6fd3dd3c04 100755 --- a/scripts/sacrebleu_pregen.sh +++ b/scripts/sacrebleu_pregen.sh @@ -15,7 +15,7 @@ echo 'Cloning Moses github repository (for tokenization scripts)...' git clone https://github.com/moses-smt/mosesdecoder.git SCRIPTS=mosesdecoder/scripts -DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl grep ^H $GEN \ | sed 's/^H\-//' \ diff --git a/tests/test_token_block_dataset.py b/tests/test_token_block_dataset.py new file mode 100644 index 0000000000..d3fa9f3967 --- /dev/null +++ b/tests/test_token_block_dataset.py @@ -0,0 +1,80 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import unittest + +import torch + +from fairseq.data import TokenBlockDataset + +import tests.utils as test_utils + + +class TestTokenBlockDataset(unittest.TestCase): + + def _build_dataset(self, data, **kwargs): + sizes = [len(x) for x in data] + underlying_ds = test_utils.TestDataset(data) + return TokenBlockDataset(underlying_ds, sizes, **kwargs) + + def test_eos_break_mode(self): + data = [ + torch.LongTensor([5, 4, 3, 2, 1]), + torch.LongTensor([1]), # this should be filtered + torch.LongTensor([8, 7, 6, 1]), + ] + ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') + self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) + self.assertEqual(ds[1].tolist(), [1]) + self.assertEqual(ds[2].tolist(), [8, 7, 6, 1]) + + data = [ + torch.LongTensor([5, 4, 3, 2, 1]), + torch.LongTensor([8, 7, 6, 1]), + torch.LongTensor([1]), # this should be filtered + ] + ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos') + self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) + self.assertEqual(ds[1].tolist(), [8, 7, 6, 1]) + self.assertEqual(ds[2].tolist(), [1]) + + def test_block_break_mode(self): + data = [ + torch.LongTensor([5, 4, 3, 2, 1]), + torch.LongTensor([8, 7, 6, 1]), + torch.LongTensor([9, 1]), + ] + ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none') + self.assertEqual(ds[0].tolist(), [5, 4, 3]) + self.assertEqual(ds[1].tolist(), [2, 1, 8]) + self.assertEqual(ds[2].tolist(), [7, 6, 1]) + self.assertEqual(ds[3].tolist(), [9, 1]) + + def test_complete_break_mode(self): + data = [ + torch.LongTensor([5, 4, 3, 2, 1]), + torch.LongTensor([8, 7, 6, 1]), + torch.LongTensor([9, 1]), + ] + ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete') + self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1]) + self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1]) + + data = [ + torch.LongTensor([4, 3, 2, 1]), + torch.LongTensor([5, 1]), + torch.LongTensor([1]), + torch.LongTensor([6, 1]), + ] + ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete') + self.assertEqual(ds[0].tolist(), [4, 3, 2, 1]) + self.assertEqual(ds[1].tolist(), [5, 1, 1]) + self.assertEqual(ds[2].tolist(), [6, 1]) + + +if __name__ == "__main__": + unittest.main()