Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLaMA end-to-end benchmarking #19985

Merged
merged 18 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add license to files and changes suggested by linter
  • Loading branch information
kunal-vaishnavi committed Mar 18, 2024
commit 3abc6fe27c85a1b849e566ce5e2a2913edac05bb
12 changes: 11 additions & 1 deletion onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
Fixed Show fixed Hide fixed
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import gc
Expand All @@ -19,6 +24,7 @@
get_msft_sample_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
Expand Down Expand Up @@ -55,7 +61,11 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
max_seq_len = (
2048
if args.benchmark_type == "ort-msft"
else 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
else 16384
if "codellama" in temp_name
else 4096
if "llama2" in temp_name
else 2048
)

if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
Expand Down
164 changes: 117 additions & 47 deletions onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import argparse
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os

import torch.distributed as dist
Expand Down
73 changes: 49 additions & 24 deletions onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import numpy as np
Expand Down Expand Up @@ -368,7 +373,7 @@ def add_io_bindings_as_tensors(
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr()
buffer_ptr=v.data_ptr(),
)
device = v.device

Expand All @@ -383,7 +388,7 @@ def add_io_bindings_as_tensors(
device_id=v.device.index,
element_type=np.float16,
shape=tuple(v.shape),
buffer_ptr=v.data_ptr()
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
Expand All @@ -393,7 +398,7 @@ def add_io_bindings_as_tensors(
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr()
buffer_ptr=v.data_ptr(),
)

return io_binding
Expand Down Expand Up @@ -437,13 +442,16 @@ def get_initial_inputs_and_outputs(
attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
position_ids = get_position_ids(attention_mask, use_past_kv=False)

tokenized_length = input_ids.shape[-1]
assert tokenized_length == requested_length

# Create inputs
inputs = {
"input_ids": input_ids.contiguous() if args.engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if args.engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if args.engine == "ort" else position_ids,
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
}
if args.engine != "ort":
if engine != "ort":
inputs["past_key_values"] = []

# Get shape of KV cache inputs
Expand All @@ -454,30 +462,47 @@ def get_initial_inputs_and_outputs(

# Create KV cache inputs
for i in range(config.num_hidden_layers):
past_key = torch.zeros(batch_size, num_heads, max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype)
past_value = torch.zeros(batch_size, num_heads, max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype)
if args.engine == "ort":
inputs.update({
f"past_key_values.{i}.key": past_key.contiguous(),
f"past_key_values.{i}.value": past_value.contiguous()
})
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
if engine == "ort":
inputs.update(
{
f"past_key_values.{i}.key": past_key.contiguous(),
f"past_key_values.{i}.value": past_value.contiguous(),
}
)
else:
inputs["past_key_values"].append((past_key, past_value))

outputs = None
if args.engine == "ort":
if engine == "ort":
# Create outputs
logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
outputs = {
"logits": logits.contiguous()
}
outputs = {"logits": logits.contiguous()}
if not use_buffer_share:
for i in range(config.num_hidden_layers):
present_key = torch.zeros(batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype)
present_value = torch.zeros(batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype)
outputs.update({
f"present.{i}.key": present_key.contiguous(),
f"present.{i}.value": present_value.contiguous()
})
present_key = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
present_value = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
outputs.update(
{f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
)

return inputs, outputs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import argparse
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse

import numpy as np
Expand Down