Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚠️⚠️[T5Tokenize] Fix T5 family tokenizers⚠️⚠️ #24565

Merged
merged 12 commits into from
Jun 30, 2023
13 changes: 12 additions & 1 deletion src/transformers/models/t5/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import sentencepiece as spm

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TextInput
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
from ...utils import logging


Expand Down Expand Up @@ -51,6 +52,8 @@
"t5-11b": 512,
}

SPIECE_UNDERLINE = "▁"


class T5Tokenizer(PreTrainedTokenizer):
"""
Expand Down Expand Up @@ -294,9 +297,17 @@ def __setstate__(self, d):
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

def tokenize(self, text: TextInput, **kwargs) -> List[str]:
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
if not text.startswith(" "):
text = " " + text
return super().tokenize(text, **kwargs)

def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
return self.sp_model.encode(text, out_type=str)
tokens = self.sp_model.encode(text, out_type=str)
if not text.startswith(" ") and tokens[0] == SPIECE_UNDERLINE:
tokens = tokens[1:]
return tokens

def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def test_small_generate(self):
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)
model = model.to(torch_device)

input_ids = tokenizer(
Expand All @@ -1160,24 +1160,24 @@ def test_small_generate(self):
self.assertEqual(output_str, "drink.")

input_ids = tokenizer(
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>.",
return_tensors="pt",
).input_ids.to(torch_device)
sequences = model.generate(input_ids)
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0]

EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> salt<extra_id_4>.</s>"
EXPECTED_OUTPUT = "<pad><extra_id_0> man<extra_id_1> beer<extra_id_2> a<extra_id_3> whiskey<extra_id_4>.</s>"
self.assertEqual(output_str, EXPECTED_OUTPUT)

def test_small_batch_generate(self):
BATCH_SIZE = 4
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8", torch_dtype=torch.bfloat16
).eval()
tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small", use_fast=False)

inputs = [
"A <extra_id_0> walks into a bar a orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
"A <extra_id_0> walks into a bar and orders a <extra_id_1> with <extra_id_2> pinch of <extra_id_3>."
] * BATCH_SIZE
encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")

Expand Down
32 changes: 32 additions & 0 deletions tests/models/t5/test_tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,35 @@ def test_get_sentinel_tokens_for_fasttokenizer(self):
def test_get_sentinel_token_ids_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))

def test_encode_extra_ids(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
tokenizer.add_special_tokens({"additional_special_tokens": ["<extra_id_0>"]})
tokenizer._create_trie(tokenizer.all_special_tokens)
# TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
# So the extra ids are split....

input_ids = tokenizer.encode(". Hello")
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
tokens = tokenizer.tokenize(". Hello")
self.assertEquals(tokens, ["▁", ".", "▁He", "ll", "o"])

input_ids = tokenizer.encode(" . Hello")
self.assertEquals(input_ids, [7, 4, 156, 86, 20, 2])
tokens = tokenizer.tokenize(" . Hello")
self.assertEquals(tokens, ["▁", ".", "▁He", "ll", "o"])

input_ids = tokenizer.encode("Hello, <extra_id_0>I")
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 8, 2])
tokens = tokenizer.tokenize("Hello, <extra_id_0>I")
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", "▁I"])

input_ids = tokenizer.encode("Hello, <extra_id_0>,")
self.assertEquals(input_ids, [156, 86, 20, 3, 999, 3, 2])
tokens = tokenizer.tokenize("Hello, <extra_id_0>,")
self.assertEquals(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])

input_ids = tokenizer.encode(" <extra_id_0> ,")
self.assertEquals(input_ids, [999, 3, 2])
tokens = tokenizer.tokenize(" <extra_id_0> ,")
self.assertEquals(tokens, ["<extra_id_0>", ","]) # spaces are eaten by rstrip / lstrip