Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export of Openai Whisper with batched prompts #19854

Merged
merged 9 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Modularize verify_onnx
  • Loading branch information
shubhambhokare1 committed Mar 15, 2024
commit b34a04af6f79598ee8968c5ac0d82155c55de7ce
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ def main(argv=None):
with torch.no_grad():
# Verify batched decoding with prompts for whisper openai implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx_multi_batch(
args.model_name_or_path, cache_dir, ort_session, device
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
Expand Down
272 changes: 108 additions & 164 deletions onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Dict, Tuple, Union

import datasets
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import torch
from float16 import float_to_float16_max_diff
Expand Down Expand Up @@ -314,42 +315,26 @@ def optimize_onnx(
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)

@staticmethod
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
def pt_transcription_for_verify_onnx(
ds: Union[datasets.DatasetDict, datasets.Dataset, datasets.IterableDatasetDict, datasets.IterableDataset],
processor: WhisperProcessor,
pt_model: torch.nn.Module,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
if version.parse(transformers_version) >= version.parse("4.36.0"):
extra_kwargs["attn_implementation"] = "eager"
pt_model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **extra_kwargs
).to(device)
processor = WhisperProcessor.from_pretrained(model_name_or_path)
config = WhisperConfig.from_pretrained(model_name_or_path)

# Try to import `datasets` pip package
try:
from datasets import load_dataset
except Exception as e:
logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True)
install_cmd = "pip install datasets"
logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
os.system(install_cmd)

from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1
input_features_ = []
if batch_size == 1:
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
else:
input_features_ = [
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
]
assert len(input_features_) == batch_size
input_features = torch.cat((input_features_[0], input_features_[1])).to(device)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
Expand All @@ -362,85 +347,70 @@ def verify_onnx(
"early_stopping": True,
"use_cache": True,
}
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()

if prompt_mode:
prompts = ["John has doubts", "Maria has grave doubts"]
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
pt_transcription = []
pt_outputs = []
for i in range(batch_size):
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
inputs["input_features"] = input_features_[i].to(device)
pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
pt_outputs.append(pt_output)
pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
inputs["input_features"] = input_features
del inputs["prompt_ids"]
else:
prompt_ids = []
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
pt_outputs = list(pt_outputs)
del inputs["early_stopping"]
del inputs["use_cache"]
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
"tensor(int64)": np.int64,
"tensor(int32)": np.int32,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
}

use_extra_decoding_ids = "extra_decoding_ids" in ort_names
for name, dtype in zip(ort_names, ort_dtypes):
if name == "input_features":
inputs[name] = inputs[name].detach().cpu().numpy()
elif name == "vocab_mask":
inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
elif name == "cross_qk_layer_head":
inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
elif name == "extra_decoding_ids":
inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0)
elif name == "temperature":
inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0][0]

expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]
ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0]
return inputs, pt_transcription, pt_outputs, prompt_ids

parity = (
pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options
)
max_diff = 0

if not parity:
if pt_outputs.shape != ort_outputs.shape:
diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])]
else:
diff = pt_outputs - ort_outputs
max_diff = max(diff.min(), diff.max(), key=abs)

if max_diff != 0:
logger.warning(f"PyTorch outputs: {pt_transcription}")
logger.warning(f"ONNX Runtime outputs: {ort_transcription}")

return max_diff
@staticmethod
def select_transcription_options(
batch_size: int,
prompt_mode: bool,
):
if batch_size > 1 and prompt_mode is True:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_options = {
expected_transcription_no_comma_prompt1,
expected_transcription_no_comma_prompt2,
expected_transcription_misspelled_prompt1,
expected_transcription_misspelled_prompt2,
}
else:
expected_transcription_no_comma = (
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
)
expected_transcription_with_comma = (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
)
expected_transcription_with_quote_and_comma = (
' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
)
expected_transcription_options = {
expected_transcription_no_comma,
expected_transcription_with_comma,
expected_transcription_with_quote_and_comma,
}
return expected_transcription_options

