forked from freewym/espresso
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bleu.py
129 lines (105 loc) · 3.86 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes
import math
import torch
try:
from fairseq import libbleu
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libbleu.so. run `pip install --editable .`\n')
raise e
C = ctypes.cdll.LoadLibrary(libbleu.__file__)
class BleuStat(ctypes.Structure):
_fields_ = [
('reflen', ctypes.c_size_t),
('predlen', ctypes.c_size_t),
('match1', ctypes.c_size_t),
('count1', ctypes.c_size_t),
('match2', ctypes.c_size_t),
('count2', ctypes.c_size_t),
('match3', ctypes.c_size_t),
('count3', ctypes.c_size_t),
('match4', ctypes.c_size_t),
('count4', ctypes.c_size_t),
]
class SacrebleuScorer(object):
def __init__(self):
import sacrebleu
self.sacrebleu = sacrebleu
self.reset()
def reset(self, one_init=False):
if one_init:
raise NotImplementedError
self.ref = []
self.sys = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.sys.append(pred)
def score(self, order=4):
return self.result_string(order).score
def result_string(self, order=4):
if order != 4:
raise NotImplementedError
return self.sacrebleu.corpus_bleu(self.sys, [self.ref])
class Scorer(object):
def __init__(self, pad, eos, unk):
self.stat = BleuStat()
self.pad = pad
self.eos = eos
self.unk = unk
self.reset()
def reset(self, one_init=False):
if one_init:
C.bleu_one_init(ctypes.byref(self.stat))
else:
C.bleu_zero_init(ctypes.byref(self.stat))
def add(self, ref, pred):
if not isinstance(ref, torch.IntTensor):
raise TypeError('ref must be a torch.IntTensor (got {})'
.format(type(ref)))
if not isinstance(pred, torch.IntTensor):
raise TypeError('pred must be a torch.IntTensor(got {})'
.format(type(pred)))
# don't match unknown words
rref = ref.clone()
assert not rref.lt(0).any()
rref[rref.eq(self.unk)] = -999
rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)
C.bleu_add(
ctypes.byref(self.stat),
ctypes.c_size_t(rref.size(0)),
ctypes.c_void_p(rref.data_ptr()),
ctypes.c_size_t(pred.size(0)),
ctypes.c_void_p(pred.data_ptr()),
ctypes.c_int(self.pad),
ctypes.c_int(self.eos))
def score(self, order=4):
psum = sum(math.log(p) if p > 0 else float('-Inf')
for p in self.precision()[:order])
return self.brevity() * math.exp(psum / order) * 100
def precision(self):
def ratio(a, b):
return a / b if b > 0 else 0
return [
ratio(self.stat.match1, self.stat.count1),
ratio(self.stat.match2, self.stat.count2),
ratio(self.stat.match3, self.stat.count3),
ratio(self.stat.match4, self.stat.count4),
]
def brevity(self):
r = self.stat.reflen / self.stat.predlen
return min(1, math.exp(1 - r))
def result_string(self, order=4):
assert order <= 4, "BLEU scores for order > 4 aren't supported"
fmt = 'BLEU{} = {:2.2f}, {:2.1f}'
for _ in range(1, order):
fmt += '/{:2.1f}'
fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})'
bleup = [p * 100 for p in self.precision()[:order]]
return fmt.format(order, self.score(order=order), *bleup,
self.brevity(), self.stat.predlen/self.stat.reflen,
self.stat.predlen, self.stat.reflen)