Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support serving encoder/decoder models #7258

Merged
merged 35 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
33c9e25
Introduce `is_list_of`
DarkLight1337 Aug 7, 2024
e6dd6f5
Avoid circular imports
DarkLight1337 Aug 7, 2024
f938c86
Refactor prompt parsing and extend this to async engine
DarkLight1337 Aug 7, 2024
6332d1e
Remove unnecessary comments
DarkLight1337 Aug 7, 2024
07b4d21
Enable full async
DarkLight1337 Aug 7, 2024
e29864c
grammar
DarkLight1337 Aug 7, 2024
c9dfb40
Add description
DarkLight1337 Aug 7, 2024
1233192
Fix wrong type annotations
DarkLight1337 Aug 7, 2024
f332275
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
dcdebee
Remove redundant docs
DarkLight1337 Aug 7, 2024
65db3f1
Be more strict
DarkLight1337 Aug 7, 2024
9ffeb22
Fix docs
DarkLight1337 Aug 7, 2024
c9e0b08
Fix 2
DarkLight1337 Aug 7, 2024
14bca1f
Disallow multi-modal data for enc/dec models
DarkLight1337 Aug 7, 2024
8fc7099
Improve type narrowing behavior using `TypeIs`
DarkLight1337 Aug 7, 2024
3a8a072
Avoid sequential await
DarkLight1337 Aug 7, 2024
ef5327c
Fix type annotations based on test files
DarkLight1337 Aug 7, 2024
8a835cc
Properly handle `inputs["decoder_prompt"]=None`
DarkLight1337 Aug 7, 2024
e0024c2
Clean
DarkLight1337 Aug 7, 2024
76af172
Clean
DarkLight1337 Aug 7, 2024
5c16f2e
Fix incorrect decoder inputs in singleton case
DarkLight1337 Aug 7, 2024
e239ba9
Clean
DarkLight1337 Aug 7, 2024
4b0e3df
Move functions to a more appropriate place
DarkLight1337 Aug 7, 2024
53f7f50
Remove outdated comment
DarkLight1337 Aug 7, 2024
3afdbc5
Fix mismatch between hf and vllm output text
DarkLight1337 Aug 7, 2024
c61b01f
Factor out duplicate code
DarkLight1337 Aug 7, 2024
f8ed373
Factor out more duplicate code
DarkLight1337 Aug 7, 2024
a4df70a
Remove default values to avoid accidentally miss those arguments
DarkLight1337 Aug 7, 2024
5240bb3
Add test for serving encoder/decoder model with OpenAI server
DarkLight1337 Aug 7, 2024
d321c82
Use two type variables
DarkLight1337 Aug 7, 2024
931d1f6
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
a06c67f
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
9f64a05
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 7, 2024
e4c5c21
Update error message
DarkLight1337 Aug 8, 2024
68fbf5a
Merge branch 'upstream' into inputs-parser
DarkLight1337 Aug 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
Expand Down
8 changes: 4 additions & 4 deletions examples/offline_inference_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
'''

from vllm import LLM, SamplingParams
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.utils import zip_enc_dec_prompt_lists
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, zip_enc_dec_prompts)

dtype = "float"

Expand Down Expand Up @@ -61,9 +61,9 @@
)

# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list = zip_enc_dec_prompt_lists(
zipped_prompt_list = zip_enc_dec_prompts(
['An encoder prompt', 'Another encoder prompt'],
['A decoder prompt', 'Another decoder prompt'])

Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
2 changes: 1 addition & 1 deletion requirements-lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5

# type checking
mypy==1.9.0
mypy==1.11.1
types-PyYAML
types-requests
types-setuptools
2 changes: 1 addition & 1 deletion requirements-openvino.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
Expand Down
32 changes: 19 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union

import pytest
Expand All @@ -14,20 +15,19 @@
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
BatchFeature)

from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu, to_enc_dec_tuple_list,
zip_enc_dec_prompt_lists)
is_cpu)

logger = init_logger(__name__)

Expand Down Expand Up @@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return prompts


class DecoderPromptType(Enum):
"""For encoder/decoder models only."""
CUSTOM = 1
NONE = 2
EMPTY_STR = 3


@pytest.fixture
def example_encoder_decoder_prompts() \
-> Dict[DecoderPromptType,
Tuple[List[str], List[Optional[str]]]]:
def example_encoder_decoder_prompts(
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
Expand All @@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type
return {
DecoderPromptType.NONE:
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
DecoderPromptType.EMPTY_STR:
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
DecoderPromptType.CUSTOM:
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
}


Expand Down Expand Up @@ -444,7 +450,7 @@ def generate_greedy_logprobs_limit(

def generate_encoder_decoder_greedy_logprobs_limit(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]],
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
Expand Down Expand Up @@ -608,7 +614,7 @@ def generate_w_logprobs(

def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
'''
Expand Down Expand Up @@ -653,7 +659,7 @@ def generate_greedy_logprobs(

def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: Tuple[List[str], List[str]],
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import pytest

from tests.models.utils import DecoderPromptType
from vllm.utils import cuda_device_count_stateless

from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
from ..utils import fork_new_process_for_each_test

Expand Down
38 changes: 27 additions & 11 deletions tests/models/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Run `pytest tests/models/test_bart.py`.
"""
from typing import List, Optional, Tuple

from vllm.utils import is_cpu

if not is_cpu():
Expand All @@ -11,22 +13,31 @@

import pytest

from tests.models.utils import DecoderPromptType
from vllm.sequence import SampleLogprobs

from ..conftest import DecoderPromptType
from .utils import check_logprobs_close

MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]

DECODER_PROMPT_TYPES = ([
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
DecoderPromptType.NONE
])
def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

hf_output_str = output_str + "</s>"
if decoder_prompt_type == DecoderPromptType.NONE:
hf_output_str = "<s>" + hf_output_str

return output_ids, hf_output_str, out_logprobs

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(
hf_runner,
vllm_runner,
Expand Down Expand Up @@ -146,8 +157,13 @@ def test_models(
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)

check_logprobs_close(outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
11 changes: 0 additions & 11 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings
from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import SampleLogprobs
Expand Down Expand Up @@ -136,13 +135,3 @@ def check_logprobs_close(
warnings.simplefilter("always")

warnings.warn(fail_msg, stacklevel=2)


class DecoderPromptType(Enum):
'''
For encoder/decoder models only -

'''
CUSTOM = 1
NONE = 2
EMPTY_STR = 3
2 changes: 1 addition & 1 deletion tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm.inputs import parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt

STRING_INPUTS = [
'',
Expand Down
10 changes: 10 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,16 @@ def _get_num_seqlen_agnostic_layers(
if t != "attention"
])

@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)

@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode


class CacheConfig:
"""Configuration for the KV cache.
Expand Down
Loading
Loading