@staticmethod
def verify_onnx_multi_batch(
def verify_onnx(
model_name_or_path: str,
cache_dir: str,
ort_session: InferenceSession,
device: torch.device,
batch_size: int = 1,
prompt_mode: bool = False,
):
"""Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
extra_kwargs = {}
Expand All @@ -464,42 +434,20 @@ def verify_onnx_multi_batch(
from datasets import load_dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features_ = [
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
]
input_features = torch.cat((input_features_[0], input_features_[1])).to(device)
inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
ds,
processor,
pt_model,
device,
batch_size=batch_size,
prompt_mode=prompt_mode,
)

start_id = [config.decoder_start_token_id] # ex: [50258]
prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363]
forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]

batch_size, max_length, min_length, num_beams, num_return_sequences = 2, 30, 0, 1, 1
length_penalty, repetition_penalty = 1.0, 1.0
inputs = {
"input_features": input_features.to(device),
"max_length": max_length,
"min_length": min_length,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"early_stopping": True,
"use_cache": True,
}
prompts = ["John has doubts", "Maria has grave doubts"]
prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
pt_transcription = []
for i in range(batch_size):
inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i])
inputs["input_features"] = input_features_[i].to(device)
pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
pt_transcription.append(processor.batch_decode(pt_outputs, skip_special_tokens=True)[0])
inputs["input_features"] = input_features
del inputs["prompt_ids"]
del inputs["early_stopping"]
del inputs["use_cache"]
ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs()))
ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs()))
ort_to_np = {
Expand All @@ -511,6 +459,7 @@ def verify_onnx_multi_batch(
"tensor(uint8)": np.uint8,
}

use_extra_decoding_ids = "extra_decoding_ids" in ort_names
for name, dtype in zip(ort_names, ort_dtypes):
if name == "input_features":
inputs[name] = inputs[name].detach().cpu().numpy()
Expand All @@ -519,20 +468,24 @@ def verify_onnx_multi_batch(
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
# This logic handles the scenario for when prompts are not of the same size
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
# The final decoder_input_ids will look as such after padding
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
ort_prompts = []
for i in range(batch_size):
ort_prompts.append(prompt_ids[i].tolist())
max_len = max(len(p) for p in ort_prompts)
padded_prompts = []
for p in ort_prompts:
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
padded_prompts.append(padded_prompt + forced_decoder_ids)
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
if not prompt_mode:
raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
else:
# This logic handles the scenario for when prompts are not of the same size
# For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
# The final decoder_input_ids will look as such after padding
# [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
# [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
ort_prompts = []
for i in range(batch_size):
ort_prompts.append(decoder_prompt_ids[i].tolist())
max_len = max(len(p) for p in ort_prompts)
padded_prompts = []
for p in ort_prompts:
padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
padded_prompts.append(padded_prompt + forced_decoder_ids)
inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
elif name == "cross_qk_layer_head":
Expand All @@ -544,21 +497,10 @@ def verify_onnx_multi_batch(
else:
inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
ort_outputs = ort_session.run(None, inputs)[0]
Fixed Show fixed Hide fixed

expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
expected_transcription_options = {
expected_transcription_no_comma_prompt1,
expected_transcription_no_comma_prompt2,
expected_transcription_misspelled_prompt1,
expected_transcription_misspelled_prompt2,
}
ort_outputs = ort_session.run(None, inputs)[0]
ort_transcription = []
for o in ort_outputs:
ort_transcription.append(processor.batch_decode(o, skip_special_tokens=True)[0])
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)

parity = 1
for i in range(batch_size):
Expand All @@ -569,11 +511,13 @@ def verify_onnx_multi_batch(
max_diff = 0

if not parity:
if pt_outputs.shape != ort_outputs.shape:
diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])]
else:
diff = pt_outputs - ort_outputs
max_diff = max(diff.min(), diff.max(), key=abs)
for i in range(batch_size):
if pt_outputs[i].shape != ort_outputs[i].shape:
diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])]
else:
diff = pt_outputs[i] - ort_outputs[i]
max_diff_i = max(diff.min(), diff.max(), key=abs)
max_diff = max(max_diff, max_diff_i)
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

if max_diff != 0:
logger.warning(f"PyTorch outputs: {pt_transcription}")
Expand Down
Loading