forked from freewym/espresso
-
Notifications
You must be signed in to change notification settings - Fork 0
/
search.py
353 lines (284 loc) · 12.8 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
class Search(object):
def __init__(self, tgt_dict):
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.scores_buf = None
self.indices_buf = None
self.beams_buf = None
def _init_buffers(self, t):
if self.scores_buf is None:
self.scores_buf = t.new()
self.indices_buf = torch.LongTensor().to(device=t.device)
self.beams_buf = torch.LongTensor().to(device=t.device)
def step(self, step, lprobs, scores):
"""Take a single search step.
Args:
step: the current search step, starting at 0
lprobs: (bsz x input_beam_size x vocab_size)
the model's log-probabilities over the vocabulary at the current step
scores: (bsz x input_beam_size x step)
the historical model scores of each hypothesis up to this point
Return: A tuple of (scores, indices, beams) where:
scores: (bsz x output_beam_size)
the scores of the chosen elements; output_beam_size can be
larger than input_beam_size, e.g., we may return
2*input_beam_size to account for EOS
indices: (bsz x output_beam_size)
the indices of the chosen elements
beams: (bsz x output_beam_size)
the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
"""
raise NotImplementedError
def set_src_lengths(self, src_lengths):
self.src_lengths = src_lengths
class BeamSearch(Search):
def __init__(self, tgt_dict):
super().__init__(tgt_dict)
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
else:
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
torch.topk(
lprobs.view(bsz, -1),
k=min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
),
out=(self.scores_buf, self.indices_buf),
)
torch.div(self.indices_buf, vocab_size, out=self.beams_buf)
self.indices_buf.fmod_(vocab_size)
return self.scores_buf, self.indices_buf, self.beams_buf
class LengthConstrainedBeamSearch(Search):
def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
super().__init__(tgt_dict)
self.min_len_a = min_len_a
self.min_len_b = min_len_b
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
min_lens = self.min_len_a * self.src_lengths + self.min_len_b
max_lens = self.max_len_a * self.src_lengths + self.max_len_b
lprobs[step < min_lens, :, self.eos] = -math.inf
lprobs[step == max_lens, :, self.eos] = 0
lprobs[step > max_lens, :, self.eos] = -math.inf
return self.beam.step(step, lprobs, scores)
class DiverseBeamSearch(Search):
"""Diverse Beam Search.
See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
Models" for details.
We only implement the Hamming Diversity penalty here, which performed best
in the original paper.
"""
def __init__(self, tgt_dict, num_groups, diversity_strength):
super().__init__(tgt_dict)
self.num_groups = num_groups
self.diversity_strength = -diversity_strength
self.diversity_buf = None
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if beam_size % self.num_groups != 0:
raise ValueError(
'DiverseBeamSearch requires --beam to be divisible by the number of groups'
)
# initialize diversity penalty
if self.diversity_buf is None:
self.diversity_buf = lprobs.new()
torch.zeros(lprobs[:, 0, :].size(), out=self.diversity_buf)
scores_G, indices_G, beams_G = [], [], []
for g in range(self.num_groups):
lprobs_g = lprobs[:, g::self.num_groups, :]
scores_g = scores[:, g::self.num_groups, :] if step > 0 else None
# apply diversity penalty
if g > 0:
lprobs_g = torch.add(lprobs_g, self.diversity_strength, self.diversity_buf.unsqueeze(1))
else:
lprobs_g = lprobs_g.contiguous()
scores_buf, indices_buf, beams_buf = self.beam.step(step, lprobs_g, scores_g)
beams_buf.mul_(self.num_groups).add_(g)
scores_G.append(scores_buf.clone())
indices_G.append(indices_buf.clone())
beams_G.append(beams_buf.clone())
# update diversity penalty
self.diversity_buf.scatter_add_(
1,
indices_buf,
self.diversity_buf.new_ones(indices_buf.size())
)
# interleave results from different groups
self.scores_buf = torch.stack(scores_G, dim=2, out=self.scores_buf).view(bsz, -1)
self.indices_buf = torch.stack(indices_G, dim=2, out=self.indices_buf).view(bsz, -1)
self.beams_buf = torch.stack(beams_G, dim=2, out=self.beams_buf).view(bsz, -1)
return self.scores_buf, self.indices_buf, self.beams_buf
class Sampling(Search):
def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
super().__init__(tgt_dict)
self.sampling_topk = sampling_topk
self.sampling_topp = sampling_topp
def _sample_topp(self, lprobs):
"""Sample among the smallest set of elements whose cumulative probability mass exceeds p.
See `"The Curious Case of Neural Text Degeneration"
(Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.
Args:
lprobs: (bsz x input_beam_size x vocab_size)
the model's log-probabilities over the vocabulary at the current step
Return: A tuple of (trimed_probs, truncated_indices) where:
trimed_probs: (bsz x input_beam_size x ?)
the model's probabilities over the elements selected to sample from. The
width of the third dimension is determined by top-P.
truncated_indices: (bsz x input_beam_size x ?)
the indices of the chosen elements.
"""
probs = lprobs.exp_()
# sort the last dimension (vocab dimension) in descending order
sorted_probs, sorted_indices = probs.sort(descending=True)
# compute a mask to indicate the words to be included in the top-P set.
cumsum_probs = sorted_probs.cumsum(dim=2)
mask = cumsum_probs.lt(self.sampling_topp)
# note that mask was computed by 'lt'. One more word needs to be included
# so that the cumulative probability mass can exceed p.
cumsum_mask = mask.cumsum(dim=2)
last_included = cumsum_mask[:, :, -1:]
last_included.clamp_(0, mask.size()[2] - 1)
mask = mask.scatter_(2, last_included, 1)
# truncate unnecessary dims.
max_dim = last_included.max()
truncated_mask = mask[:, :, :max_dim + 1]
truncated_probs = sorted_probs[:, :, :max_dim + 1]
truncated_indices = sorted_indices[:, :, :max_dim + 1]
# trim the words that are not in top-P by setting their probabilities
# to 0, so that they would not be sampled later.
trim_mask = (~truncated_mask)
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
return trimed_probs, truncated_indices
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs = lprobs[:, ::beam_size, :].contiguous()
if self.sampling_topp > 0:
# only sample from the smallest set of words whose cumulative probability mass exceeds p
probs, top_indices = self._sample_topp(lprobs)
elif self.sampling_topk > 0:
# only sample from top-k candidates
lprobs, top_indices = lprobs.topk(self.sampling_topk)
probs = lprobs.exp_()
else:
probs = lprobs.exp_()
# sample
if step == 0:
self.indices_buf = torch.multinomial(
probs.view(bsz, -1),
beam_size,
replacement=True,
out=self.indices_buf,
).view(bsz, beam_size)
else:
self.indices_buf = torch.multinomial(
probs.view(bsz * beam_size, -1),
1,
replacement=True,
out=self.indices_buf,
).view(bsz, beam_size)
if step == 0:
# expand to beam size
probs = probs.expand(bsz, beam_size, -1)
# gather scores
torch.gather(
probs,
dim=2,
index=self.indices_buf.unsqueeze(-1),
out=self.scores_buf,
)
self.scores_buf = self.scores_buf.log_().view(bsz, -1)
# remap indices if using top-k or top-P sampling
if self.sampling_topk > 0 or self.sampling_topp > 0:
self.indices_buf = torch.gather(
top_indices.expand(bsz, beam_size, -1),
dim=2,
index=self.indices_buf.unsqueeze(-1),
).squeeze(2)
if step == 0:
self.beams_buf = self.indices_buf.new_zeros(bsz, beam_size)
else:
self.beams_buf = torch.arange(0, beam_size, out=self.beams_buf).repeat(bsz, 1)
# make scores cumulative
self.scores_buf.add_(
torch.gather(
scores[:, :, step - 1],
dim=1,
index=self.beams_buf,
)
)
return self.scores_buf, self.indices_buf, self.beams_buf
class DiverseSiblingsSearch(Search):
"""
Beam search with diverse siblings.
See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details.
https://arxiv.org/abs/1611.08562
1/ Calculate hypotheses for each beam
2/ Intra-sibling ordering
3/ Rewrite scores
4/ Choose top K hypotheses
if diversity_rate == 0 is equivalent to BeamSearch
"""
def __init__(self, tgt_dict, diversity_rate):
super().__init__(tgt_dict)
self.diversity_rate = diversity_rate
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
super()._init_buffers(lprobs)
bsz, beam_size, vocab_size = lprobs.size()
k = min(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size * 2,
lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
)
s_list = [lprobs.new() for i in range(beam_size)]
i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)]
sibling_score = lprobs.new(range(1, k + 1)) * self.diversity_rate
if step == 0:
return self.beam.step(step, lprobs, scores)
lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
# 1/ Calculate hypotheses for each beam
for i in range(beam_size):
torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i]))
i_list[i].fmod_(vocab_size)
# 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores
s_list[i].sub_(sibling_score)
# 4/ Choose top K hypotheses
indices = torch.stack(i_list, dim=1).view(bsz, -1)
final_scores = lprobs.new()
final_indices = torch.LongTensor().to(device=lprobs.device)
final_beams = torch.LongTensor().to(device=lprobs.device)
torch.topk(
torch.stack(s_list, dim=1).view(bsz, -1),
k,
out=(final_scores, final_indices),
)
torch.div(final_indices, k, out=final_beams)
for i in range(bsz):
final_indices[i] = indices[i][final_indices[i]]
return final_scores, final_indices, final_beams