Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Rename PromptInputs and inputs with backward compatibility #8876

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
rename PromptInputs and inputs with backward compatibility (#8760)
  • Loading branch information
DarkLight1337 committed Sep 26, 2024
commit ab5a9372749cd671e4060cd1fb009008ce00cc81
8 changes: 4 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptInputs] = [{
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand All @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptInputs
.. autodata:: vllm.inputs.PromptType

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.

To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
8 changes: 5 additions & 3 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):

@pytest.mark.asyncio
async def test_new_requests_event():
params = SamplingParams()

engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0

await engine.add_request("1", "", None)
await engine.add_request("1", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1

await engine.add_request("2", "", None)
await engine.add_request("2", "", params)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
Expand All @@ -111,7 +113,7 @@ async def test_new_requests_event():
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls

await engine.add_request("3", "", None)
await engine.add_request("3", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
Expand Down
34 changes: 0 additions & 34 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)

v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
Expand All @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)

v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()
Expand Down
37 changes: 0 additions & 37 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)

v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
Expand All @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)

v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
Expand Down
12 changes: 6 additions & 6 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ async def test_evil_forward(tmp_socket):

# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored

# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):

# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
Expand All @@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):

# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
Expand All @@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass

# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/mq_llm_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def generate(
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):

Expand Down
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
Expand All @@ -19,7 +19,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptInputs",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
Expand Down
Loading