Skip to content

Commit

Permalink
Temporarily revert the fill-mask improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Jun 23, 2021
1 parent 4bdff2c commit 941b444
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 81 deletions.
78 changes: 26 additions & 52 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
args (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of prompts) with masked tokens.
targets (:obj:`str` or :obj:`List[str]`, `optional`):
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
resulting token will be used (with a warning, and that might be slower).
When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be
tokenized and the first resulting token will be used (with a warning).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.
Expand All @@ -115,56 +115,25 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
inputs = self._parse_and_tokenize(*args, **kwargs)
outputs = self._forward(inputs, return_tensors=True)

# top_k must be defined
if top_k is None:
top_k = self.top_k

results = []
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

if targets is not None:
if len(targets) == 0 or len(targets[0]) == 0:
raise ValueError("At least one target must be provided when passed.")
if isinstance(targets, str):
targets = [targets]

try:
vocab = self.tokenizer.get_vocab()
except Exception:
vocab = {}
target_ids = []
targets_proc = []
for target in targets:
id_ = vocab.get(target, None)
if id_ is None:
input_ids = self.tokenizer(
target,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
max_length=1,
truncation=True,
)["input_ids"]
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
# XXX: If users encounter this pass
# it becomes pretty slow, so let's make sure
# The warning enables them to fix the input to
# get faster performance.
target_enc = self.tokenizer.tokenize(target)
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
f"Replacing with `{target_enc[0]}`."
)
target_ids.append(id_)
target_ids = list(set(target_ids))
if len(target_ids) == 0:
raise ValueError("At least one target must be provided when passed.")
target_ids = np.array(target_ids)
# Cap top_k if there are targets
if top_k > target_ids.shape[0]:
top_k = target_ids.shape[0]
targets_proc.append(target_enc[0])
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))

for i in range(batch_size):
input_ids = inputs["input_ids"][i]
Expand All @@ -178,11 +147,14 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):

logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if targets is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))

topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
if targets is None:
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
sort_inds = tf.reverse(tf.argsort(values), [0])
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
predictions = target_inds[sort_inds.numpy()]
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)

Expand All @@ -191,11 +163,13 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):

logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)

if targets is not None:
probs = probs[..., target_ids]

values, predictions = probs.topk(top_k)
if targets is None:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else:
values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))
values = values[..., sort_inds]
predictions = target_inds[sort_inds]

for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
Expand Down
33 changes: 4 additions & 29 deletions tests/test_pipelines_fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def test_torch_fill_mask(self):
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
Expand All @@ -90,34 +89,10 @@ def test_torch_fill_mask_with_targets(self):
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)

@require_torch
def test_torch_fill_mask_with_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
targets = [" Teven", "ĠPatrick", "ĠClara"]
top_k = 2
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)

self.assertEqual(len(outputs), 2)

@require_torch
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
# String duplicates + id duplicates
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
top_k = 10
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)

# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)

@require_tf
def test_tf_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
Expand All @@ -136,7 +111,7 @@ def test_torch_fill_mask_results(self):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = ["ĠPatrick", "ĠClara"]
valid_targets = [" Patrick", " Clara"]
for model_name in self.large_models:
unmasker = pipeline(
task="fill-mask",
Expand Down Expand Up @@ -209,7 +184,7 @@ def test_tf_fill_mask_results(self):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = ["ĠPatrick", "ĠClara"]
valid_targets = [" Patrick", " Clara"]
for model_name in self.large_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)

Expand Down

0 comments on commit 941b444

Please sign in to comment.