Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 16, 2024
1 parent c0f0929 commit 32b03b7
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
Empty file added tests/multi_step/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions tests/multi_step/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Test the AsyncLLMEngine with multi-step-decoding and chunked prefill

from typing import List

import pytest

from ..utils import RemoteOpenAIServer

MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8, 16] # Multi-step decoding steps
NUM_PROMPTS = [10]

DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.85",
"--swap-space",
"16",
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
(2, 2),
(1, 2),
(2, 1),
])
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.asyncio
async def test_mutli_step_with_chunked_prefill(example_prompts, model: str,
tp_size: int, pp_size: int,
eager_mode: int,
num_scheduler_steps: int,
num_prompts: int):

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

if eager_mode:
ms_server_args.append("--enforce-eager")

distributed_args = [
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
]

ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations

0 comments on commit 32b03b7

Please sign in to comment.