diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index df97ae5d6a8f..cea15f8045d8 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -1,15 +1,12 @@ import argparse +from typing import List, Tuple -from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput -def main(args: argparse.Namespace): - # Parse the CLI argument and initialize the engine. - engine_args = EngineArgs.from_cli_args(args) - engine = LLMEngine.from_engine_args(engine_args) - - # Test the following prompts. - test_prompts = [ +def create_test_prompts() -> List[Tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters.""" + return [ ("A robot may not injure a human being", SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), ("To be or not to be,", @@ -25,22 +22,36 @@ def main(args: argparse.Namespace): temperature=0.0)), ] - # Run the engine by calling `engine.step()` manually. + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams]]): + """Continuously process a list of prompts and handle the outputs.""" request_id = 0 - while True: - # To test continuous batching, we add one request at each step. + + while test_prompts or engine.has_unfinished_requests(): if test_prompts: prompt, sampling_params = test_prompts.pop(0) engine.add_request(str(request_id), prompt, sampling_params) request_id += 1 - request_outputs = engine.step() + request_outputs: List[RequestOutput] = engine.step() + for request_output in request_outputs: if request_output.finished: print(request_output) - if not (engine.has_unfinished_requests() or test_prompts): - break + +def initialize_engine(args: argparse.Namespace) -> LLMEngine: + """Initialize the LLMEngine from the command line arguments.""" + engine_args = EngineArgs.from_cli_args(args) + return LLMEngine.from_engine_args(engine_args) + + +def main(args: argparse.Namespace): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine(args) + test_prompts = create_test_prompts() + process_requests(engine, test_prompts) if __name__ == '__main__':