Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export of Openai Whisper with batched prompts #19854

Merged
merged 9 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,13 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
# Verify batched decoding with prompts for whisper openai implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def create_dummy(
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
model_impl: str = "hf",
): # -> WhisperDecoderInputs:
"""Create dummy inputs for WhisperDecoder.

Expand Down Expand Up @@ -170,7 +171,7 @@ def create_dummy(
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
head_size,
]

Expand Down Expand Up @@ -228,6 +229,7 @@ def export_onnx(
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
model_impl=decoder.model_impl,
)
input_list = inputs.to_list()

Expand Down
184 changes: 137 additions & 47 deletions onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Dict, Tuple, Union

import datasets
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import torch
from float16 import float_to_float16_max_diff
Expand Down Expand Up @@ -313,12 +314,103 @@ def optimize_onnx(

m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)

@staticmethod
def pt_transcription_for_verify_onnx(
ds: Union[datasets.DatasetDict, datasets.Dataset, datasets.IterableDatasetDict, datasets.IterableDataset],
processor: WhisperProcessor,
pt_model: torch.nn.Module,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
input_features_ = []
if batch_size == 1:
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
else:
input_features_ = [
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
]
assert len(input_features_) == batch_size
input_features = torch.cat((input_features_[0], input_features_[1])).to(device)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
"max_length": max_length,
"min_length": min_length,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"early_stopping": True,
"use_cache": True,
}

if prompt_mode:
prompts = ["John has doubts", "Maria has grave doubts"]
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
pt_transcription = []
pt_outputs = []
for i in range(batch_size):
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
inputs["input_features"] = input_features_[i].to(device)
pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
pt_outputs.append(pt_output)
pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
inputs["input_features"] = input_features
del inputs["prompt_ids"]
else:
prompt_ids = []
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
pt_outputs = list(pt_outputs)
del inputs["early_stopping"]
del inputs["use_cache"]
return inputs, pt_transcription, pt_outputs, prompt_ids

@staticmethod
def select_transcription_options(
batch_size: int,
prompt_mode: bool,
):
if batch_size > 1 and prompt_mode is True:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_options = {
expected_transcription_no_comma_prompt1,
expected_transcription_no_comma_prompt2,
expected_transcription_misspelled_prompt1,
expected_transcription_misspelled_prompt2,
}
else:
expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
return expected_transcription_options

@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
Expand All @@ -342,30 +434,20 @@ def verify_onnx(
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
ds,
processor,
pt_model,
device,
batch_size=batch_size,
prompt_mode=prompt_mode,
)

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
"max_length": max_length,
"min_length": min_length,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"early_stopping": True,
"use_cache": True,
}
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()

del inputs["early_stopping"]
del inputs["use_cache"]
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
Expand All @@ -386,8 +468,24 @@ def verify_onnx(
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
if not prompt_mode:
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
else:
# This logic handles the scenario for when prompts are not of the same size
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
# The final decoder_input_ids will look as such after padding
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
ort_prompts = []
for i in range(batch_size):
ort_prompts.append(decoder_prompt_ids[i].tolist())
max_len = max(len(p) for p in ort_prompts)
padded_prompts = []
for p in ort_prompts:
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
padded_prompts.append(padded_prompt + forced_decoder_ids)
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
elif name == "cross_qk_layer_head":
Expand All @@ -398,36 +496,28 @@ def verify_onnx(
inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][0]

expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0]

parity = (
pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options
)
ort_outputs = ort_session.run(None, inputs)[0]
Fixed Show fixed Hide fixed
ort_transcription = []
for o in ort_outputs:
ort_transcription.append(processor.batch_decode(o, skip_special_tokens=True)[0])
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)

parity = 1
for i in range(batch_size):
parity *= (
pt_transcription[i] in expected_transcription_options
and ort_transcription[i] in expected_transcription_options
)
max_diff = 0

if not parity:
if pt_outputs.shape != ort_outputs.shape:
diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])]
else:
diff = pt_outputs - ort_outputs
max_diff = max(diff.min(), diff.max(), key=abs)
for i in range(batch_size):
if pt_outputs[i].shape != ort_outputs[i].shape:
diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
else:
diff = pt_outputs[i] - ort_outputs[i]
max_diff_i = max(diff.min(), diff.max(), key=abs)
max_diff = max(max_diff, max_diff_i)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

if max_diff != 0:
logger.warning(f"PyTorch outputs: {pt_transcription}")
Expand Down
Loading