Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues for Whisper export with beam search #15619

Merged

Conversation

kunal-vaishnavi
Copy link
Contributor

@kunal-vaishnavi kunal-vaishnavi commented Apr 21, 2023

Description

This PR fixes an issue with calling the ORT transformer optimizer script on the custom export of Whisper with beam search. It also includes the fix for the GPU out-of-memory issue.

Motivation and Context

With this PR fix, the optimizer runs as described in the Whisper model optimization PR.

@kunal-vaishnavi kunal-vaishnavi merged commit 3de33e0 into microsoft:main Apr 21, 2023
@shub-kris
Copy link

shub-kris commented May 4, 2023

Hi @kunal-vaishnavi, how can one load the exported ONNX models into HuggingFace optimum?

I ran the statement mentioned here: https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/whisper#exporting-whisper-with-beam-search

$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers/models/whisper
$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format

which lead to these files:

├── whisper-tiny_beamsearch.onnx
├── whisper-tiny_beamsearch.onnx.data
├── whisper-tiny_decoder.onnx
├── whisper-tiny_decoder.onnx.data
├── whisper-tiny_encoder_decoder_init.onnx
└── whisper-tiny_encoder_decoder_init.onnx.data

I am running into an error while running:

model = ORTModelForSpeechSeq2Seq.from_pretrained(
    "whispertiny/openai",
    use_io_binding=(device == 'cuda'),
    provider='CPUExecutionProvider',
).to(device)

The error is:


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 1
----> 1 model = ORTModelForSpeechSeq2Seq.from_pretrained(
      2     "whispertiny/openai",
      3     use_io_binding=(device == 'cuda'),
      4     provider='CPUExecutionProvider',
      5 ).to(device)

File /opt/conda/envs/py-39/lib/python3.9/site-packages/optimum/onnxruntime/modeling_ort.py:646, in ORTModel.from_pretrained(cls, model_id, export, force_download, use_auth_token, cache_dir, subfolder, config, local_files_only, provider, session_options, provider_options, **kwargs)
    602 @classmethod
    603 @add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING)
    604 def from_pretrained(
   (...)
    617     **kwargs,
    618 ):
    619     """
    620     provider (`str`, defaults to `"CPUExecutionProvider"`):
    621         ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ for
   (...)
    644         `ORTModel`: The loaded ORTModel model.
    645     """
--> 646     return super().from_pretrained(
    647         model_id,
    648         export=export,
    649         force_download=force_download,
    650         use_auth_token=use_auth_token,
    651         cache_dir=cache_dir,
    652         subfolder=subfolder,
    653         config=config,
    654         local_files_only=local_files_only,
    655         provider=provider,
    656         session_options=session_options,
    657         provider_options=provider_options,
    658         **kwargs,
    659     )

File /opt/conda/envs/py-39/lib/python3.9/site-packages/optimum/modeling_base.py:362, in OptimizedModel.from_pretrained(cls, model_id, export, force_download, use_auth_token, cache_dir, subfolder, config, local_files_only, trust_remote_code, revision, **kwargs)
    359     trust_remote_code = False
    361 from_pretrained_method = cls._from_transformers if export else cls._from_pretrained
--> 362 return from_pretrained_method(
    363     model_id=model_id,
    364     config=config,
    365     revision=revision,
    366     cache_dir=cache_dir,
    367     force_download=force_download,
    368     use_auth_token=use_auth_token,
    369     subfolder=subfolder,
    370     local_files_only=local_files_only,
    371     trust_remote_code=trust_remote_code,
    372     **kwargs,
    373 )

File /opt/conda/envs/py-39/lib/python3.9/site-packages/optimum/onnxruntime/modeling_seq2seq.py:1188, in ORTModelForSpeechSeq2Seq._from_pretrained(cls, model_id, config, **kwargs)
   1180 @classmethod
   1181 def _from_pretrained(
   1182     cls,
   (...)
   1185     **kwargs,
   1186 ):
   1187     if "WhisperForConditionalGeneration" in config.architectures:
-> 1188         return _ORTModelForWhisper._from_pretrained(model_id, config, **kwargs)
   1189     else:
   1190         return super()._from_pretrained(model_id, config, **kwargs)

File /opt/conda/envs/py-39/lib/python3.9/site-packages/optimum/onnxruntime/modeling_seq2seq.py:1222, in _ORTModelForWhisper._from_pretrained(cls, model_id, config, **kwargs)
   1215 @classmethod
   1216 def _from_pretrained(
   1217     cls,
   (...)
   1220     **kwargs,
   1221 ):
-> 1222     return super(ORTModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs)

File /opt/conda/envs/py-39/lib/python3.9/site-packages/optimum/onnxruntime/modeling_seq2seq.py:686, in ORTModelForConditionalGeneration._from_pretrained(cls, model_id, config, use_auth_token, revision, force_download, cache_dir, encoder_file_name, decoder_file_name, decoder_with_past_file_name, subfolder, local_files_only, use_cache, use_merged, provider, session_options, provider_options, use_io_binding, model_save_dir, **kwargs)
    684 if use_merged is False:
    685     if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
