From d26b37e744ea980977e266adf48736451b73c583 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 10 Mar 2021 21:42:04 +0530 Subject: [PATCH] Speech2TextTransformer (#10175) * s2t * fix config * conversion script * fix import * add tokenizer * fix tok init * fix tokenizer * first version working * fix embeds * fix lm head * remove extra heads * fix convert script * handle encoder attn mask * style * better enc attn mask * override _prepare_attention_mask_for_generation * handle attn_maks in encoder and decoder * input_ids => input_features * enable use_cache * remove old code * expand embeddings if needed * remove logits bias * masked_lm_loss => loss * hack tokenizer to support feature processing * fix model_input_names * style * fix error message * doc * remove inputs_embeds * remove input_embeds * remove unnecessary docstring * quality * SpeechToText => Speech2Text * style * remove shared_embeds * subsample => conv * remove Speech2TextTransformerDecoderWrapper * update output_lengths formula * fix table * remove max_position_embeddings * update conversion scripts * add possibility to do upper case for now * add FeatureExtractor and Processor * add tests for extractor * require_torch_audio => require_torchaudio * add processor test * update import * remove classification head * attention mask is now 1D * update docstrings * attention mask should be of type long * handle attention mask from generate * alwyas return attention_mask * fix test * style * doc * Speech2TextTransformer => Speech2Text * Speech2TextTransformerConfig => Speech2TextConfig * remove dummy_inputs * nit * style * multilinguial tok * fix tokenizer * add tgt_lang setter * save lang_codes * fix tokenizer * add forced_bos_token_id to tokenizer * apply review suggestions * add torchaudio to extra deps * add speech deps to CI * fix dep * add libsndfile to ci * libsndfile1 * add speech to extras all * libsndfile1 -> libsndfile1 * libsndfile * libsndfile1-dev * apt update * add sudo to install * update deps table * install libsndfile1-dev on CI * tuple to list * init conv layer * add model tests * quality * add integration tests * skip_special_tokens * add speech_to_text_transformer in toctree * fix tokenizer * fix fp16 tests * add tokenizer tests * fix copyright * input_values => input_features * doc * add model in readme * doc * change checkpoint names * fix copyright * fix code example * add max_model_input_sizes in tokenizer * fix integration tests * add do_lower_case to tokenizer * remove clamp trick * fix "Add modeling imports here" * fix copyrights * fix tests * SpeechToTextTransformer => SpeechToText * fix naming * fix table formatting * fix typo * style * fix typos * remove speech dep from extras[testing] * fix copies * rename doc file, * put imports under is_torch_available * run feat extract tests when torch is available * dummy objects for processor and extractor * fix imports in tests * fix import in modeling test * fxi imports * fix torch import * fix imports again * fix positional embeddings * fix typo in import * adapt new extractor refactor * style * fix torchscript test * doc * doc * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen * fix docs, copied from, style * fix docstring * handle imports * remove speech from all extra deps * remove s2t from seq2seq lm mapping * better names * skip training tests * add install instructions * List => Tuple * doc * fix conversion script * fix urls * add instruction for libsndfile * fix fp16 test Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .circleci/config.yml | 13 +- README.md | 1 + docs/source/index.rst | 24 +- docs/source/model_doc/speech_to_text.rst | 152 ++ setup.cfg | 1 + setup.py | 4 +- src/transformers/__init__.py | 26 + src/transformers/dependency_versions_table.py | 1 + src/transformers/file_utils.py | 11 + src/transformers/generation_utils.py | 5 +- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 7 + src/transformers/models/auto/modeling_auto.py | 6 +- .../models/speech_to_text/__init__.py | 77 + .../configuration_speech_to_text.py | 200 +++ .../convert_s2t_fairseq_to_tfms.py | 112 ++ .../feature_extraction_speech_to_text.py | 225 +++ .../speech_to_text/modeling_speech_to_text.py | 1353 +++++++++++++++++ .../processing_speech_to_text.py | 144 ++ .../tokenization_speech_to_text.py | 259 ++++ src/transformers/testing_utils.py | 14 + src/transformers/utils/dummy_pt_objects.py | 21 + .../utils/dummy_sentencepiece_objects.py | 14 + .../test_feature_extraction_speech_to_text.py | 146 ++ tests/test_generation_utils.py | 5 +- tests/test_modeling_speech_to_text.py | 754 +++++++++ tests/test_processor_speech_to_text.py | 144 ++ ...test_sequence_feature_extraction_common.py | 4 +- tests/test_tokenization_speech_to_text.py | 129 ++ utils/check_repo.py | 4 + 30 files changed, 3833 insertions(+), 24 deletions(-) create mode 100644 docs/source/model_doc/speech_to_text.rst create mode 100644 src/transformers/models/speech_to_text/__init__.py create mode 100644 src/transformers/models/speech_to_text/configuration_speech_to_text.py create mode 100644 src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py create mode 100644 src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py create mode 100755 src/transformers/models/speech_to_text/modeling_speech_to_text.py create mode 100644 src/transformers/models/speech_to_text/processing_speech_to_text.py create mode 100644 src/transformers/models/speech_to_text/tokenization_speech_to_text.py create mode 100644 tests/test_feature_extraction_speech_to_text.py create mode 100644 tests/test_modeling_speech_to_text.py create mode 100644 tests/test_processor_speech_to_text.py create mode 100644 tests/test_tokenization_speech_to_text.py diff --git a/.circleci/config.yml b/.circleci/config.yml index fe85b7aaa2bdc6..e67fdaa0263708 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,8 +77,9 @@ jobs: keys: - v0.4-torch_and_tf-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece] + - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech] - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-{{ checksum "setup.py" }} @@ -104,8 +105,9 @@ jobs: keys: - v0.4-torch-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,torch,testing,sentencepiece] + - run: pip install .[sklearn,torch,testing,sentencepiece,speech] - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} @@ -157,8 +159,9 @@ jobs: keys: - v0.4-flax-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece] + - run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech] - save_cache: key: v0.4-flax-{{ checksum "setup.py" }} paths: @@ -183,8 +186,9 @@ jobs: keys: - v0.4-torch-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,torch,testing,sentencepiece] + - run: pip install .[sklearn,torch,testing,sentencepiece,speech] - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} @@ -300,6 +304,7 @@ jobs: keys: - v0.4-build_doc-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} + - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - run: pip install ."[all, docs]" - save_cache: diff --git a/README.md b/README.md index f6e503896b67a4..944c4fdc3ccfaa 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. 1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. +1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. 1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1485e9b5bc9387..392f66c99aab6b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -191,31 +191,34 @@ and conversion utilities for the following models: 36. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -37. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +37. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper + `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun + Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. +38. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -38. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +39. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -39. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +40. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -40. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +41. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -41. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +42. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -42. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +43. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -43. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +44. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -44. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +45. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -45. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +46. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. @@ -304,6 +307,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Speech2Text | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | T5 | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -436,6 +441,7 @@ TensorFlow and/or Flax. model_doc/reformer model_doc/retribert model_doc/roberta + model_doc/speech_to_text model_doc/squeezebert model_doc/t5 model_doc/tapas diff --git a/docs/source/model_doc/speech_to_text.rst b/docs/source/model_doc/speech_to_text.rst new file mode 100644 index 00000000000000..7ebccb1dce7cda --- /dev/null +++ b/docs/source/model_doc/speech_to_text.rst @@ -0,0 +1,152 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +Speech2Text +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The Speech2Text model was proposed in `fairseq S2T: Fast Speech-to-Text Modeling with fairseq +`__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. It's a +transformer-based seq2seq (encoder-decoder) model designed for end-to-end Automatic Speech Recognition (ASR) and Speech +Translation (ST). It uses a convolutional downsampler to reduce the length of speech inputs by 3/4th before they are +fed into the encoder. The model is trained with standard autoregressive cross-entropy loss and generates the +transcripts/translations autoregressively. Speech2Text has been fine-tuned on several datasets for ASR and ST: +`LibriSpeech `__, `CoVoST 2 `__, `MuST-C +`__. + +The original code can be found `here `__. + + +Inference +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Speech2Text is a speech model that accepts a float tensor of log-mel filter-bank features extracted from the speech +signal. It's a transformer-based seq2seq model, so the transcripts/translations are generated autoregressively. The +:obj:`generate()` method can be used for inference. + +The :class:`~transformers.Speech2TextFeatureExtractor` class is responsible for extracting the log-mel filter-bank +features. The :class:`~transformers.Speech2TextProcessor` wraps :class:`~transformers.Speech2TextFeatureExtractor` and +:class:`~transformers.Speech2TextTokenizer` into a single instance to both extract the input features and decode the +predicted token ids. + +The feature extractor depends on :obj:`torchaudio` and the tokenizer depends on :obj:`sentencepiece` so be sure to +install those packages before running the examples. You could either install those as extra speech dependancies with +``pip install transformers"[speech, sentencepiece]"`` or install the packages seperatly with ``pip install torchaudio +sentencepiece``. Also ``torchaudio`` requires the development version of the `libsndfile +`__ package which can be installed via a system package manager. On Ubuntu it can +be installed as follows: ``apt install libsndfile1-dev`` + + +- ASR and Speech Translation + +.. code-block:: + + >>> import torch + >>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") + >>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-small-librispeech-asr") + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1 + >>> generated_ids = model.generate(input_ids=input_features) + + >>> transcription = processor.batch_decode(generated_ids) + + +- Multilingual speech translation + + For multilingual speech translation models, :obj:`eos_token_id` is used as the :obj:`decoder_start_token_id` and + the target language id is forced as the first generated token. To force the target language id as the first + generated token, pass the :obj:`forced_bos_token_id` parameter to the :obj:`generate()` method. The following + example shows how to transate English speech to French text using the `facebook/s2t-medium-mustc-multilingual-st` + checkpoint. + +.. code-block:: + + >>> import torch + >>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st") + >>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st") + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1 + >>> generated_ids = model.generate(input_ids=input_features, forced_bos_token_id=processor.tokenizer.lang_code_to_id["fr"]) + + >>> translation = processor.batch_decode(generated_ids) + + +See the `model hub `__ to look for Speech2Text checkpoints. + + +Speech2TextConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Speech2TextConfig + :members: + + +Speech2TextTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Speech2TextTokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +Speech2TextFeatureExtractor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Speech2TextFeatureExtractor + :members: __call__ + + +Speech2TextProcessor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Speech2TextProcessor + :members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor + + +Speech2TextModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Speech2TextModel + :members: forward + + +Speech2TextForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Speech2TextForConditionalGeneration + :members: forward diff --git a/setup.cfg b/setup.cfg index a4f685aaa6fefe..5f0f0afb412042 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ known_third_party = tensorflow_datasets timeout_decorator torch + torchaudio torchtext torchvision torch_xla diff --git a/setup.py b/setup.py index 87c18390fd06f6..7903198180dd83 100644 --- a/setup.py +++ b/setup.py @@ -134,6 +134,7 @@ "timeout-decorator", "tokenizers>=0.10.1,<0.11", "torch>=1.0", + "torchaudio", "tqdm>=4.27", "unidic>=1.0.2", "unidic_lite>=1.0.7", @@ -227,14 +228,13 @@ def run(self): extras["modelcreation"] = deps_list("cookiecutter") extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") -extras["speech"] = deps_list("soundfile") +extras["speech"] = deps_list("soundfile", "torchaudio") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["testing"] = ( deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets") + extras["retrieval"] + extras["modelcreation"] - + extras["speech"] ) extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton") extras["quality"] = deps_list("black", "isort", "flake8") diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a61d279fbcdbf5..383dd7682f68f4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -135,6 +135,11 @@ "Wav2Vec2Processor", ], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100Tokenizer"], + "models.speech_to_text": [ + "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Speech2TextConfig", + "Speech2TextFeatureExtractor", + ], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ @@ -275,6 +280,8 @@ _import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") + _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") + _import_structure["models.speech_to_text"].append("Speech2TextProcessor") _import_structure["models.t5"].append("T5Tokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") @@ -377,6 +384,14 @@ _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure + _import_structure["models.speech_to_text"].extend( + [ + "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", + "Speech2TextForConditionalGeneration", + "Speech2TextModel", + ] + ) + _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1379,6 +1394,11 @@ from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer + from .models.speech_to_text import ( + SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, + Speech2TextConfig, + Speech2TextFeatureExtractor, + ) from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer @@ -1461,6 +1481,7 @@ from .models.mt5 import MT5Tokenizer from .models.pegasus import PegasusTokenizer from .models.reformer import ReformerTokenizer + from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer from .models.t5 import T5Tokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer from .models.xlm_roberta import XLMRobertaTokenizer @@ -1862,6 +1883,11 @@ RobertaForTokenClassification, RobertaModel, ) + from .models.speech_to_text import ( + SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, + Speech2TextForConditionalGeneration, + Speech2TextModel, + ) from .models.squeezebert import ( SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, SqueezeBertForMaskedLM, diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 9a7b078b8c6cae..6022ac220bc9c3 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -47,6 +47,7 @@ "timeout-decorator": "timeout-decorator", "tokenizers": "tokenizers>=0.10.1,<0.11", "torch": "torch>=1.0", + "torchaudio": "torchaudio", "tqdm": "tqdm>=4.27", "unidic": "unidic>=1.0.2", "unidic_lite": "unidic_lite>=1.0.7", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index a99d5900b18685..09470bd3dd28e2 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -177,6 +177,13 @@ except importlib_metadata.PackageNotFoundError: _soundfile_available = False +_torchaudio_available = importlib.util.find_spec("torchaudio") +try: + _torchaudio_version = importlib_metadata.version("torchaudio") + logger.debug(f"Successfully imported soundfile version {_torchaudio_version}") +except importlib_metadata.PackageNotFoundError: + _torchaudio_available = False + torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) old_default_cache_path = os.path.join(torch_cache_home, "transformers") @@ -364,6 +371,10 @@ def is_soundfile_availble(): return _soundfile_available +def is_torchaudio_available(): + return _torchaudio_available + + def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3a2d56d87cbde5..b1a2b807537f1b 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -384,7 +384,7 @@ def _prepare_attention_mask_for_generation( ) if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: return input_ids.ne(pad_token_id).long() - return input_ids.new_ones(input_ids.shape) + return input_ids.new_ones(input_ids.shape, dtype=torch.long) def _prepare_encoder_decoder_kwargs_for_generation( self, input_ids: torch.LongTensor, model_kwargs @@ -402,8 +402,7 @@ def _prepare_decoder_input_ids_for_generation( ) -> torch.LongTensor: decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) decoder_input_ids = ( - torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device) - * decoder_start_token_id + torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id ) return decoder_input_ids diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index d4957cb76cb501..ca371d804ca389 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -60,6 +60,7 @@ reformer, retribert, roberta, + speech_to_text, squeezebert, t5, tapas, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 4a9be13e52b2e9..c28d3190dce2ce 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -58,6 +58,10 @@ from ..reformer.configuration_reformer import ReformerConfig from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig +from ..speech_to_text.configuration_speech_to_text import ( + SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, + Speech2TextConfig, +) from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig @@ -76,6 +80,7 @@ (key, value) for pretrained_map in [ # Add archive maps here + SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -122,6 +127,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("speech_to_text", Speech2TextConfig), ("wav2vec2", Wav2Vec2Config), ("m2m_100", M2M100Config), ("convbert", ConvBertConfig), @@ -174,6 +180,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("speech_to_text", "Speech2Text"), ("wav2vec2", "Wav2Vec2"), ("m2m_100", "M2M100"), ("convbert", "ConvBERT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 99a72320e3a58d..b5b85f8c1b2382 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -66,8 +66,6 @@ CamembertForTokenClassification, CamembertModel, ) - -# Add modeling imports here from ..convbert.modeling_convbert import ( ConvBertForMaskedLM, ConvBertForMultipleChoice, @@ -211,6 +209,7 @@ RobertaForTokenClassification, RobertaModel, ) +from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel from ..squeezebert.modeling_squeezebert import ( SqueezeBertForMaskedLM, SqueezeBertForMultipleChoice, @@ -296,6 +295,7 @@ ReformerConfig, RetriBertConfig, RobertaConfig, + Speech2TextConfig, SqueezeBertConfig, T5Config, TapasConfig, @@ -315,6 +315,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping + (Speech2TextConfig, Speech2TextModel), (Wav2Vec2Config, Wav2Vec2Model), (M2M100Config, M2M100Model), (ConvBertConfig, ConvBertModel), @@ -399,6 +400,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping + (Speech2TextConfig, Speech2TextForConditionalGeneration), (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (M2M100Config, M2M100ForConditionalGeneration), (ConvBertConfig, ConvBertForMaskedLM), diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py new file mode 100644 index 00000000000000..d431ce4fa6d698 --- /dev/null +++ b/src/transformers/models/speech_to_text/__init__.py @@ -0,0 +1,77 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available + + +_import_structure = { + "configuration_speech_to_text": [ + "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Speech2TextConfig", + ], + "feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"], +} + +if is_sentencepiece_available(): + _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] + _import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"] + + +if is_torch_available(): + _import_structure["modeling_speech_to_text"] = [ + "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", + "Speech2TextForConditionalGeneration", + "Speech2TextModel", + "Speech2TextPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig + from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor + + if is_sentencepiece_available(): + from .processing_speech_to_text import Speech2TextProcessor + from .tokenization_speech_to_text import Speech2TextTokenizer + + if is_torch_available(): + from .modeling_speech_to_text import ( + SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, + Speech2TextForConditionalGeneration, + Speech2TextModel, + Speech2TextPreTrainedModel, + ) + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/src/transformers/models/speech_to_text/configuration_speech_to_text.py new file mode 100644 index 00000000000000..ceaebec98dab9e --- /dev/null +++ b/src/transformers/models/speech_to_text/configuration_speech_to_text.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Speech2Text model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/config.json", + # See all Speech2Text models at https://huggingface.co/models?filter=speech_to_text +} + + +class Speech2TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.Speech2TextModel`. It is used + to instantiate an Speech2Text model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Speech2Text + `facebook/s2t-small-librispeech-asr `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50265): + Vocabulary size of the Speech2Text model. Defines the number of different tokens that can be represented by + the :obj:`inputs_ids` passed when calling :class:`~transformers.Speech2TextModel` + d_model (:obj:`int`, `optional`, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (:obj:`int`, `optional`, defaults to 12): + Number of encoder layers. + decoder_layers (:obj:`int`, `optional`, defaults to 12): + Number of decoder layers. + encoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for classifier. + init_std (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. + decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + max_source_positions (:obj:`int`, `optional`, defaults to 6000): + The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + max_target_positions: (:obj:`int`, `optional`, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + num_conv_layers (:obj:`int`, `optional`, defaults to 2): + Number of 1D convolutional layers in the conv module. + conv_kernel_sizes (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 5)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the conv module. The length + of :obj:`conv_kernel_sizes` has to match :obj:`num_conv_layers`. + conv_channels (:obj:`int`, `optional`, defaults to 1024): + An integer defining the number of output channels of each convolution layers except the final one in the + conv module. + input_feat_per_channel (:obj:`int`, `optional`, defaults to 80): + An integer specifying the size of feature vector. This is also the dimentions of log-mel filter-bank + features. + input_channels (:obj:`int`, `optional`, defaults to 1): + An integer specifying number of input channels of the input feature vector. + + Example:: + + >>> from transformers import Speech2TextModel, Speech2TextConfig + + >>> # Initializing a Speech2Text s2t_transformer_s style configuration + >>> configuration = Speech2TextConfig() + + >>> # Initializing a model from the s2t_transformer_s style configuration + >>> model = Speech2TextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "speech_to_text" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=10000, + encoder_layers=12, + encoder_ffn_dim=2048, + encoder_attention_heads=4, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=4, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + scale_embedding=True, + gradient_checkpointing=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + max_source_positions=6000, + max_target_positions=1024, + num_conv_layers=2, + conv_kernel_sizes=(5, 5), + conv_channels=1024, + input_feat_per_channel=80, + input_channels=1, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.gradient_checkpointing = gradient_checkpointing + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.num_conv_layers = num_conv_layers + self.conv_kernel_sizes = list(conv_kernel_sizes) + self.conv_channels = conv_channels + self.input_feat_per_channel = input_feat_per_channel + self.input_channels = input_channels + + if len(self.conv_kernel_sizes) != self.num_conv_layers: + raise ValueError( + "Configuration for convolutional module is incorrect." + "It is required that `len(config.conv_kernel_sizes)` == `config.num_conv_layers`" + f"but is `len(config.conv_kernel_sizes) = {len(self.conv_kernel_sizes)}`," + f"`config.num_conv_layers = {self.num_conv_layers}`." + ) + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model diff --git a/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py new file mode 100644 index 00000000000000..2f57d1e34038fd --- /dev/null +++ b/src/transformers/models/speech_to_text/convert_s2t_fairseq_to_tfms.py @@ -0,0 +1,112 @@ +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from torch import nn + +from transformers import Speech2TextConfig, Speech2TextForConditionalGeneration + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "decoder.output_projection.weight", + "_float_tensor", + "encoder.embed_positions._float_tensor", + "decoder.embed_positions._float_tensor", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_keys(s_dict): + keys = list(s_dict.keys()) + for key in keys: + if "transformer_layers" in key: + s_dict[key.replace("transformer_layers", "layers")] = s_dict.pop(key) + elif "subsample" in key: + s_dict[key.replace("subsample", "conv")] = s_dict.pop(key) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_path): + m2m_100 = torch.load(checkpoint_path, map_location="cpu") + args = m2m_100["args"] + state_dict = m2m_100["model"] + lm_head_weights = state_dict["decoder.output_projection.weight"] + + remove_ignore_keys_(state_dict) + rename_keys(state_dict) + + vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] + + tie_embeds = args.share_decoder_input_output_embed + + conv_kernel_sizes = [int(i) for i in args.conv_kernel_sizes.split(",")] + config = Speech2TextConfig( + vocab_size=vocab_size, + max_source_positions=args.max_source_positions, + max_target_positions=args.max_target_positions, + encoder_layers=args.encoder_layers, + decoder_layers=args.decoder_layers, + encoder_attention_heads=args.encoder_attention_heads, + decoder_attention_heads=args.decoder_attention_heads, + encoder_ffn_dim=args.encoder_ffn_embed_dim, + decoder_ffn_dim=args.decoder_ffn_embed_dim, + d_model=args.encoder_embed_dim, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_function="relu", + num_conv_layers=len(conv_kernel_sizes), + conv_channels=args.conv_channels, + conv_kernel_sizes=conv_kernel_sizes, + input_feat_per_channel=args.input_feat_per_channel, + input_channels=args.input_channels, + tie_word_embeddings=tie_embeds, + num_beams=5, + max_length=200, + use_cache=True, + decoder_start_token_id=2, + early_stopping=True, + ) + + model = Speech2TextForConditionalGeneration(config) + model.model.load_state_dict(state_dict) + if tie_embeds: + model.lm_head = make_linear_from_emb(model.model.decoder.embed_tokens) + else: + model.lm_head.weight.data = lm_head_weights + + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("fairseq_path", type=str, help="Path to the fairseq model (.pt) file.") + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_fairseq_s2t_checkpoint_to_tfms(args.fairseq_path, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py new file mode 100644 index 00000000000000..e7fdb44aefe40b --- /dev/null +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for Speech2Text +""" + +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...file_utils import PaddingStrategy, TensorType, is_torch_available, is_torchaudio_available +from ...utils import logging + + +if is_torch_available(): + import torch + +if is_torchaudio_available(): + import torchaudio.compliance.kaldi as ta_kaldi + +logger = logging.get_logger(__name__) + + +class Speech2TextFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Speech2Text feature extractor. + + This feature extractor inherits from :class:`~transformers.Speech2TextFeatureExtractor` which contains most of the + main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral + mean and variance normalization to the extracted features. + + Args: + feature_size (:obj:`int`, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (:obj:`int`, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz). + num_mel_bins (:obj:`int`, defaults to 80): + Number of Mel-frequency bins. + padding_value (:obj:`float`, defaults to 0.0): + The value that is used to fill the padding vectors. + do_ceptral_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features. + normalize_means (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to zero-mean normalize the extracted features. + normalize_vars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to unit-variance normalize the extracted features. + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + num_mel_bins=80, + padding_value=0.0, + do_ceptral_normalize=True, + normalize_means=True, + normalize_vars=True, + **kwargs + ): + if not is_torchaudio_available(): + raise ImportError("`Speech2TextFeatureExtractor` requires torchaudio: `pip install torchaudio`.") + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.num_mel_bins = num_mel_bins + self.do_ceptral_normalize = do_ceptral_normalize + self.normalize_means = normalize_means + self.normalize_vars = normalize_vars + self.return_attention_mask = True + + def _extract_fbank_features( + self, + waveform: np.ndarray, + ) -> np.ndarray: + """ + Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs + and hence the waveform should not be normalized before feature extraction. + """ + waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers + waveform = torch.from_numpy(waveform).unsqueeze(0) + features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate) + return features.numpy() + + @staticmethod + def utterance_cmvn( + x: np.ndarray, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True + ) -> np.ndarray: + mean = x.mean(axis=0) + square_sums = (x ** 2).sum(axis=0) + + if normalize_means: + x = np.subtract(x, mean) + if normalize_vars: + var = square_sums / x.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-10)) + x = np.divide(x, std) + + return x + + def normalize(self, input_values: List[np.ndarray]) -> List[np.ndarray]: + return [self.utterance_cmvn(x, self.normalize_means, self.normalize_vars) for x in input_values] + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). sequences. + + Args: + raw_speech (:obj:`np.ndarray`, :obj:`List[float]`, :obj:`List[np.ndarray]`, :obj:`List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + >= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (:obj:`bool`, `optional`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + .. note:: + + For Speech2TextTransoformer models, :obj:`attention_mask` should alwys be passed for batched + inference, to avoid subtle bugs. + + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + sampling_rate (:obj:`int`, `optional`): + The sampling rate at which the :obj:`raw_speech` input was sampled. It is strongly recommended to pass + :obj:`sampling_rate` at the forward call to prevent silent errors. + padding_value (:obj:`float`, defaults to 0.0): + The value that is used to fill the padding values / vectors. + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of {self.sampling_rate}." + f"Please make sure that the provided `raw_speech` input was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function." + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched = bool( + isinstance(raw_speech, (list, tuple)) + and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + ) + + # make sure input is in list format + if is_batched and not isinstance(raw_speech[0], np.ndarray): + raw_speech = [np.asarray(speech) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features + features = [self._extract_fbank_features(waveform) for waveform in raw_speech] + + # Utterance-level cepstral mean and variance normalization + if self.do_ceptral_normalize: + features = self.normalize(features) + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_features": features}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_tensors=return_tensors, + **kwargs, + ) + + return padded_inputs diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py new file mode 100755 index 00000000000000..5c82896b9e59fd --- /dev/null +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -0,0 +1,1353 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Speech2Text model. """ + + +import math +import random +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_speech_to_text import Speech2TextConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Speech2TextConfig" +_TOKENIZER_FOR_DOC = "Speech2TextTokenizer" + + +SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/s2t-small-librispeech-asr", + # See all Speech2Text models at https://huggingface.co/models?filter=speech_to_text +] + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class Conv1dSubsampler(nn.Module): + """ + Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation + via gated linear units (https://arxiv.org/abs/1911.08460) + """ + + def __init__(self, config): + super(Conv1dSubsampler, self).__init__() + self.config = config + self.num_layers = config.num_conv_layers + self.in_channels = config.input_feat_per_channel * config.input_channels + self.mid_channels = config.conv_channels + self.out_channels = config.d_model + self.kernel_sizes = config.conv_kernel_sizes + + self.conv_layers = nn.ModuleList( + nn.Conv1d( + self.in_channels if i == 0 else self.mid_channels // 2, + self.mid_channels if i < self.num_layers - 1 else self.out_channels * 2, + kernel_size=k, + stride=2, + padding=k // 2, + ) + for i, k in enumerate(self.kernel_sizes) + ) + + def forward(self, input_features): + hidden_states = input_features.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + hidden_states = conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + hidden_states = hidden_states.transpose(1, 2).contiguous() # -> T x B x (C x D) + return hidden_states + + +class Speech2TextSinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.make_weights(num_positions + self.offset, embedding_dim, padding_idx) + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward, put the weights on correct device + emb_weights = emb_weights.to(self.weights.device) + + self.weights = nn.Parameter(emb_weights) + self.weights.requires_grad = False + self.weights.detach_() + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Speech2Text +class Speech2TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class Speech2TextEncoderLayer(nn.Module): + def __init__(self, config: Speech2TextConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = Speech2TextAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + :obj:`(config.encoder_attention_heads,)`. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Speech2TextDecoderLayer(nn.Module): + def __init__(self, config: Speech2TextConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = Speech2TextAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = Speech2TextAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + encoder_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape :obj:`(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + :obj:`(config.encoder_attention_heads,)`. + encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size :obj:`(config.encoder_attention_heads,)`. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Speech2TextPreTrainedModel(PreTrainedModel): + config_class = Speech2TextConfig + base_model_prefix = "model" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _get_subsampled_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + for i in range(self.config.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + def _get_subsampled_encoder_attn_mask(self, attention_mask): + # generate creates 3D attention mask, becuase of the shape of input_features + # convert it to 2D if thats the case + if len(attention_mask.shape) > 2: + attention_mask = attention_mask[:, :, -1] + + subsampled_lengths = self._get_subsampled_output_lengths(attention_mask.sum(-1)) + max_len = subsampled_lengths.max().item() + bsz = attention_mask.size()[0] + attention_mask = torch.zeros((bsz, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() + return attention_mask + + +SPEECH_TO_TEXT_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.Speech2TextConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +SPEECH_TO_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_features (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length, feature_size)`): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a ``.flac`` or ``.wav`` audio file into an array of type :obj:`List[float]` or a + :obj:`numpy.ndarray`, *e.g.* via the soundfile library (``pip install soundfile``). To prepare the array + into :obj:`input_features`, the :class:`~transformers.Speech2TextTokenizer` should be used for extracting + the fbank features, padding and conversion into a tensor of type :obj:`torch.FloatTensor`. See + :meth:`~transformers.Speech2TextTokenizer.__call__` + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0, + 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the :obj:`input_ids` to the right, following the paper. + decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read + :func:`modeling_speech_to_text._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the + paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class Speech2TextEncoder(Speech2TextPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + :class:`Speech2TextEncoderLayer`. + + Args: + config: Speech2TextConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv = Conv1dSubsampler(config) + + self.embed_positions = Speech2TextSinusoidalPositionalEmbedding( + self.max_source_positions, + embed_dim, + self.padding_idx, + ) + self.layers = nn.ModuleList([Speech2TextEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.init_weights() + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_features (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length, feature_size)`): + Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a ``.flac`` or ``.wav`` audio file into an array of type :obj:`List[float]` or a + :obj:`numpy.ndarray`, *e.g.* via the soundfile library (``pip install soundfile``). To prepare the + array into :obj:`input_features`, the :class:`~transformers.Speech2TextTokenizer` should be used for + extracting the fbank features, padding and conversion into a tensor of type :obj:`torch.FloatTensor`. + See :meth:`~transformers.Speech2TextTokenizer.__call__` + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in + ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None: + attention_mask = self._get_subsampled_encoder_attn_mask(attention_mask) + + inputs_embeds = self.conv(input_features) + inputs_embeds = self.embed_scale * inputs_embeds + + if attention_mask is None: + padding_mask = torch.zeros_like(inputs_embeds, dtype=torch.long) + else: + padding_mask = attention_mask.ne(1).long() + embed_pos = self.embed_positions(padding_mask) + + hidden_states = inputs_embeds + embed_pos + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Speech2TextDecoder(Speech2TextPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`Speech2TextDecoderLayer` + + Args: + config: Speech2TextConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = Speech2TextSinusoidalPositionalEmbedding( + self.max_target_positions, + config.d_model, + self.padding_idx, + ) + self.layers = nn.ModuleList([Speech2TextDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.init_weights() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoder.forward with MBart->Speech2Text + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.Speech2TextTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + encoder_attention_mask = self._get_subsampled_encoder_attn_mask(encoder_attention_mask) + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + encoder_head_mask[idx] if encoder_head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare Speech2Text Model outputting raw hidden-states without any specific head on top.", + SPEECH_TO_TEXT_START_DOCSTRING, +) +class Speech2TextModel(Speech2TextPreTrainedModel): + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + + self.encoder = Speech2TextEncoder(config) + self.decoder = Speech2TextDecoder(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="s2t_transformer_s", + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_features=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs=None, + past_key_values=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_features, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The Speech2Text Model with a language modeling head. Can be used for summarization.", + SPEECH_TO_TEXT_START_DOCSTRING, +) +class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"encoder\.version", + r"decoder\.version", + r"model.encoder.embed_positions.weights", + r"model.decoder.embed_positions.weights", + ] + _keys_to_ignore_on_save = [ + r"model.encoder.embed_positions.weights", + r"model.decoder.embed_positions.weights", + ] + + def __init__(self, config: Speech2TextConfig): + super().__init__(config) + self.model = Speech2TextModel(config) + self.lm_head = nn.Linear(config.d_model, self.config.vocab_size, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + return new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_features=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + encoder_outputs=None, + past_key_values=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. + + Returns: + + Example:: + + >>> import torch + >>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") + >>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-small-librispeech-asr") + + >>> def map_to_array(batch): + >>> speech, _ = sf.read(batch["file"]) + >>> batch["speech"] = speech + >>> return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1 + >>> generated_ids = model.generate(input_ids=input_features) + + >>> transcription = processor.batch_decode(generated_ids) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/speech_to_text/processing_speech_to_text.py b/src/transformers/models/speech_to_text/processing_speech_to_text.py new file mode 100644 index 00000000000000..af79e9c64ac924 --- /dev/null +++ b/src/transformers/models/speech_to_text/processing_speech_to_text.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Speech processor class for Speech2Text +""" +from contextlib import contextmanager + +from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor +from .tokenization_speech_to_text import Speech2TextTokenizer + + +class Speech2TextProcessor: + r""" + Constructs a Speech2Text processor which wraps a Speech2Text feature extractor and a Speech2Text tokenizer into a + single processor. + + :class:`~transformers.Speech2TextProcessor` offers all the functionalities of + :class:`~transformers.Speech2TextFeatureExtractor` and :class:`~transformers.Speech2TextTokenizer`. See the + :meth:`~transformers.Speech2TextProcessor.__call__` and :meth:`~transformers.Speech2TextProcessor.decode` for more + information. + + Args: + feature_extractor (:obj:`Speech2TextFeatureExtractor`): + An instance of :class:`~transformers.Speech2TextFeatureExtractor`. The feature extractor is a required + input. + tokenizer (:obj:`Speech2TextTokenizer`): + An instance of :class:`~transformers.Speech2TextTokenizer`. The tokenizer is a required input. + """ + + def __init__(self, feature_extractor, tokenizer): + if not isinstance(feature_extractor, Speech2TextFeatureExtractor): + raise ValueError( + f"`feature_extractor` has to be of type {Speech2TextFeatureExtractor.__class__}, but is {type(feature_extractor)}" + ) + if not isinstance(tokenizer, Speech2TextTokenizer): + raise ValueError( + f"`tokenizer` has to be of type {Speech2TextTokenizer.__class__}, but is {type(tokenizer)}" + ) + + self.feature_extractor = feature_extractor + self.tokenizer = tokenizer + self.current_processor = self.feature_extractor + + def save_pretrained(self, save_directory): + """ + Save a Speech2Text feature extractor object and Speech2Text tokenizer object to the directory + ``save_directory``, so that it can be re-loaded using the + :func:`~transformers.Speech2TextProcessor.from_pretrained` class method. + + .. note:: + + This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and + :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the + docstrings of the methods above for more information. + + Args: + save_directory (:obj:`str` or :obj:`os.PathLike`): + Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will + be created if it does not exist). + """ + + self.feature_extractor.save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate a :class:`~transformers.Speech2TextProcessor` from a pretrained Speech2Text processor. + + .. note:: + + This class method is simply calling Speech2TextFeatureExtractor's + :meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Speech2TextTokenizer's + :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the + docstrings of the methods above for more information. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a feature extractor file saved using the + :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., + ``./my_model_directory/``. + - a path or url to a saved feature extractor JSON `file`, e.g., + ``./my_model_directory/feature_extraction_config.json``. + **kwargs + Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and + :class:`~transformers.PreTrainedTokenizer` + """ + feature_extractor = Speech2TextFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = Speech2TextTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + + return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) + + def __call__(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to Speech2TextFeatureExtractor's + :meth:`~transformers.Speech2TextFeatureExtractor.__call__` and returns its output. If used in the context + :meth:`~transformers.Speech2TextProcessor.as_target_processor` this method forwards all its arguments to + Speech2TextTokenizer's :meth:`~transformers.Speech2TextTokenizer.__call__`. Please refer to the doctsring of + the above two methods for more information. + """ + return self.current_processor(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Speech2TextTokenizer's + :meth:`~transformers.PreTrainedTokenizer.batch_decode`. Please refer to the docstring of this method for more + information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Speech2TextTokenizer's + :meth:`~transformers.PreTrainedTokenizer.decode`. Please refer to the docstring of this method for more + information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning + Speech2Text. + """ + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor diff --git a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py new file mode 100644 index 00000000000000..bf3402295aa337 --- /dev/null +++ b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py @@ -0,0 +1,259 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Speech2Text.""" + +import json +from pathlib import Path +from shutil import copyfile +from typing import Dict, List, Optional, Tuple, Union + +import sentencepiece + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "spm_file": "sentencepiece.bpe.model", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/vocab.json", + }, + "spm_file": { + "facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/sentencepiece.bpe.model" + }, +} + +MAX_MODEL_INPUT_SIZES = { + "facebook/s2t-small-librispeech-asr": 1024, +} + +MUSTC_LANGS = ["pt", "fr", "ru", "nl", "ro", "it", "es", "de"] + +LANGUAGES = {"mustc": MUSTC_LANGS} + + +class Speech2TextTokenizer(PreTrainedTokenizer): + """ + Construct an Speech2Text tokenizer. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains some of the main methods. + Users should refer to the superclass for more information regarding such methods. + + Args: + vocab_file (:obj:`str`): + File containing the vocabulary. + spm_file (:obj:`str`): + Path to the `SentencePiece `__ model file + bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The beginning of sentence token. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sentence token. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + do_upper_case (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to uppercase the output when decoding. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to lowercase the input when tokenizing. + tgt_lang (:obj:`str`, `optional`): + A string representing the target language. + **kwargs + Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer` + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = MAX_MODEL_INPUT_SIZES + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + spm_file, + bos_token="", + eos_token="", + pad_token="", + unk_token="", + do_upper_case=False, + do_lower_case=False, + tgt_lang=None, + lang_codes=None, + **kwargs, + ): + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + do_upper_case=do_upper_case, + do_lower_case=do_lower_case, + tgt_lang=tgt_lang, + lang_codes=lang_codes, + **kwargs, + ) + self.do_upper_case = do_upper_case + self.do_lower_case = do_lower_case + + self.encoder = load_json(vocab_file) + self.decoder = {v: k for k, v in self.encoder.items()} + self.spm_file = spm_file + self.sp_model = load_spm(spm_file) + + if lang_codes is not None: + self.lang_codes = lang_codes + self.langs = LANGUAGES[lang_codes] + self.lang_tokens = [f"" for lang in self.langs] + self.lang_code_to_id = {lang: self.sp_model.PieceToId(f"") for lang in self.langs} + + self._additional_special_tokens = self.lang_tokens + self._tgt_lang = tgt_lang if tgt_lang is not None else self.langs[0] + + self.set_tgt_lang_special_tokens(self._tgt_lang) + else: + self.lang_code_to_id = {} + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang) -> None: + self._tgt_lang = new_tgt_lang + self.set_tgt_lang_special_tokens(new_tgt_lang) + + def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None: + """Reset the special tokens to the target language setting. prefix=[eos, tgt_lang_code] and suffix=[eos].""" + lang_code_id = self.lang_code_to_id[tgt_lang] + self.prefix_tokens = [lang_code_id] + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.EncodeAsPieces(text) + + def _convert_token_to_id(self, token): + return self.encoder.get(token, self.encoder[self.unk_token]) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the decoder.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + + if self.do_upper_case: + out_string = out_string.upper() + return out_string + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id] + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.bos_token_id, self.eos_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def get_vocab(self) -> Dict: + vocab = self.encoder.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self) -> Dict: + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d: Dict) -> None: + self.__dict__ = d + self.sp_model = load_spm(self.spm_file) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + save_dir = Path(save_directory) + assert save_dir.is_dir(), f"{save_directory} should be a directory" + vocab_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"] + ) + spm_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["spm_file"] + ) + + save_json(self.encoder, vocab_save_path) + + if not spm_save_path.exists(): + copyfile(self.spm_file, spm_save_path) + + return (str(vocab_save_path), str(spm_save_path)) + + +def load_spm(path: str) -> sentencepiece.SentencePieceProcessor: + spm = sentencepiece.SentencePieceProcessor() + spm.Load(str(path)) + return spm + + +def load_json(path: str) -> Union[Dict, List]: + with open(path, "r") as f: + return json.load(f) + + +def save_json(data, path: str) -> None: + with open(path, "w") as f: + json.dump(data, f, indent=2) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 10a67953cf5323..13838fab406dea 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -38,6 +38,7 @@ is_tokenizers_available, is_torch_available, is_torch_tpu_available, + is_torchaudio_available, ) from .integrations import is_optuna_available, is_ray_available @@ -195,6 +196,19 @@ def require_torch_scatter(test_case): return test_case +def require_torchaudio(test_case): + """ + Decorator marking a test that requires torchaudio. + + These tests are skipped when torchaudio isn't installed. + + """ + if not is_torchaudio_available: + return unittest.skip("test requires torchaudio")(test_case) + else: + return test_case + + def require_tf(test_case): """ Decorator marking a test that requires TensorFlow. diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index fb782d65059022..d5ddcd2e3c769c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2160,6 +2160,27 @@ def from_pretrained(self, *args, **kwargs): requires_pytorch(self) +SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class Speech2TextForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class Speech2TextModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index 4c3c3c2abd99e2..d9611dd2513685 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -92,6 +92,20 @@ def from_pretrained(self, *args, **kwargs): requires_sentencepiece(self) +class Speech2TextProcessor: + def __init__(self, *args, **kwargs): + requires_sentencepiece(self) + + +class Speech2TextTokenizer: + def __init__(self, *args, **kwargs): + requires_sentencepiece(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_sentencepiece(self) + + class T5Tokenizer: def __init__(self, *args, **kwargs): requires_sentencepiece(self) diff --git a/tests/test_feature_extraction_speech_to_text.py b/tests/test_feature_extraction_speech_to_text.py new file mode 100644 index 00000000000000..5cd2f67f457d5f --- /dev/null +++ b/tests/test_feature_extraction_speech_to_text.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import random +import unittest + +import numpy as np + +from transformers import Speech2TextFeatureExtractor +from transformers.testing_utils import require_torch, require_torchaudio + +from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin + + +global_rng = random.Random() + + +def floats_list(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for batch_idx in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +@require_torch +@require_torchaudio +class Speech2TextFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + min_seq_length=400, + max_seq_length=2000, + feature_size=24, + num_mel_bins=24, + padding_value=0.0, + sampling_rate=16_000, + return_attention_mask=True, + do_normalize=True, + ): + self.parent = parent + self.batch_size = batch_size + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1) + self.feature_size = feature_size + self.num_mel_bins = num_mel_bins + self.padding_value = padding_value + self.sampling_rate = sampling_rate + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize + + def prepare_feat_extract_dict(self): + return { + "feature_size": self.feature_size, + "num_mel_bins": self.num_mel_bins, + "padding_value": self.padding_value, + "sampling_rate": self.sampling_rate, + "return_attention_mask": self.return_attention_mask, + "do_normalize": self.do_normalize, + } + + def prepare_inputs_for_common(self, equal_length=False, numpify=False): + def _flatten(list_of_lists): + return list(itertools.chain(*list_of_lists)) + + if equal_length: + speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)] + else: + speech_inputs = [ + floats_list((x, self.feature_size)) + for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff) + ] + if numpify: + speech_inputs = [np.asarray(x) for x in speech_inputs] + return speech_inputs + + +@require_torch +@require_torchaudio +class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): + + feature_extraction_class = Speech2TextFeatureExtractor + + def setUp(self): + self.feat_extract_tester = Speech2TextFeatureExtractionTester(self) + + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # Test feature size + input_features = feature_extractor(np_speech_inputs, padding=True, return_tensors="np").input_features + self.assertTrue(input_features.ndim == 3) + self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size) + + # Test not batched input + encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test batched + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + def test_cepstral_mean_and_variance_normalization(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True) + input_features = inputs.input_features + attention_mask = inputs.attention_mask + fbank_feat_lengths = np.sum(attention_mask == 1, axis=1) + + def _check_zero_mean_unit_variance(input_vector): + self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3)) + self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3)) + + _check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]]) + _check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]]) + _check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]]) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 2c9669306940cc..77a2abeed3d6b7 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -53,12 +53,13 @@ class GenerationTesterMixin: model_tester = None all_generative_model_classes = () + input_name = "input_ids" def _get_input_ids_and_config(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict["input_ids"] - attention_mask = torch.ones_like(input_ids) + input_ids = inputs_dict[self.input_name] + attention_mask = torch.ones_like(input_ids, dtype=torch.long) # cut to half length & take max batch_size 3 max_batch_size = 2 diff --git a/tests/test_modeling_speech_to_text.py b/tests/test_modeling_speech_to_text.py new file mode 100644 index 00000000000000..c5b7db53c85498 --- /dev/null +++ b/tests/test_modeling_speech_to_text.py @@ -0,0 +1,754 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Speech2Text model. """ + + +import copy +import inspect +import os +import tempfile +import unittest + +from transformers.file_utils import cached_property +from transformers.testing_utils import ( + is_torch_available, + require_sentencepiece, + require_tokenizers, + require_torch, + require_torchaudio, + slow, + torch_device, +) + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + Speech2TextConfig, + Speech2TextForConditionalGeneration, + Speech2TextModel, + Speech2TextProcessor, + ) + from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder + + +def prepare_speech_to_text_inputs_dict( + config, + input_features, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_features.ne(0) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + return { + # "input_ids": input_features, + "input_features": input_features, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + + +@require_torch +class Speech2TextModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + num_conv_layers=2, + conv_kernel_sizes=(5, 5), + conv_channels=32, + input_feat_per_channel=24, + input_channels=1, + hidden_act="relu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + max_source_positions=20, + max_target_positions=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.num_conv_layers = num_conv_layers + self.conv_kernel_sizes = conv_kernel_sizes + self.conv_channels = conv_channels + self.input_feat_per_channel = input_feat_per_channel + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs(self): + input_features = floats_tensor( + [self.batch_size, self.seq_length, self.input_feat_per_channel], self.vocab_size + ) + attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device) + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2) + + config = Speech2TextConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + num_conv_layers=self.num_conv_layers, + conv_kernel_sizes=self.conv_kernel_sizes, + conv_channels=self.conv_channels, + input_feat_per_channel=self.input_feat_per_channel, + input_channels=self.input_channels, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + max_source_positions=self.max_source_positions, + max_target_positions=self.max_target_positions, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + ) + inputs_dict = prepare_speech_to_text_inputs_dict( + config, + input_features=input_features, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + ) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def get_subsampled_output_lengths(self, input_lengths): + """ + Computes the output length of the convolutional layers + """ + + for i in range(self.num_conv_layers): + input_lengths = (input_lengths - 1) // 2 + 1 + + return input_lengths + + def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval() + input_ids = inputs_dict["decoder_input_ids"] + attention_mask = inputs_dict["decoder_attention_mask"] + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size).clamp(2) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = Speech2TextModel(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = Speech2TextEncoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder( + inputs_dict["input_features"], attention_mask=inputs_dict["attention_mask"] + )[0] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = Speech2TextDecoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=inputs_dict["attention_mask"], + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +@require_torch +class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_missing_keys = False + test_torchscript = True + + input_name = "input_features" + + def setUp(self): + self.model_tester = Speech2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=Speech2TextConfig) + self.maxDiff = 3000 + + def test_config(self): + self.config_tester.run_common_tests() + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + def test_inputs_embeds(self): + pass + + # training is not supported yet + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_features = input_dict["input_features"] + attention_mask = input_dict["attention_mask"] + model = Speech2TextForConditionalGeneration(config).eval().to(torch_device) + if torch_device == "cuda": + input_features = input_features.half() + model.half() + model.generate(input_features, attention_mask=attention_mask) + model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = [ + "input_features", + "attention_mask", + "decoder_input_ids", + "decoder_attention_mask", + ] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + else: + seq_length = self.model_tester.seq_length + + subsampled_seq_length = model._get_subsampled_output_lengths(seq_length) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [subsampled_seq_length, self.model_tester.hidden_size], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [decoder_seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + subsampled_encoder_seq_length = model._get_subsampled_output_lengths(encoder_seq_length) + subsampled_encoder_key_length = model._get_subsampled_output_lengths(encoder_key_length) + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length], + ) + out_len = len(outputs) + + correct_outlen = 5 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + subsampled_encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length], + ) + + def test_resize_tokens_embeddings(self): + ( + original_config, + inputs_dict, + ) = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_resize_embeddings: + return + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + model = model_class(config) + model.to(torch_device) + + if self.model_tester.is_training is False: + model.eval() + + model_vocab_size = config.vocab_size + # Retrieve the embeddings and clone theme + model_embed = model.resize_token_embeddings(model_vocab_size) + cloned_embeddings = model_embed.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size + 10) + self.assertEqual(model.config.vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size - 15) + self.assertEqual(model.config.vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) + + # make sure that decoder_input_ids are resized + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that adding and removing tokens has not modified the first part of the embedding matrix. + models_equal = True + for p1, p2 in zip(cloned_embeddings, model_embed.weight): + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_resize_embeddings_untied(self): + ( + original_config, + inputs_dict, + ) = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_resize_embeddings: + return + + original_config.tie_word_embeddings = False + + # if model cannot untied embeddings -> leave test + if original_config.tie_word_embeddings: + return + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + model = model_class(config).to(torch_device) + + # if no output embeddings -> leave test + if model.get_output_embeddings() is None: + continue + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_vocab_size = config.vocab_size + model.resize_token_embeddings(model_vocab_size + 10) + self.assertEqual(model.config.vocab_size, model_vocab_size + 10) + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model.resize_token_embeddings(model_vocab_size - 15) + self.assertEqual(model.config.vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + def test_generate_without_input_ids(self): + pass + + @staticmethod + def _get_encoder_outputs( + model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 + ): + encoder = model.get_encoder() + encoder_outputs = encoder( + input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( + num_interleave, dim=0 + ) + input_ids = input_ids[:, :, 0] + input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id() + attention_mask = None + return encoder_outputs, input_ids, attention_mask + + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): + batch_size, seq_length = input_ids.shape[:2] + subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) + num_sequences_in_output = batch_size * num_return_sequences + gen_len = ( + output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length + ) + + # scores + self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) + + # Attentions + # encoder + self._check_encoder_attention_for_generate( + output.encoder_attentions, batch_size, config, subsampled_seq_length + ) + # decoder + self._check_attentions_for_generate( + num_sequences_in_output, + output.decoder_attentions, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + + # Hidden States + # encoder + self._check_encoder_hidden_states_for_generate( + output.encoder_hidden_states, batch_size, config, subsampled_seq_length + ) + + # decoder + self._check_hidden_states_for_generate( + num_sequences_in_output, + output.decoder_hidden_states, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + try: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + input_features = inputs["input_features"] + attention_mask = inputs["attention_mask"] + decoder_input_ids = inputs["decoder_input_ids"] + decoder_attention_mask = inputs["decoder_attention_mask"] + traced_model = torch.jit.trace( + model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask) + ) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + +@require_torch +@require_torchaudio +@require_sentencepiece +@require_tokenizers +@slow +class Speech2TextModelIntegrationTests(unittest.TestCase): + @cached_property + def default_processor(self): + return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") + + def _load_datasamples(self, num_samples): + from datasets import load_dataset + + import soundfile as sf + + # map files to raw + def map_to_array(batch): + speech, _ = sf.read(batch["file"]) + batch["speech"] = speech + return batch + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + ds = ds.select(range(num_samples)).map(map_to_array) + + return ds["speech"][:num_samples] + + def test_generation_librispeech(self): + model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") + model.to(torch_device) + processor = self.default_processor + + input_speech = self._load_datasamples(1) + + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) + + generated_ids = model.generate(input_features) + generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + + EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"] + self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS) + + def test_generation_librispeech_batched(self): + model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") + model.to(torch_device) + processor = self.default_processor + + input_speech = self._load_datasamples(4) + + inputs = processor(input_speech, return_tensors="pt", padding=True) + + input_features = inputs.input_features.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + generated_ids = model.generate(input_features, attention_mask=attention_mask) + generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True) + + EXPECTED_TRANSCRIPTIONS = [ + "a man said to the universe sir i exist", + "sweat covered brion's body trickling into the titleing cloth that was the only garment he wore", + "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about", + "his instant of panic was followed by a small sharp blow high on his chest", + ] + + self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS) diff --git a/tests/test_processor_speech_to_text.py b/tests/test_processor_speech_to_text.py new file mode 100644 index 00000000000000..cf26e32c1db4bf --- /dev/null +++ b/tests/test_processor_speech_to_text.py @@ -0,0 +1,144 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +from pathlib import Path +from shutil import copyfile + +from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer +from transformers.file_utils import FEATURE_EXTRACTOR_NAME +from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json +from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio + +from .test_feature_extraction_speech_to_text import floats_list + + +SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + + +@require_torch +@require_torchaudio +@require_sentencepiece +class Speech2TextProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + vocab = ["", "", "", "", "▁This", "▁is", "▁a", "▁t", "est"] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + save_dir = Path(self.tmpdirname) + save_json(vocab_tokens, save_dir / VOCAB_FILES_NAMES["vocab_file"]) + if not (save_dir / VOCAB_FILES_NAMES["spm_file"]).exists(): + copyfile(SAMPLE_SP, save_dir / VOCAB_FILES_NAMES["spm_file"]) + + tokenizer = Speech2TextTokenizer.from_pretrained(self.tmpdirname) + tokenizer.save_pretrained(self.tmpdirname) + + feature_extractor_map = { + "feature_size": 24, + "num_mel_bins": 24, + "padding_value": 0.0, + "sampling_rate": 16000, + "return_attention_mask": False, + "do_normalize": True, + } + save_json(feature_extractor_map, save_dir / FEATURE_EXTRACTOR_NAME) + + def get_tokenizer(self, **kwargs): + return Speech2TextTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_feature_extractor(self, **kwargs): + return Speech2TextFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + processor.save_pretrained(self.tmpdirname) + processor = Speech2TextProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, Speech2TextTokenizer) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, Speech2TextFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = Speech2TextProcessor( + tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor() + ) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) + + processor = Speech2TextProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, Speech2TextTokenizer) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, Speech2TextFeatureExtractor) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + raw_speech = floats_list((3, 1000)) + + input_feat_extract = feature_extractor(raw_speech, return_tensors="np") + input_processor = processor(raw_speech, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "This is a test string" + + with processor.as_target_processor(): + encoded_processor = processor(input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) diff --git a/tests/test_sequence_feature_extraction_common.py b/tests/test_sequence_feature_extraction_common.py index 8c1777553ac6bd..f375e10e19fb64 100644 --- a/tests/test_sequence_feature_extraction_common.py +++ b/tests/test_sequence_feature_extraction_common.py @@ -222,7 +222,7 @@ def test_padding_accepts_tensors_pt(self): input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name] input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name] - self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2) + self.assertTrue(abs(input_np.astype(np.float32).sum() - input_pt.numpy().sum()) < 1e-2) @require_tf def test_padding_accepts_tensors_tf(self): @@ -235,7 +235,7 @@ def test_padding_accepts_tensors_tf(self): input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name] input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name] - self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2) + self.assertTrue(abs(input_np.astype(np.float32).sum() - input_tf.numpy().sum()) < 1e-2) def test_attention_mask(self): feat_dict = self.feat_extract_dict diff --git a/tests/test_tokenization_speech_to_text.py b/tests/test_tokenization_speech_to_text.py new file mode 100644 index 00000000000000..2a42b04a5059c4 --- /dev/null +++ b/tests/test_tokenization_speech_to_text.py @@ -0,0 +1,129 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from pathlib import Path +from shutil import copyfile + +from transformers import SPIECE_UNDERLINE, is_sentencepiece_available +from transformers.models.speech_to_text import Speech2TextTokenizer +from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json +from transformers.testing_utils import require_sentencepiece, require_tokenizers + +from .test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + +if is_sentencepiece_available(): + import sentencepiece as sp + + +FR_CODE = 5 +ES_CODE = 10 + + +@require_sentencepiece +@require_tokenizers +class SpeechToTextTokenizerTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = Speech2TextTokenizer + test_rust_tokenizer = False + + def setUp(self): + super().setUp() + + spm_model = sp.SentencePieceProcessor() + spm_model.Load(SAMPLE_SP) + vocab = ["", "", "", ""] + + vocab += [spm_model.IdToPiece(id_) for id_ in range(len(spm_model))] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + + save_dir = Path(self.tmpdirname) + save_json(vocab_tokens, save_dir / VOCAB_FILES_NAMES["vocab_file"]) + if not (save_dir / VOCAB_FILES_NAMES["spm_file"]).exists(): + copyfile(SAMPLE_SP, save_dir / VOCAB_FILES_NAMES["spm_file"]) + + tokenizer = Speech2TextTokenizer.from_pretrained(self.tmpdirname) + tokenizer.save_pretrained(self.tmpdirname) + + def test_full_tokenizer(self): + tokenizer = Speech2TextTokenizer.from_pretrained(self.tmpdirname) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [289, 50, 14, 174, 386], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + # fmt: off + [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", "."], + # fmt: on + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [12, 25, 88, 59, 28, 23, 11, 4, 606, 351, 351, 351, 7, 16, 70, 50, 76, 84, 10, 4, 8]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + # fmt: off + [SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "", "."], + # fmt: on + ) + + +@require_sentencepiece +class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): + checkpoint_name = "valhalla/s2t_mustc_multilinguial_medium" + + french_text = "C'est trop cool" + spanish_text = "Esto es genial" + + @classmethod + def setUpClass(cls): + cls.tokenizer: Speech2TextTokenizer = Speech2TextTokenizer.from_pretrained(cls.checkpoint_name) + return cls + + def check_language_codes(self): + self.assertEqual(self.tokenizer.lang_code_to_id["pt"], 4) + self.assertEqual(self.tokenizer.lang_code_to_id["ru"], 6) + self.assertEqual(self.tokenizer.lang_code_to_id["it"], 9) + self.assertEqual(self.tokenizer.lang_code_to_id["de"], 11) + + def test_tokenizer_decode_ignores_language_codes(self): + self.assertIn(ES_CODE, self.tokenizer.all_special_ids) + generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2] + result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + expected_spanish = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) + self.assertEqual(result, expected_spanish) + self.assertNotIn(self.tokenizer.eos_token, result) + + def test_tokenizer_adds_special_tokens(self): + self.tokenizer.tgt_lang = "fr" + encoded = self.tokenizer(self.french_text).input_ids + self.assertEqual(encoded[0], FR_CODE) + self.assertEqual(encoded[-1], self.tokenizer.eos_token_id) + + def test_tgt_lang_setter(self): + self.tokenizer.tgt_lang = "fr" + self.assertListEqual(self.tokenizer.prefix_tokens, [FR_CODE]) + + self.tokenizer.tgt_lang = "es" + self.assertListEqual(self.tokenizer.prefix_tokens, [ES_CODE]) diff --git a/utils/check_repo.py b/utils/check_repo.py index afcc4cbd73fcbe..b64f5ae2c761b8 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -32,6 +32,8 @@ # models to ignore for not tested "M2M100Encoder", # Building part of bigger (tested) model. "M2M100Decoder", # Building part of bigger (tested) model. + "Speech2TextEncoder", # Building part of bigger (tested) model. + "Speech2TextDecoder", # Building part of bigger (tested) model. "LEDEncoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model. "BartDecoderWrapper", # Building part of bigger (tested) model. @@ -79,6 +81,8 @@ # models to ignore for model xxx mapping "M2M100Encoder", "M2M100Decoder", + "Speech2TextEncoder", + "Speech2TextDecoder", "LEDEncoder", "LEDDecoder", "BartDecoder",