--> 686         decoder_without_past_path = ORTModelForConditionalGeneration.infer_onnx_filename(
    687             model_id,
    688             [DECODER_ONNX_FILE_PATTERN],
    689             "decoder_file_name",
    690             subfolder=subfolder,
    691             use_auth_token=use_auth_token,
    692             revision=revision,
    693         )
    694     else:
    695         decoder_without_past_path = model_path / subfolder / decoder_file_name

File /opt/conda/envs/py-39/lib/python3.9/site-packages/optimum/onnxruntime/modeling_ort.py:437, in ORTModel.infer_onnx_filename(model_name_or_path, patterns, argument_name, subfolder, use_auth_token, revision, fail_if_not_found)
    435 elif len(onnx_files) > 1:
    436     if argument_name is not None:
--> 437         raise RuntimeError(
    438             f"Too many ONNX model files were found in {path}, specify which one to load by using the "
    439             f"{argument_name} argument."
    440         )
    441 return onnx_files[0]

RuntimeError: Too many ONNX model files were found in whispertiny/openai, specify which one to load by using the decoder_file_name argument.

I added config.json into whispertiny/openai

@kunal-vaishnavi
Copy link
Contributor Author

Thank you for specifying the steps you followed.

├── whisper-tiny_beamsearch.onnx
├── whisper-tiny_beamsearch.onnx.data
├── whisper-tiny_decoder.onnx
├── whisper-tiny_decoder.onnx.data
├── whisper-tiny_encoder_decoder_init.onnx
└── whisper-tiny_encoder_decoder_init.onnx.data

Optimum exports Whisper as three models: encoder, decoder, and decoder with past. This custom export creates one combined model, whisper-tiny_beamsearch.onnx, with two subgraphs, whisper-tiny_encoder_decoder_init.onnx and whisper-tiny_decoder.onnx, that are chained together with beam search logic. The combined model can then run in ONNX Runtime directly.

model = ORTModelForSpeechSeq2Seq.from_pretrained(
    "whispertiny/openai",
    use_io_binding=(device == 'cuda'),
    provider='CPUExecutionProvider',
).to(device)

Optimum expects separate encoder, decoder, and decoder with past models as well as several JSON files to be located within whispertiny/openai. This custom export produces one combined model so it will not load in Optimum.

Optimum + ONNX Runtime

You can export Whisper using Optimum and save the generated models.

from optimum.onnxruntime import ORTModelForSpeechSeq2Seq

model = ORTModelForSpeechSeq2Seq.from_pretrained('openai/whisper-tiny', from_transformers=True)
model.save_pretrained('whisper_tiny')

Then you can optimize and run Whisper with Optimum + ONNX Runtime by following the example in the PR linked above.

@shub-kris
Copy link

shub-kris commented May 5, 2023

Thanks @kunal-vaishnavi , giving it a try.

Can you please provide an example of how one can use the whisper-tiny_beamsearch.onnx for doing inference?

@shub-kris
Copy link

I am trying this, it works but have confusion regarding the attention_mask

from datasets import load_dataset
from transformers import AutoProcessor
import onnx
import onnxruntime as ort
import numpy as np
import sys
import time

model_path, model_id = "whisper-tiny_beamsearch.onnx", "openai/whisper-tiny"
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# ds = ds.select([i for i in range(600)])
processor = AutoProcessor.from_pretrained(model_id)
ort_sess = ort.InferenceSession(model_path, provider="CPUExecutionProvider")

## warmup
for _ in range(10):
    inputs = processor.feature_extractor(ds[0]["audio"]["array"], return_tensors="np",
                                         sampling_rate=16000,
                                         return_attention_mask=True)
    # input_features = inputs['input_features'].cpu().numpy()
    # attention_mask = inputs['attention_mask'].cpu().numpy()
    attention_mask = np.tile(inputs['attention_mask'], (1, inputs['input_features'].shape[1], 1))
    
    outputs = ort_sess.run(None, {'input_features': inputs['input_features'], 
                              'max_length': [50],
                             'num_beams': [1],
                             'min_length': [50], 
                              'num_return_sequences': [1], 
                              'length_penalty': [1.0], 
                              'repetition_penalty': [1.0], 
                              'attention_mask': attention_mask                           
                             })


@kunal-vaishnavi
Copy link
Contributor Author

I am trying this, it works but have confusion regarding the attention_mask

We are planning to remove the attention_mask input since it is not used in Whisper. You can set any values for now until the attention_mask input is removed.

In your above example, you can use Hugging Face's processor.batch_decode method after obtaining the outputs from ORT.

outputs = sess.run(None, inputs)[0]

decoded = []
for b in range(batch_size):
    for r in range(num_return_sequences):
        torch_outputs = torch.from_numpy(outputs[b][r])
        decoded.append(processor.batch_decode(torch_outputs, skip_special_tokens=True)[0])

print(decoded)

@shub-kris
Copy link

Thanks a lot @kunal-vaishnavi for your help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants