From 550378be96448232a1e47f4a3af774fc3eca3cb7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 16 Sep 2024 17:29:14 -0400 Subject: [PATCH 01/31] Allow for processor kwarg overrides Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 21 ++++++++++ vllm/config.py | 6 ++- vllm/engine/arg_utils.py | 8 ++++ vllm/engine/llm_engine.py | 3 +- vllm/entrypoints/llm.py | 2 + vllm/inputs/registry.py | 71 ++++++++++++++++++++++++++++++---- 6 files changed, 102 insertions(+), 9 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 8dd200b35d0f..fabf37aa2a68 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected): def test_bad_nullable_kvs(arg): with pytest.raises(ArgumentTypeError): nullable_kvs(arg) + + +@pytest.mark.parametrize(("arg", "expected"), [ + (None, None), + ("{}", {}), + ('{"num_crops": 4}', { + "num_crops": 4 + }), + ('{"foo": {"bar": "baz"}}', { + "foo": { + "bar": "baz" + } + }), +]) +def test_processor_kwargs_prompt_parser(arg, expected): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + if arg is None: + args = parser.parse_args([]) + else: + args = parser.parse_args(["--processor-kwargs", arg]) + assert args.processor_kwargs == expected diff --git a/vllm/config.py b/vllm/config.py index 7a15606836dc..94552a22cc25 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,6 +122,8 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. + processor_kwargs: Arguments to be forwarded to the model's processor, + e.g., tokenizer, image processor, or custom processor callable. """ def __init__(self, @@ -150,7 +152,8 @@ def __init__(self, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, - config_format: ConfigFormat = ConfigFormat.AUTO) -> None: + config_format: ConfigFormat = ConfigFormat.AUTO, + processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -184,6 +187,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc + self.processor_kwargs = processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4139eca9c183..ca1f334de535 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -175,6 +175,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None + processor_kwargs: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -513,6 +514,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'e.g.: `image=16,video=2` allows a maximum of 16 ' 'images and 2 videos per prompt. Defaults to 1 for ' 'each modality.')) + parser.add_argument( + '--processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the model processor, e.g., tokenizer or ' + 'image processor. For example: {"num_crops": 4}.')) # LoRA related configs parser.add_argument('--enable-lora', @@ -822,6 +829,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, + processor_kwargs=self.processor_kwargs, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2743d5c7d228..a482cbbe2009 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -235,7 +235,7 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s)", + "use_async_output_proc=%s, processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -268,6 +268,7 @@ def __init__( scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + model_config.processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 248b070611cd..6304851233ce 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,6 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, + processor_kwargs=None, **kwargs, ) -> None: ''' @@ -174,6 +175,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, + processor_kwargs=processor_kwargs, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f7..eb816baa6e8c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,4 +1,5 @@ import functools +import inspect from array import array from collections import UserDict from dataclasses import dataclass @@ -245,6 +246,34 @@ def process_input(self, model_config: "ModelConfig", See also: :ref:`input_processing_pipeline` """ + processor = self._get_model_input_processor(model_config) + return processor(InputContext(model_config), inputs) + + def create_input_processor(self, model_config: "ModelConfig"): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + # Determine which kwargs can be leveraged for the input processor + # and drop + warn for kwargs that are unimplemented. + processor_kwargs = self._get_allowed_kwarg_overrides( + callable=self._get_model_input_processor(model_config), + overrides=model_config.processor_kwargs, + ) + return functools.partial(self.process_input, model_config, + **processor_kwargs) + + def _get_model_input_processor(self, + model_config: "ModelConfig") -> Callable: + """Grabs the input processor for the provided model. + + Args: + model_config: Config whose model architecture we can leverage to + grab the callable input processor. + + Returns: + Callable input processor for this model. + """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture @@ -252,12 +281,40 @@ def process_input(self, model_config: "ModelConfig", processor = self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) + return processor - return processor(InputContext(model_config), inputs) - - def create_input_processor(self, model_config: "ModelConfig"): - """ - Create an input processor (see :meth:`process_input`) for a - specific model. + def _get_allowed_kwarg_overrides( + self, + callable: Callable, + overrides: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + """Given a callable processor, determine which kwarg overrides provided + via the model config are valid keyword arguments, and drop any that + are not. + + Args: + processor: Callable processor which takes 0 or more kwargs. + model_config: Config which may contain init time processor kwargs. + + Returns: + Dictionary containing the processor kwargs to be wrapped when + creating the callable processor partial. """ - return functools.partial(self.process_input, model_config) + if not isinstance(overrides, dict): + return {} + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) + # Drop any processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_kwargs + } + + # If anything is dropped, log a warning + dropped_keys = set(overrides) - set(filtered_overrides) + if dropped_keys: + logger.warning( + "The following kwarg overrides are not implemented " + "by the input processor and will be dropped: %s", dropped_keys) + return filtered_overrides From 190606f4d619de75b81af9ff6bba5031b5837393 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 15:29:40 -0400 Subject: [PATCH 02/31] Pass processor through to partial Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index eb816baa6e8c..55d3aaec271f 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -212,7 +212,7 @@ def _default_input_processor(self, ctx: InputContext, """The default input processor is a no-op.""" return inputs - def register_input_processor(self, processor: InputProcessor): + def register_input_processor(self, processor: InputProcessor) -> Callable: """ Register an input processor to a model class. @@ -236,36 +236,42 @@ def wrapper(model_cls: N) -> N: return wrapper - def process_input(self, model_config: "ModelConfig", - inputs: LLMInputs) -> LLMInputs: + def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", + processor: Callable, **processor_kwargs) -> LLMInputs: """ - Apply an input processor to an instance of model inputs. + Apply an input processor to an instance of model inputs. This will + usually not be invoked be directly, and instead will be wrapped in + a functools partial once the processor is created. The model is identified by ``model_config``. See also: :ref:`input_processing_pipeline` """ - processor = self._get_model_input_processor(model_config) - return processor(InputContext(model_config), inputs) + return processor(InputContext(model_config), inputs, + **processor_kwargs) - def create_input_processor(self, model_config: "ModelConfig"): + def create_input_processor(self, model_config: "ModelConfig") -> Callable: """ - Create an input processor (see :meth:`process_input`) for a + Create an input processor (see :meth:`_process_input`) for a specific model. """ # Determine which kwargs can be leveraged for the input processor # and drop + warn for kwargs that are unimplemented. + processor = self._get_model_input_processor(model_config) processor_kwargs = self._get_allowed_kwarg_overrides( - callable=self._get_model_input_processor(model_config), + callable=processor, overrides=model_config.processor_kwargs, ) - return functools.partial(self.process_input, model_config, + return functools.partial(self._process_input, + model_config=model_config, + processor=processor, **processor_kwargs) def _get_model_input_processor(self, model_config: "ModelConfig") -> Callable: - """Grabs the input processor for the provided model. + """ + Grabs the input processor for the provided model. Args: model_config: Config whose model architecture we can leverage to @@ -288,7 +294,8 @@ def _get_allowed_kwarg_overrides( callable: Callable, overrides: Optional[Dict[str, Any]], ) -> Dict[str, Any]: - """Given a callable processor, determine which kwarg overrides provided + """ + Given a callable processor, determine which kwarg overrides provided via the model config are valid keyword arguments, and drop any that are not. From b1ca0417dcd0548dd7a86e0a9f5f2db5a218f1c6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 15:30:46 -0400 Subject: [PATCH 03/31] Add default & processor kwarg override tests Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 91 ++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/multimodal/test_processor.py diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py new file mode 100644 index 000000000000..f7ecb0b01ddb --- /dev/null +++ b/tests/multimodal/test_processor.py @@ -0,0 +1,91 @@ +import pytest +from vllm.inputs.registry import InputRegistry +from vllm.config import ModelConfig +from unittest.mock import patch +from vllm.inputs import InputContext, LLMInputs + +DUMMY_MODEL_ID = "facebook/opt-125m" +# For processor kwargs - we test overrides by defining a callable with a +# default for the `num_crops`, then override the value through the processor +# kwargs +DEFAULT_NUM_CROPS = 4 +NUM_CROPS_OVERRIDE = 16 + +@pytest.fixture +def processor_mock(): + """Patches the internal model input processor with an override callable.""" + def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, num_crops=DEFAULT_NUM_CROPS): + # For testing purposes, we don't worry about the llm inputs / return + # type validation, and just return the value of the kwarg that we + # clobber. + return num_crops + with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): + yield + + +def get_model_config(processor_kwargs=None): + """Creates a handle to a model config, which may have processor kwargs.""" + # NOTE - values / architecture don't matter too much here since we patch + # the return values for stuff like the input processor anyway. + return ModelConfig( + DUMMY_MODEL_ID, + DUMMY_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs + ) + + +def test_default_processor_is_a_noop(): + """Ensure that by default, there is no processor override.""" + dummy_registry = InputRegistry() + model_config = get_model_config() + processor = dummy_registry.create_input_processor(model_config) + proc_inputs = LLMInputs(prompt="foobar") + proc_outputs = processor(inputs=proc_inputs) + # We should get the same object back since this is a no-op by default + assert proc_inputs is proc_outputs + + + +def test_processor_default_kwargs(processor_mock): + """Ensure we can call a processor that has extra kwargs & no overrides.""" + dummy_registry = InputRegistry() + model_config = get_model_config() + processor = dummy_registry.create_input_processor(model_config) + # The patched fixture patches the processor to return the value of + # num_crops in the processor call, which should be 4 by default. + num_crops_val = processor(LLMInputs(prompt="foobar")) + assert num_crops_val == DEFAULT_NUM_CROPS + + +def test_processor_default_kwargs_with_override(processor_mock): + """Ensure we can call a processor that has extra kwargs & no overrides.""" + dummy_registry = InputRegistry() + # Create processor_kwargs to override the value used + # for num_crops in the patched processor callable + model_config = get_model_config( + processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE} + ) + processor = dummy_registry.create_input_processor(model_config) + num_crops_val = processor(LLMInputs(prompt="foobar")) + # Since the patched processor is an echo, we should get the + # override value we passed to processor_kwargs instead. + assert num_crops_val == NUM_CROPS_OVERRIDE + + +def test_processor_with_sad_kwarg_overrides(processor_mock): + """Ensure that processor kwargs that are unused do not fail.""" + dummy_registry = InputRegistry() + # Since the processor does not take `does_not_exist` as an arg, + # it will be filtered, then warn + drop it from the callable + # to prevent the processor from failing. + model_config = get_model_config( + processor_kwargs={"does_not_exist": 100}, + ) + + processor = dummy_registry.create_input_processor(model_config) + num_crops_val = processor(LLMInputs(prompt="foobar")) + assert num_crops_val == DEFAULT_NUM_CROPS From 195e31ccdcf93d871a0acce3d354c7ff82dcc98d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 15:49:54 -0400 Subject: [PATCH 04/31] Don't allow ctx or inputs as kwargs Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 52 +++++++++++++++++++----------- vllm/inputs/registry.py | 16 ++++++++- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index f7ecb0b01ddb..6a9f88be50b4 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest -from vllm.inputs.registry import InputRegistry + from vllm.config import ModelConfig -from unittest.mock import patch from vllm.inputs import InputContext, LLMInputs +from vllm.inputs.registry import InputRegistry DUMMY_MODEL_ID = "facebook/opt-125m" # For processor kwargs - we test overrides by defining a callable with a @@ -11,15 +13,21 @@ DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 + @pytest.fixture def processor_mock(): """Patches the internal model input processor with an override callable.""" - def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, num_crops=DEFAULT_NUM_CROPS): + + def custom_processor(ctx: InputContext, + llm_inputs: LLMInputs, + num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return # type validation, and just return the value of the kwarg that we # clobber. return num_crops - with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): + + with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", + return_value=custom_processor): yield @@ -27,15 +35,13 @@ def get_model_config(processor_kwargs=None): """Creates a handle to a model config, which may have processor kwargs.""" # NOTE - values / architecture don't matter too much here since we patch # the return values for stuff like the input processor anyway. - return ModelConfig( - DUMMY_MODEL_ID, - DUMMY_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs - ) + return ModelConfig(DUMMY_MODEL_ID, + DUMMY_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs) def test_default_processor_is_a_noop(): @@ -49,7 +55,6 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs - def test_processor_default_kwargs(processor_mock): """Ensure we can call a processor that has extra kwargs & no overrides.""" dummy_registry = InputRegistry() @@ -67,8 +72,7 @@ def test_processor_default_kwargs_with_override(processor_mock): # Create processor_kwargs to override the value used # for num_crops in the patched processor callable model_config = get_model_config( - processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE} - ) + processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE}) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt="foobar")) # Since the patched processor is an echo, we should get the @@ -82,10 +86,20 @@ def test_processor_with_sad_kwarg_overrides(processor_mock): # Since the processor does not take `does_not_exist` as an arg, # it will be filtered, then warn + drop it from the callable # to prevent the processor from failing. - model_config = get_model_config( - processor_kwargs={"does_not_exist": 100}, - ) + model_config = get_model_config(processor_kwargs={"does_not_exist": 100}, ) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt="foobar")) assert num_crops_val == DEFAULT_NUM_CROPS + + +def test_processor_kwargs_cannot_clobber_reserved_kwargs(processor_mock): + """Ensure that special kwargs cannot be overridden.""" + dummy_registry = InputRegistry() + model_config = get_model_config(processor_kwargs={"ctx": + "something bad"}, ) + processor = dummy_registry.create_input_processor(model_config) + # It's good enough to make sure this is callable, because if we had + # an override pushed through, we'd run into issues with multiple + # values provided for a single argument + processor(LLMInputs(prompt="foobar")) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 55d3aaec271f..305a0daca04a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -258,11 +258,13 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: """ # Determine which kwargs can be leveraged for the input processor # and drop + warn for kwargs that are unimplemented. + # NOTE: we don't allow override values for ctx/inputs, since doing + # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) processor_kwargs = self._get_allowed_kwarg_overrides( callable=processor, overrides=model_config.processor_kwargs, - ) + immutable_kwargs=("ctx", "inputs")) return functools.partial(self._process_input, model_config=model_config, processor=processor, @@ -293,6 +295,7 @@ def _get_allowed_kwarg_overrides( self, callable: Callable, overrides: Optional[Dict[str, Any]], + immutable_kwargs: Optional[Tuple[str, ...]], ) -> Dict[str, Any]: """ Given a callable processor, determine which kwarg overrides provided @@ -302,6 +305,7 @@ def _get_allowed_kwarg_overrides( Args: processor: Callable processor which takes 0 or more kwargs. model_config: Config which may contain init time processor kwargs. + immutable_kwargs: Reserved kwarg keys that can't be overridden. Returns: Dictionary containing the processor kwargs to be wrapped when @@ -309,6 +313,15 @@ def _get_allowed_kwarg_overrides( """ if not isinstance(overrides, dict): return {} + + if immutable_kwargs: + for name in immutable_kwargs: + if name in overrides: + logger.warning( + "%s is a reserved kwarg and will be dropped " + "from the input processor overrides", name) + del overrides[name] + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) # Drop any processor_kwargs provided by the user that are # not kwarg names accepted by the provided input processor. @@ -324,4 +337,5 @@ def _get_allowed_kwarg_overrides( logger.warning( "The following kwarg overrides are not implemented " "by the input processor and will be dropped: %s", dropped_keys) + return filtered_overrides From 1472d0438edc6ddebfc7fc8991c8504598d49718 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 17 Sep 2024 17:08:42 -0400 Subject: [PATCH 05/31] Add kwarg override for processor to dummy data factories Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 305a0daca04a..37ded8edd694 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -74,12 +74,16 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], + **processor_kwargs: Any, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. Note: :data:`InputProcessor` is not applied to the dummy data. + The processor_kwargs are overrides provided at initialization + time to values in the config whose values may affect the number + of tokens per instance. """ ... @@ -185,10 +189,17 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + # Check to see if this model expects additional processor kwargs; + # even though the processor isn't used on the dummy data, values + # passed to it that override the config may have implications on + # the number dummy data, e.g., the number of image tokens per instance. + df_kwargs = self._get_dummy_factory_processor_kwargs( + model_config, dummy_factory) seq_data, mm_data = dummy_factory( InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), + **df_kwargs, ) # Having more tokens is over-conservative but otherwise fine @@ -207,6 +218,21 @@ def dummy_data_for_profiling( return seq_data, mm_data + def _get_dummy_factory_processor_kwargs( + self, model_config: "ModelConfig", + dummy_factory: Callable) -> Dict[str, Any]: + # Dummy factory takes no additional kwargs; presumably this means that + # image processor kwargs have either not been implemented, or they have + # no affect on the token counts. + if len(inspect.signature(dummy_factory).parameters) < 4: + return {} + # Otherwise we may have overrides; filter them in the + # same way we filter the input processor overrides + return self._get_allowed_kwarg_overrides( + callable=dummy_factory, + overrides=model_config.processor_kwargs, + immutable_kwargs=("ctx", "seq_len", "mm_counts")) + def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: """The default input processor is a no-op.""" From f10601fae34d71dbae12cecbec5bd88361aae392 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 01:13:12 -0400 Subject: [PATCH 06/31] Add kwarg override forr processor to max token calc Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 102 ++++++++++++++++++++-------------------- vllm/multimodal/base.py | 8 +++- 2 files changed, 58 insertions(+), 52 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 37ded8edd694..7393a883778d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -28,6 +28,55 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" +def get_allowed_kwarg_overrides( + callable: Callable, + overrides: Optional[Dict[str, Any]], + immutable_kwargs: Optional[Tuple[str, ...]], +) -> Dict[str, Any]: + """ + Given a callable processor, determine which kwarg overrides provided + via the model config are valid keyword arguments, and drop any that + are not. + + Args: + processor: Callable processor which takes 0 or more kwargs. + model_config: Config which may contain init time processor kwargs. + immutable_kwargs: Reserved kwarg keys that can't be overridden. + + Returns: + Dictionary containing the processor kwargs to be wrapped when + creating the callable processor partial. + """ + if not isinstance(overrides, dict): + return {} + + if immutable_kwargs: + for name in immutable_kwargs: + if name in overrides: + logger.warning( + "%s is a reserved kwarg and will be dropped " + "from the input processor overrides", name) + del overrides[name] + + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) + # Drop any processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_kwargs + } + + # If anything is dropped, log a warning + dropped_keys = set(overrides) - set(filtered_overrides) + if dropped_keys: + logger.warning( + "The following kwarg overrides are not implemented " + "by the input processor and will be dropped: %s", dropped_keys) + + return filtered_overrides + + @dataclass(frozen=True) class InputContext: """ @@ -228,7 +277,7 @@ def _get_dummy_factory_processor_kwargs( return {} # Otherwise we may have overrides; filter them in the # same way we filter the input processor overrides - return self._get_allowed_kwarg_overrides( + return get_allowed_kwarg_overrides( callable=dummy_factory, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "seq_len", "mm_counts")) @@ -287,7 +336,7 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: # NOTE: we don't allow override values for ctx/inputs, since doing # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) - processor_kwargs = self._get_allowed_kwarg_overrides( + processor_kwargs = get_allowed_kwarg_overrides( callable=processor, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "inputs")) @@ -316,52 +365,3 @@ def _get_model_input_processor(self, processor = self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) return processor - - def _get_allowed_kwarg_overrides( - self, - callable: Callable, - overrides: Optional[Dict[str, Any]], - immutable_kwargs: Optional[Tuple[str, ...]], - ) -> Dict[str, Any]: - """ - Given a callable processor, determine which kwarg overrides provided - via the model config are valid keyword arguments, and drop any that - are not. - - Args: - processor: Callable processor which takes 0 or more kwargs. - model_config: Config which may contain init time processor kwargs. - immutable_kwargs: Reserved kwarg keys that can't be overridden. - - Returns: - Dictionary containing the processor kwargs to be wrapped when - creating the callable processor partial. - """ - if not isinstance(overrides, dict): - return {} - - if immutable_kwargs: - for name in immutable_kwargs: - if name in overrides: - logger.warning( - "%s is a reserved kwarg and will be dropped " - "from the input processor overrides", name) - del overrides[name] - - allowed_kwargs = list(inspect.signature(callable).parameters.keys()) - # Drop any processor_kwargs provided by the user that are - # not kwarg names accepted by the provided input processor. - filtered_overrides = { - kwarg_name: val - for kwarg_name, val in overrides.items() - if kwarg_name in allowed_kwargs - } - - # If anything is dropped, log a warning - dropped_keys = set(overrides) - set(filtered_overrides) - if dropped_keys: - logger.warning( - "The following kwarg overrides are not implemented " - "by the input processor and will be dropped: %s", dropped_keys) - - return filtered_overrides diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 032964fe0ac4..0623de1d523d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -13,6 +13,7 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext +from vllm.inputs.registry import get_allowed_kwarg_overrides from vllm.logger import init_logger from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -333,7 +334,12 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - max_mm_tokens = max_mm_tokens(InputContext(model_config)) + processor_kwargs = get_allowed_kwarg_overrides( + callable=max_mm_tokens, + overrides=model_config.processor_kwargs, + immutable_kwargs=("ctx",)) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) From 429097aa71c2cd5984a4e26df895e88863c0a36b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 01:18:04 -0400 Subject: [PATCH 07/31] Move kwarg only override func to utils Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 54 +++-------------------------------------- vllm/multimodal/base.py | 8 +++--- vllm/utils.py | 50 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 55 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7393a883778d..36682adaed40 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,6 +11,7 @@ from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.utils import get_allowed_kwarg_only_overrides from .data import LLMInputs @@ -28,55 +29,6 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" -def get_allowed_kwarg_overrides( - callable: Callable, - overrides: Optional[Dict[str, Any]], - immutable_kwargs: Optional[Tuple[str, ...]], -) -> Dict[str, Any]: - """ - Given a callable processor, determine which kwarg overrides provided - via the model config are valid keyword arguments, and drop any that - are not. - - Args: - processor: Callable processor which takes 0 or more kwargs. - model_config: Config which may contain init time processor kwargs. - immutable_kwargs: Reserved kwarg keys that can't be overridden. - - Returns: - Dictionary containing the processor kwargs to be wrapped when - creating the callable processor partial. - """ - if not isinstance(overrides, dict): - return {} - - if immutable_kwargs: - for name in immutable_kwargs: - if name in overrides: - logger.warning( - "%s is a reserved kwarg and will be dropped " - "from the input processor overrides", name) - del overrides[name] - - allowed_kwargs = list(inspect.signature(callable).parameters.keys()) - # Drop any processor_kwargs provided by the user that are - # not kwarg names accepted by the provided input processor. - filtered_overrides = { - kwarg_name: val - for kwarg_name, val in overrides.items() - if kwarg_name in allowed_kwargs - } - - # If anything is dropped, log a warning - dropped_keys = set(overrides) - set(filtered_overrides) - if dropped_keys: - logger.warning( - "The following kwarg overrides are not implemented " - "by the input processor and will be dropped: %s", dropped_keys) - - return filtered_overrides - - @dataclass(frozen=True) class InputContext: """ @@ -277,7 +229,7 @@ def _get_dummy_factory_processor_kwargs( return {} # Otherwise we may have overrides; filter them in the # same way we filter the input processor overrides - return get_allowed_kwarg_overrides( + return get_allowed_kwarg_only_overrides( callable=dummy_factory, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "seq_len", "mm_counts")) @@ -336,7 +288,7 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: # NOTE: we don't allow override values for ctx/inputs, since doing # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) - processor_kwargs = get_allowed_kwarg_overrides( + processor_kwargs = get_allowed_kwarg_only_overrides( callable=processor, overrides=model_config.processor_kwargs, immutable_kwargs=("ctx", "inputs")) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 0623de1d523d..7af50d4a55aa 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -13,9 +13,9 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext -from vllm.inputs.registry import get_allowed_kwarg_overrides from vllm.logger import init_logger -from vllm.utils import JSONTree, is_list_of, json_map_leaves +from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, + json_map_leaves) logger = init_logger(__name__) @@ -334,10 +334,10 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - processor_kwargs = get_allowed_kwarg_overrides( + processor_kwargs = get_allowed_kwarg_only_overrides( callable=max_mm_tokens, overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx",)) + immutable_kwargs=("ctx", )) max_mm_tokens = max_mm_tokens(InputContext(model_config), **processor_kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index 060b387ec783..45db573809bb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,6 +4,7 @@ import datetime import enum import gc +import inspect import os import random import socket @@ -1237,6 +1238,55 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +def get_allowed_kwarg_only_overrides( + callable: Callable, + overrides: Optional[Dict[str, Any]], + immutable_kwargs: Optional[Tuple[str, ...]], +) -> Dict[str, Any]: + """ + Given a callable processor, determine which kwarg overrides provided + via the model config are valid keyword arguments, and drop any that + are not. + + Args: + processor: Callable processor which takes 0 or more kwargs. + model_config: Config which may contain init time processor kwargs. + immutable_kwargs: Reserved kwarg keys that can't be overridden. + + Returns: + Dictionary containing the processor kwargs to be wrapped when + creating the callable processor partial. + """ + if not isinstance(overrides, dict): + return {} + + if immutable_kwargs: + for name in immutable_kwargs: + if name in overrides: + logger.warning( + "%s is a reserved kwarg and will be dropped " + "from the input processor overrides", name) + del overrides[name] + + allowed_kwargs = list(inspect.signature(callable).parameters.keys()) + # Drop any processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_kwargs + } + + # If anything is dropped, log a warning + dropped_keys = set(overrides) - set(filtered_overrides) + if dropped_keys: + logger.warning( + "The following kwarg overrides are not implemented " + "by the input processor and will be dropped: %s", dropped_keys) + + return filtered_overrides + + # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. From 159cfc26d5a2153e4ef1b450611b99475d83a19c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 01:41:21 -0400 Subject: [PATCH 08/31] Force processor kwargs to be keyword-only Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 2 ++ vllm/inputs/registry.py | 8 ++----- vllm/multimodal/base.py | 3 +-- vllm/utils.py | 38 ++++++++++++++---------------- 4 files changed, 23 insertions(+), 28 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index 6a9f88be50b4..c38bf9078356 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -18,8 +18,10 @@ def processor_mock(): """Patches the internal model input processor with an override callable.""" + # NOTE: processor kwargs must be keyword-only. def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, + *, num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return # type validation, and just return the value of the kwarg that we diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 36682adaed40..7d669c311591 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -230,9 +230,7 @@ def _get_dummy_factory_processor_kwargs( # Otherwise we may have overrides; filter them in the # same way we filter the input processor overrides return get_allowed_kwarg_only_overrides( - callable=dummy_factory, - overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx", "seq_len", "mm_counts")) + callable=dummy_factory, overrides=model_config.processor_kwargs) def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: @@ -289,9 +287,7 @@ def create_input_processor(self, model_config: "ModelConfig") -> Callable: # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) processor_kwargs = get_allowed_kwarg_only_overrides( - callable=processor, - overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx", "inputs")) + callable=processor, overrides=model_config.processor_kwargs) return functools.partial(self._process_input, model_config=model_config, processor=processor, diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 7af50d4a55aa..b0118c71c26a 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -336,8 +336,7 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: if callable(max_mm_tokens): processor_kwargs = get_allowed_kwarg_only_overrides( callable=max_mm_tokens, - overrides=model_config.processor_kwargs, - immutable_kwargs=("ctx", )) + overrides=model_config.processor_kwargs) max_mm_tokens = max_mm_tokens(InputContext(model_config), **processor_kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index 45db573809bb..22c6804246a5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1241,48 +1241,46 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def get_allowed_kwarg_only_overrides( callable: Callable, overrides: Optional[Dict[str, Any]], - immutable_kwargs: Optional[Tuple[str, ...]], ) -> Dict[str, Any]: """ - Given a callable processor, determine which kwarg overrides provided - via the model config are valid keyword arguments, and drop any that - are not. + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. Args: - processor: Callable processor which takes 0 or more kwargs. - model_config: Config which may contain init time processor kwargs. - immutable_kwargs: Reserved kwarg keys that can't be overridden. + callable: Callable which takes 0 or more keyword only arguments. + overrides: Potential overrides to be used when invoking the callable. Returns: - Dictionary containing the processor kwargs to be wrapped when - creating the callable processor partial. + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. """ if not isinstance(overrides, dict): return {} - if immutable_kwargs: - for name in immutable_kwargs: - if name in overrides: - logger.warning( - "%s is a reserved kwarg and will be dropped " - "from the input processor overrides", name) - del overrides[name] + allowed_override_names = [ + name for name, param in inspect.signature(callable).parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] - allowed_kwargs = list(inspect.signature(callable).parameters.keys()) # Drop any processor_kwargs provided by the user that are # not kwarg names accepted by the provided input processor. filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if kwarg_name in allowed_kwargs + if kwarg_name in allowed_override_names } # If anything is dropped, log a warning dropped_keys = set(overrides) - set(filtered_overrides) if dropped_keys: logger.warning( - "The following kwarg overrides are not implemented " - "by the input processor and will be dropped: %s", dropped_keys) + "The following intended overrides are not keyword-only args " + "and and will be dropped: %s", dropped_keys) return filtered_overrides From af919301fc9eec8c6962394907d2abc9971d2967 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 03:05:04 -0400 Subject: [PATCH 09/31] Pass unfiltered processor kwargs to default mapper Signed-off-by: Alex-Brooks --- vllm/multimodal/image.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 6cdde949bc2b..137a574c7d1a 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Any, Dict import torch from PIL import Image @@ -22,10 +23,12 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig): + def _get_hf_image_processor(self, model_config: ModelConfig, + processor_kwargs: Dict[str, Any]): return cached_get_image_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **processor_kwargs) def _default_input_mapper( self, @@ -36,7 +39,13 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor(model_config) + processor_kwargs = ({} if model_config.processor_kwargs is None + else model_config.processor_kwargs) + + image_processor = self._get_hf_image_processor( + model_config, + processor_kwargs, + ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") From 9adad10623580ad048ad17ae236d06b60c668969 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 03:57:25 -0400 Subject: [PATCH 10/31] Add hack for mapper preprocessor kwargs Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 12 +++++------- vllm/multimodal/base.py | 5 ++++- vllm/multimodal/registry.py | 8 ++++++++ vllm/utils.py | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7d669c311591..d422ddedd38a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -194,14 +194,12 @@ def dummy_data_for_profiling( # even though the processor isn't used on the dummy data, values # passed to it that override the config may have implications on # the number dummy data, e.g., the number of image tokens per instance. - df_kwargs = self._get_dummy_factory_processor_kwargs( + processor_kwargs = self._get_dummy_factory_processor_kwargs( model_config, dummy_factory) - seq_data, mm_data = dummy_factory( - InputContext(model_config), - seq_len, - _MultiModalCounts(mm_counts), - **df_kwargs, - ) + + seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index b0118c71c26a..d0d11e6cbda7 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -257,11 +257,14 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) + processor_kwargs = get_allowed_kwarg_only_overrides( + callable=mapper, overrides=model_config.processor_kwargs) + if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(InputContext(model_config), data) + return mapper(InputContext(model_config), data, **processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 745fc715caf4..f1c56226e044 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,6 +138,14 @@ def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ + # TODO - there is a bit of weirdness here in the way mapper handles + # the args, because for the HF one, we pass processor_kwargs at init + # time and don't need them at func time, for the function's we are + # wrapping in processor like interfaces, we pass them at the time + # of invocation. + # + # Currently it works, but warns when the default processor is used, + # which is bad. return functools.partial(self.map_input, model_config) def register_max_multimodal_tokens( diff --git a/vllm/utils.py b/vllm/utils.py index 22c6804246a5..e1b8ccfd6aad 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1239,7 +1239,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def get_allowed_kwarg_only_overrides( - callable: Callable, + callable: Optional[Callable], overrides: Optional[Dict[str, Any]], ) -> Dict[str, Any]: """ @@ -1259,7 +1259,7 @@ def get_allowed_kwarg_only_overrides( to overwrite one or more keyword only arguments when invoking the callable. """ - if not isinstance(overrides, dict): + if not overrides or not callable: return {} allowed_override_names = [ From 9f7aed8e7e6db89146668c199a8b3926834e6797 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 18:38:18 -0400 Subject: [PATCH 11/31] Simplify dummy data processor kwarg & add tests Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 167 +++++++++++++++++++---------- vllm/inputs/registry.py | 22 +--- 2 files changed, 114 insertions(+), 75 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index c38bf9078356..cc56f26795a0 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -1,3 +1,5 @@ +from array import array +from typing import Mapping from unittest.mock import patch import pytest @@ -6,19 +8,35 @@ from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry +from vllm.multimodal import MultiModalRegistry +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + + + DUMMY_MODEL_ID = "facebook/opt-125m" -# For processor kwargs - we test overrides by defining a callable with a -# default for the `num_crops`, then override the value through the processor -# kwargs +# For processor_kwargs - we test overrides by defining mocks for each place +# it is used, and ensuring that we can pass processor kwargs an override value +# to receive the intended result for things like sequence length etc. DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 +def get_model_config(processor_kwargs=None): + """Creates a handle to a model config, which may have processor kwargs.""" + # NOTE - values / architecture don't matter too much here since we patch + # the return values for stuff like the input processor anyway. + return ModelConfig(DUMMY_MODEL_ID, + DUMMY_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs) +# Mocks for all of the places that we use the processor_kwargs +# to override values in different callables @pytest.fixture -def processor_mock(): +def use_processor_mock(): """Patches the internal model input processor with an override callable.""" - - # NOTE: processor kwargs must be keyword-only. def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, *, @@ -32,76 +50,111 @@ def custom_processor(ctx: InputContext, return_value=custom_processor): yield - -def get_model_config(processor_kwargs=None): - """Creates a handle to a model config, which may have processor kwargs.""" - # NOTE - values / architecture don't matter too much here since we patch - # the return values for stuff like the input processor anyway. - return ModelConfig(DUMMY_MODEL_ID, - DUMMY_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs) +@pytest.fixture +def use_dummy_data_mock(): + """Patches the internal model input processor with an override callable.""" + def custom_dummy_data_factory(self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops=DEFAULT_NUM_CROPS): + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + return seq_data, None + + with patch("vllm.inputs.registry.InputRegistry._default_dummy_data_factory", + custom_dummy_data_factory): + yield +### Test for default processor logic & processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() model_config = get_model_config() processor = dummy_registry.create_input_processor(model_config) - proc_inputs = LLMInputs(prompt="foobar") + proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) - # We should get the same object back since this is a no-op by default assert proc_inputs is proc_outputs - -def test_processor_default_kwargs(processor_mock): - """Ensure we can call a processor that has extra kwargs & no overrides.""" +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_processor_default_kwargs(use_processor_mock, num_crops): + """Ensure that we can override processor kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config() + # If we have a value for num_crops, pass the override value and make + # sure we get that value as a return-value from out mock processor, + # otherwise fall back to the default value + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops + model_config = get_model_config(processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) - # The patched fixture patches the processor to return the value of - # num_crops in the processor call, which should be 4 by default. - num_crops_val = processor(LLMInputs(prompt="foobar")) - assert num_crops_val == DEFAULT_NUM_CROPS - -def test_processor_default_kwargs_with_override(processor_mock): - """Ensure we can call a processor that has extra kwargs & no overrides.""" - dummy_registry = InputRegistry() - # Create processor_kwargs to override the value used - # for num_crops in the patched processor callable - model_config = get_model_config( - processor_kwargs={"num_crops": NUM_CROPS_OVERRIDE}) - processor = dummy_registry.create_input_processor(model_config) - num_crops_val = processor(LLMInputs(prompt="foobar")) - # Since the patched processor is an echo, we should get the - # override value we passed to processor_kwargs instead. - assert num_crops_val == NUM_CROPS_OVERRIDE + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + assert num_crops_val == expected_num_crops -def test_processor_with_sad_kwarg_overrides(processor_mock): - """Ensure that processor kwargs that are unused do not fail.""" +@pytest.mark.parametrize("processor_kwargs", + [ + {"does_not_exist": 100}, # Not part of the signature + {"ctx": "something bad"} # Part of the signature, not keyword only + ] +) +def test_processor_with_sad_kwarg_overrides(use_processor_mock, + processor_kwargs): + """Ensure invalid processor_kwargs can't be used in the input processor.""" dummy_registry = InputRegistry() - # Since the processor does not take `does_not_exist` as an arg, - # it will be filtered, then warn + drop it from the callable - # to prevent the processor from failing. - model_config = get_model_config(processor_kwargs={"does_not_exist": 100}, ) + + model_config = get_model_config(processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) - num_crops_val = processor(LLMInputs(prompt="foobar")) + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == DEFAULT_NUM_CROPS -def test_processor_kwargs_cannot_clobber_reserved_kwargs(processor_mock): - """Ensure that special kwargs cannot be overridden.""" +### Test overrides for the dummy data +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs={"ctx": - "something bad"}, ) - processor = dummy_registry.create_input_processor(model_config) - # It's good enough to make sure this is callable, because if we had - # an override pushed through, we'd run into issues with multiple - # values provided for a single argument - processor(LLMInputs(prompt="foobar")) + model_config = get_model_config( + processor_kwargs=processor_kwargs, + ) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + model_config, + seq_len=-1, + mm_registry=mm_registry + ) + assert len(seq_data.prompt_token_ids) == expected_seq_count + + +@pytest.mark.parametrize("processor_kwargs", + [ + {"does_not_exist": 100}, # Not part of the signature + {"ctx": "something bad"} # Part of the signature, not keyword only + ] +) +def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): + """Ensure that dummy_data kwargs that are unused do not fail.""" + dummy_registry = InputRegistry() + model_config = get_model_config( + processor_kwargs=processor_kwargs, + ) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + model_config, + seq_len=-1, + mm_registry=mm_registry + ) + assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index d422ddedd38a..7b74c24d7344 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -190,12 +190,10 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - # Check to see if this model expects additional processor kwargs; - # even though the processor isn't used on the dummy data, values - # passed to it that override the config may have implications on - # the number dummy data, e.g., the number of image tokens per instance. - processor_kwargs = self._get_dummy_factory_processor_kwargs( - model_config, dummy_factory) + processor_kwargs = get_allowed_kwarg_only_overrides( + callable=dummy_factory, + overrides=model_config.processor_kwargs + ) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), @@ -217,18 +215,6 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _get_dummy_factory_processor_kwargs( - self, model_config: "ModelConfig", - dummy_factory: Callable) -> Dict[str, Any]: - # Dummy factory takes no additional kwargs; presumably this means that - # image processor kwargs have either not been implemented, or they have - # no affect on the token counts. - if len(inspect.signature(dummy_factory).parameters) < 4: - return {} - # Otherwise we may have overrides; filter them in the - # same way we filter the input processor overrides - return get_allowed_kwarg_only_overrides( - callable=dummy_factory, overrides=model_config.processor_kwargs) def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: From ff59e44c351ed34f4c2e14ea136752bae7d93856 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 19:58:05 -0400 Subject: [PATCH 12/31] Add tests for max multimodal token kwarg overrides Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 146 +++++++++++++++++++++++------ 1 file changed, 115 insertions(+), 31 deletions(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index cc56f26795a0..f94a4f5abce6 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -7,19 +7,22 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry - from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.model_executor.models.phi3v import Phi3VForCausalLM - - +# Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" +# Used for tests that need a multimodal model +MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + # For processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 + def get_model_config(processor_kwargs=None): """Creates a handle to a model config, which may have processor kwargs.""" # NOTE - values / architecture don't matter too much here since we patch @@ -32,11 +35,13 @@ def get_model_config(processor_kwargs=None): seed=0, processor_kwargs=processor_kwargs) + # Mocks for all of the places that we use the processor_kwargs # to override values in different callables @pytest.fixture def use_processor_mock(): """Patches the internal model input processor with an override callable.""" + def custom_processor(ctx: InputContext, llm_inputs: LLMInputs, *, @@ -50,23 +55,31 @@ def custom_processor(ctx: InputContext, return_value=custom_processor): yield + @pytest.fixture def use_dummy_data_mock(): """Patches the internal model input processor with an override callable.""" + def custom_dummy_data_factory(self, ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], *, num_crops=DEFAULT_NUM_CROPS): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) return seq_data, None - with patch("vllm.inputs.registry.InputRegistry._default_dummy_data_factory", - custom_dummy_data_factory): + with patch( + "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", + custom_dummy_data_factory): yield +# lambda whose signature matches max token calcs + extra kwargs +get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops + + ### Test for default processor logic & processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" @@ -77,6 +90,7 @@ def test_default_processor_is_a_noop(): proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs + @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_processor_default_kwargs(use_processor_mock, num_crops): """Ensure that we can override processor kwargs.""" @@ -93,12 +107,16 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): assert num_crops_val == expected_num_crops -@pytest.mark.parametrize("processor_kwargs", +@pytest.mark.parametrize( + "processor_kwargs", [ - {"does_not_exist": 100}, # Not part of the signature - {"ctx": "something bad"} # Part of the signature, not keyword only - ] -) + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): """Ensure invalid processor_kwargs can't be used in the input processor.""" @@ -117,9 +135,7 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config( - processor_kwargs=processor_kwargs, - ) + model_config = get_model_config(processor_kwargs=processor_kwargs, ) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -127,25 +143,25 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, - seq_len=-1, - mm_registry=mm_registry - ) + model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count -@pytest.mark.parametrize("processor_kwargs", +@pytest.mark.parametrize( + "processor_kwargs", [ - {"does_not_exist": 100}, # Not part of the signature - {"ctx": "something bad"} # Part of the signature, not keyword only - ] -) -def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) +def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, + processor_kwargs): """Ensure that dummy_data kwargs that are unused do not fail.""" dummy_registry = InputRegistry() - model_config = get_model_config( - processor_kwargs=processor_kwargs, - ) + model_config = get_model_config(processor_kwargs=processor_kwargs, ) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -153,8 +169,76 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwar # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, - seq_len=-1, - mm_registry=mm_registry - ) + model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + + +### Test overrides for the max token count per multimodal instance +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_max_tokens_kwarg_overrides(num_crops): + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the processor_kwargs. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {Phi3VForCausalLM: get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + model_config) + + assert expected_seq_count == max_multimodal_tokens + + +@pytest.mark.parametrize( + "processor_kwargs", + [ + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) +def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + # Similar before, but since these kwargs get filtered, + # we always get our default value back. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {Phi3VForCausalLM: get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + model_config) + + assert max_multimodal_tokens == DEFAULT_NUM_CROPS From 6b264547e55c943faf8924955d5e9635d6189d1d Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 20:52:15 -0400 Subject: [PATCH 13/31] Format registry Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7b74c24d7344..08f516e17aef 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,5 +1,4 @@ import functools -import inspect from array import array from collections import UserDict from dataclasses import dataclass @@ -191,9 +190,7 @@ def dummy_data_for_profiling( mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) processor_kwargs = get_allowed_kwarg_only_overrides( - callable=dummy_factory, - overrides=model_config.processor_kwargs - ) + callable=dummy_factory, overrides=model_config.processor_kwargs) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), @@ -215,7 +212,6 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: """The default input processor is a no-op.""" From 0e2d53d9baacb6dc0c77e1299bbdd77d4c1c1257 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:06:34 -0400 Subject: [PATCH 14/31] Fix default mapper comparison Signed-off-by: Alex-Brooks --- vllm/multimodal/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index d0d11e6cbda7..06cbb528c34b 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -257,8 +257,14 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=mapper, overrides=model_config.processor_kwargs) + # Only get processor kwargs at mapping time if we are not using the + # input mapper; no overrides are used on the default here because they + # should be passed to the huggingface resource at initialization time. + if mapper != self._default_input_mapper: + processor_kwargs = get_allowed_kwarg_only_overrides( + callable=mapper, overrides=model_config.processor_kwargs) + else: + processor_kwargs = {} if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " From 5a3341bd75c3d1ef8b3d696736d34d927085b7d2 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:07:02 -0400 Subject: [PATCH 15/31] Move kwarg filtering into hf processor getter Signed-off-by: Alex-Brooks --- vllm/multimodal/image.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 137a574c7d1a..c2657a112173 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import Any, Dict import torch from PIL import Image @@ -23,8 +22,11 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig, - processor_kwargs: Dict[str, Any]): + def _get_hf_image_processor(self, model_config: ModelConfig): + processor_kwargs = ({} if model_config.processor_kwargs is None else + model_config.processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -39,13 +41,8 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - processor_kwargs = ({} if model_config.processor_kwargs is None - else model_config.processor_kwargs) + image_processor = self._get_hf_image_processor(model_config) - image_processor = self._get_hf_image_processor( - model_config, - processor_kwargs, - ) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") From 3e1fe54acc605b7d4f62a5e0235dcb147881e8f5 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:08:01 -0400 Subject: [PATCH 16/31] Enable processor_kwargs in video processor Signed-off-by: Alex-Brooks --- vllm/multimodal/video.py | 7 ++++++- vllm/transformers_utils/image_processor.py | 9 ++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 4401d1315792..aff920977662 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -37,9 +37,14 @@ def get_data_key(self) -> str: return "video" def _get_hf_video_processor(self, model_config: ModelConfig): + processor_kwargs = ({} if model_config.processor_kwargs is None else + model_config.processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_video_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 4cffac3724ba..61b338972e0f 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -3,7 +3,9 @@ def get_video_processor( processor_name: str, + *args, trust_remote_code: bool = False, + **kwargs, ): """ Gets a processor for the given model name via HuggingFace. @@ -11,7 +13,12 @@ def get_video_processor( from transformers import AutoProcessor try: - processor = AutoProcessor.from_pretrained(processor_name) + processor = AutoProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) video_processor = processor.video_processor except ValueError as e: From feccfd7c7575b121ad18b6bcec4d27a003b80a46 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:08:38 -0400 Subject: [PATCH 17/31] Add tests for mapper processor_kwargs Signed-off-by: Alex-Brooks --- tests/multimodal/test_processor.py | 113 ++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index f94a4f5abce6..c86fa5d3c7e4 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -3,13 +3,14 @@ from unittest.mock import patch import pytest +import torch from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry +from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData -from vllm.model_executor.models.phi3v import Phi3VForCausalLM # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" @@ -78,6 +79,9 @@ def custom_dummy_data_factory(self, # lambda whose signature matches max token calcs + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops +custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { + "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +} ### Test for default processor logic & processor_kwargs wrapping @@ -242,3 +246,110 @@ def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): model_config) assert max_multimodal_tokens == DEFAULT_NUM_CROPS + + +### Test overrides for the mapper +@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) +def test_default_mapper_with_processer_kwargs(image_assets, num_crops): + """Ensure that the mapper processor kwargs can fall back to HF models.""" + # NOTE - we don't validate bad inputs for the default mapper, because it's + # through the automodel interface in transformers, so we can't easily + # inspect what kwargs are or are not allowed. + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] + assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + + +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_custom_mapper_kwarg_overrides(image_assets, num_crops): + """Ensure that custom mappers can consume processor_kwargs.""" + processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {Phi3VForCausalLM: custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 + + +@pytest.mark.parametrize( + "processor_kwargs", + [ + { + "does_not_exist": 100 + }, # Not part of the signature + { + "ctx": "something bad" + } # Part of the signature, not keyword only + ]) +def test_custom_mapper_with_sad_kwarg_overrides(image_assets, + processor_kwargs): + """Ensure that custom mappers can filter out invalid processor_kwargs.""" + + model_config = ModelConfig( + MULTIMODAL_MODEL_ID, + MULTIMODAL_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=0, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {Phi3VForCausalLM: custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 From 3ada64de23250759389a7d49000ae87048b1efae Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:21:23 -0400 Subject: [PATCH 18/31] Update mapper not on multimodal processor kwargs Signed-off-by: Alex-Brooks --- vllm/multimodal/registry.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f1c56226e044..3940e1671b57 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,14 +138,15 @@ def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ - # TODO - there is a bit of weirdness here in the way mapper handles - # the args, because for the HF one, we pass processor_kwargs at init - # time and don't need them at func time, for the function's we are - # wrapping in processor like interfaces, we pass them at the time - # of invocation. + # NOTE - we currently make the assumption that if a model has multiple + # supported modalities, they take the same kwargs. For the default, + # this could be an issue in the future if it falls back to two HF + # resources and we can't inspect the signature easily since it's + # getting initialized through the autoclass. # - # Currently it works, but warns when the default processor is used, - # which is bad. + # If this is a problem in the future, we should revisit it, but since + # it potentially introduces a lot of complexity for a currently + # uncommon case, we do not for simplicity of both use & implementation return functools.partial(self.map_input, model_config) def register_max_multimodal_tokens( From 58dcc63ce4b11fd38b212e9fc65d893da6e8a336 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:36:53 -0400 Subject: [PATCH 19/31] processor kwarg test cleanup Signed-off-by: Alex-Brooks --- ..._processor.py => test_processor_kwargs.py} | 152 +++++++----------- 1 file changed, 60 insertions(+), 92 deletions(-) rename tests/multimodal/{test_processor.py => test_processor_kwargs.py} (74%) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor_kwargs.py similarity index 74% rename from tests/multimodal/test_processor.py rename to tests/multimodal/test_processor_kwargs.py index c86fa5d3c7e4..1c84cd265e26 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -24,17 +24,18 @@ NUM_CROPS_OVERRIDE = 16 -def get_model_config(processor_kwargs=None): +def get_model_config(model_name, trust_remote_code=False, processor_kwargs=None, limit_mm_per_prompt=None): """Creates a handle to a model config, which may have processor kwargs.""" # NOTE - values / architecture don't matter too much here since we patch # the return values for stuff like the input processor anyway. - return ModelConfig(DUMMY_MODEL_ID, - DUMMY_MODEL_ID, + return ModelConfig(model_name, + model_name, tokenizer_mode="auto", - trust_remote_code=False, + trust_remote_code=trust_remote_code, dtype="float16", seed=0, - processor_kwargs=processor_kwargs) + processor_kwargs=processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt) # Mocks for all of the places that we use the processor_kwargs @@ -77,7 +78,7 @@ def custom_dummy_data_factory(self, yield -# lambda whose signature matches max token calcs + extra kwargs +# lambda whose signature matches max token calcs + extra kwargs & mapper respectively get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) @@ -88,7 +89,7 @@ def custom_dummy_data_factory(self, def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() - model_config = get_model_config() + model_config = get_model_config(DUMMY_MODEL_ID) processor = dummy_registry.create_input_processor(model_config) proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) @@ -97,14 +98,15 @@ def test_default_processor_is_a_noop(): @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_processor_default_kwargs(use_processor_mock, num_crops): - """Ensure that we can override processor kwargs.""" + """Ensure input processors can use processor kwargs.""" dummy_registry = InputRegistry() # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(processor_kwargs=processor_kwargs) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -114,19 +116,18 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): - """Ensure invalid processor_kwargs can't be used in the input processor.""" + """Ensure that input processors filter out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs=processor_kwargs) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) processor = dummy_registry.create_input_processor(model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -136,10 +137,12 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, ### Test overrides for the dummy data @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): + """Ensure dummy data factories can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs=processor_kwargs, ) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -154,18 +157,17 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): - """Ensure that dummy_data kwargs that are unused do not fail.""" + """Ensure that dummy data factory filters out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(processor_kwargs=processor_kwargs, ) + model_config = get_model_config(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -180,19 +182,14 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, ### Test overrides for the max token count per multimodal instance @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_max_tokens_kwarg_overrides(num_crops): + """Ensure max token calcs can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -213,24 +210,17 @@ def test_max_tokens_kwarg_overrides(num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + """Ensure that max token calcs filters out invalid processor_kwargs.""" + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -255,16 +245,10 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}, - ) + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -279,20 +263,13 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_custom_mapper_kwarg_overrides(image_assets, num_crops): - """Ensure that custom mappers can consume processor_kwargs.""" + """Ensure custom mappers can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) @@ -315,27 +292,18 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): @pytest.mark.parametrize( "processor_kwargs", [ - { - "does_not_exist": 100 - }, # Not part of the signature - { - "ctx": "something bad" - } # Part of the signature, not keyword only + # Not part of the signature + {"does_not_exist": 100}, + # Part of the signature, not keyword only + {"ctx": "something bad"} ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, processor_kwargs): - """Ensure that custom mappers can filter out invalid processor_kwargs.""" - - model_config = ModelConfig( - MULTIMODAL_MODEL_ID, - MULTIMODAL_MODEL_ID, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}, - ) + """Ensure that custom mappers filters out invalid processor_kwargs.""" + model_config = get_model_config(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(model_config) From 1cee21558d1eff37bfa81cda3825de262a24d07a Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 16:59:16 -0400 Subject: [PATCH 20/31] Move context builder to test utils Signed-off-by: Alex-Brooks --- .../decoder_only/vision_language/test_qwen.py | 29 +---------------- tests/models/utils.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index e4f79092b760..638fb68b8f87 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -5,14 +5,13 @@ import torch from PIL.Image import Image -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, VllmRunner, _ImageAssets) -from ...utils import check_logprobs_close +from ...utils import build_model_context, check_logprobs_close text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component @@ -42,32 +41,6 @@ IMG_SIZE = 448 -def build_model_context(model_name: str, - tokenizer_name: Optional[str] = None, - trust_remote_code: bool = False): - """Creates an InputContext for a given model. - - Args: - model_name: Name of the model being considered. - tokenizer_name: Name of the tokenizer being considered. - trust_remote_code: Whether or not to allow loading remote code. - - Returns: - InputContext for the model being considered. - """ - if tokenizer_name is None: - tokenizer_name = model_name - model_config = ModelConfig( - model_name, - tokenizer_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float32", - seed=0, - ) - return InputContext(model_config) - - @pytest.fixture() def input_mapper_for_qwen(): # Lazy import to avoid initializing CUDA during test collection diff --git a/tests/models/utils.py b/tests/models/utils.py index 8e31a1d6eefe..0c3e876dd6cd 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,6 +1,8 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union +from vllm.config import ModelConfig +from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -240,3 +242,33 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) + + +def build_model_context(model_name: str, + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False, + processor_kwargs: Optional[Dict] = None): + """Creates an InputContext for a given model. + + Args: + model_name: Name of the model being considered. + tokenizer_name: Name of the tokenizer being considered. + trust_remote_code: Whether or not to allow loading remote code. + processor_kwargs: optional processor kwargs for to be leveraged + in the input processor, mapper, dummy data creation, etc. + + Returns: + InputContext for the model being considered. + """ + if tokenizer_name is None: + tokenizer_name = model_name + model_config = ModelConfig( + model_name, + tokenizer_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + dtype="float32", + seed=0, + processor_kwargs=processor_kwargs, + ) + return InputContext(model_config) From d5f9efa94a80e2a4751a69f109027df0334789c7 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 21:55:42 -0400 Subject: [PATCH 21/31] Use common context builder in processor kwarg tests Signed-off-by: Alex-Brooks --- tests/models/utils.py | 5 +- tests/multimodal/test_processor_kwargs.py | 144 +++++++++++----------- 2 files changed, 77 insertions(+), 72 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 0c3e876dd6cd..77a7e054bf68 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -247,7 +247,8 @@ def check_logprobs_close( def build_model_context(model_name: str, tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, - processor_kwargs: Optional[Dict] = None): + processor_kwargs: Optional[Dict] = None, + limit_mm_per_prompt: Optional[Dict] = None): """Creates an InputContext for a given model. Args: @@ -256,6 +257,7 @@ def build_model_context(model_name: str, trust_remote_code: Whether or not to allow loading remote code. processor_kwargs: optional processor kwargs for to be leveraged in the input processor, mapper, dummy data creation, etc. + limit_mm_per_prompt: Multimodal limits. Returns: InputContext for the model being considered. @@ -270,5 +272,6 @@ def build_model_context(model_name: str, dtype="float32", seed=0, processor_kwargs=processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, ) return InputContext(model_config) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 1c84cd265e26..35df3fe1492e 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,13 +5,14 @@ import pytest import torch -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.inputs.registry import InputRegistry from vllm.model_executor.models.phi3v import Phi3VForCausalLM from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from ..models.utils import build_model_context + # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" # Used for tests that need a multimodal model @@ -24,20 +25,6 @@ NUM_CROPS_OVERRIDE = 16 -def get_model_config(model_name, trust_remote_code=False, processor_kwargs=None, limit_mm_per_prompt=None): - """Creates a handle to a model config, which may have processor kwargs.""" - # NOTE - values / architecture don't matter too much here since we patch - # the return values for stuff like the input processor anyway. - return ModelConfig(model_name, - model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float16", - seed=0, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt=limit_mm_per_prompt) - - # Mocks for all of the places that we use the processor_kwargs # to override values in different callables @pytest.fixture @@ -78,7 +65,7 @@ def custom_dummy_data_factory(self, yield -# lambda whose signature matches max token calcs + extra kwargs & mapper respectively +# lambda whose signature matches max token calcs extra & mapper + extra kwargs get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) @@ -89,8 +76,8 @@ def custom_dummy_data_factory(self, def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID) - processor = dummy_registry.create_input_processor(model_config) + ctx = build_model_context(DUMMY_MODEL_ID) + processor = dummy_registry.create_input_processor(ctx.model_config) proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs @@ -105,9 +92,9 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): # otherwise fall back to the default value processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) - processor = dummy_registry.create_input_processor(model_config) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) + processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == expected_num_crops @@ -117,19 +104,22 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor_kwargs): """Ensure that input processors filter out invalid processor_kwargs.""" dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) - - processor = dummy_registry.create_input_processor(model_config) + processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) assert num_crops_val == DEFAULT_NUM_CROPS @@ -141,16 +131,16 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, seq_len=-1, mm_registry=mm_registry) + ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count @@ -158,24 +148,28 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, processor_kwargs): """Ensure that dummy data factory filters out invalid processor_kwargs.""" dummy_registry = InputRegistry() - model_config = get_model_config(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + ctx = build_model_context(DUMMY_MODEL_ID, + processor_kwargs=processor_kwargs) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( - model_config, seq_len=-1, mm_registry=mm_registry) + ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS @@ -186,13 +180,13 @@ def test_max_tokens_kwarg_overrides(num_crops): processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -202,7 +196,7 @@ def test_max_tokens_kwarg_overrides(num_crops): {Phi3VForCausalLM: get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( - model_config) + ctx.model_config) assert expected_seq_count == max_multimodal_tokens @@ -211,19 +205,23 @@ def test_max_tokens_kwarg_overrides(num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): """Ensure that max token calcs filters out invalid processor_kwargs.""" - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Similar before, but since these kwargs get filtered, # we always get our default value back. @@ -233,7 +231,7 @@ def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): {Phi3VForCausalLM: get_num_crops}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( - model_config) + ctx.model_config) assert max_multimodal_tokens == DEFAULT_NUM_CROPS @@ -245,18 +243,18 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) image = image_assets[0].pil_image mm_inputs = {"image": image} - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 @@ -266,13 +264,13 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): """Ensure custom mappers can use processor kwargs.""" processor_kwargs = None if num_crops is None else {"num_crops": num_crops} expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -284,7 +282,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): "_default_input_mapper", {Phi3VForCausalLM: custom_mapper}, ): - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 @@ -293,20 +291,24 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): "processor_kwargs", [ # Not part of the signature - {"does_not_exist": 100}, + { + "does_not_exist": 100 + }, # Part of the signature, not keyword only - {"ctx": "something bad"} + { + "ctx": "something bad" + } ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, processor_kwargs): """Ensure that custom mappers filters out invalid processor_kwargs.""" - model_config = get_model_config(MULTIMODAL_MODEL_ID, - trust_remote_code=True, - processor_kwargs=processor_kwargs, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + processor_kwargs=processor_kwargs, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(model_config) + mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos # our num_crops value back from the processor_kwargs. @@ -318,6 +320,6 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, "_default_input_mapper", {Phi3VForCausalLM: custom_mapper}, ): - mapped_inputs = mm_registry.map_input(model_config, mm_inputs) + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 From b5d434b5c9c906ded90f56b10372c434e94751dc Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:05:26 -0600 Subject: [PATCH 22/31] Update vllm/entrypoints/llm.py Co-authored-by: Cyrus Leung --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6304851233ce..d27ea214ff37 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,7 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, - processor_kwargs=None, + processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: ''' From a0963014c73941bf2fa75f059dd96f2c68133731 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:06:41 -0600 Subject: [PATCH 23/31] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 08f516e17aef..e9d528ad9067 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -217,7 +217,7 @@ def _default_input_processor(self, ctx: InputContext, """The default input processor is a no-op.""" return inputs - def register_input_processor(self, processor: InputProcessor) -> Callable: + def register_input_processor(self, processor: InputProcessor): """ Register an input processor to a model class. From 79962e02e55d70c2fa4672c863454e898a1a71a3 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:06:54 -0600 Subject: [PATCH 24/31] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index e9d528ad9067..c0043f35711b 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -256,7 +256,7 @@ def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", return processor(InputContext(model_config), inputs, **processor_kwargs) - def create_input_processor(self, model_config: "ModelConfig") -> Callable: + def create_input_processor(self, model_config: "ModelConfig"): """ Create an input processor (see :meth:`_process_input`) for a specific model. From 2cb1f72c0b8c53d6c7278959113a0a6a8ead2b05 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:07:10 -0600 Subject: [PATCH 25/31] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index c0043f35711b..410ae3021f4e 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -274,7 +274,7 @@ def create_input_processor(self, model_config: "ModelConfig"): **processor_kwargs) def _get_model_input_processor(self, - model_config: "ModelConfig") -> Callable: + model_config: "ModelConfig"): """ Grabs the input processor for the provided model. From 37eb5324855268587217011ebbd6e2410c57d721 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:07:25 -0600 Subject: [PATCH 26/31] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 410ae3021f4e..a5b5cddc6c6c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -81,7 +81,8 @@ def __call__( Note: :data:`InputProcessor` is not applied to the dummy data. - The processor_kwargs are overrides provided at initialization + + The :code:`processor_kwargs` are overrides provided at initialization time to values in the config whose values may affect the number of tokens per instance. """ From a4c7c3dea4753684d08a42bc8a97476518a95aa0 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 00:08:31 -0600 Subject: [PATCH 27/31] Update vllm/inputs/registry.py Co-authored-by: Cyrus Leung --- vllm/inputs/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index a5b5cddc6c6c..caaaa0cfc713 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -243,7 +243,7 @@ def wrapper(model_cls: N) -> N: return wrapper def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", - processor: Callable, **processor_kwargs) -> LLMInputs: + processor: InputProcessor, **processor_kwargs: Any) -> LLMInputs: """ Apply an input processor to an instance of model inputs. This will usually not be invoked be directly, and instead will be wrapped in From 36dd2cba7cb5150ff53cdb1560e3c80b01eb3712 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 22 Sep 2024 02:12:25 -0400 Subject: [PATCH 28/31] Fix formatting Signed-off-by: Alex-Brooks --- vllm/inputs/registry.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index caaaa0cfc713..2a4a2250aba7 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -82,9 +82,9 @@ def __call__( Note: :data:`InputProcessor` is not applied to the dummy data. - The :code:`processor_kwargs` are overrides provided at initialization - time to values in the config whose values may affect the number - of tokens per instance. + The :code:`processor_kwargs` are overrides provided at + initialization time to values in the config whose values + may affect the number of tokens per instance. """ ... @@ -243,7 +243,8 @@ def wrapper(model_cls: N) -> N: return wrapper def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", - processor: InputProcessor, **processor_kwargs: Any) -> LLMInputs: + processor: InputProcessor, + **processor_kwargs: Any) -> LLMInputs: """ Apply an input processor to an instance of model inputs. This will usually not be invoked be directly, and instead will be wrapped in @@ -274,8 +275,7 @@ def create_input_processor(self, model_config: "ModelConfig"): processor=processor, **processor_kwargs) - def _get_model_input_processor(self, - model_config: "ModelConfig"): + def _get_model_input_processor(self, model_config: "ModelConfig"): """ Grabs the input processor for the provided model. From f95c86f7798b2fa6925a7bc993fb5a39f662f5ee Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 22 Sep 2024 04:29:21 -0400 Subject: [PATCH 29/31] Rename processor kwargs to mm processor kwargs Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 6 +- tests/models/utils.py | 6 +- tests/multimodal/test_processor_kwargs.py | 74 +++++++++++++---------- vllm/config.py | 6 +- vllm/engine/arg_utils.py | 10 +-- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 4 +- vllm/inputs/registry.py | 20 +++--- vllm/multimodal/base.py | 14 ++--- vllm/multimodal/image.py | 6 +- vllm/multimodal/video.py | 6 +- vllm/utils.py | 2 +- 12 files changed, 83 insertions(+), 75 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index fabf37aa2a68..360ac1bfbad9 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -54,10 +54,10 @@ def test_bad_nullable_kvs(arg): } }), ]) -def test_processor_kwargs_prompt_parser(arg, expected): +def test_mm_processor_kwargs_prompt_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) else: - args = parser.parse_args(["--processor-kwargs", arg]) - assert args.processor_kwargs == expected + args = parser.parse_args(["--mm-processor-kwargs", arg]) + assert args.mm_processor_kwargs == expected diff --git a/tests/models/utils.py b/tests/models/utils.py index 77a7e054bf68..eb6254f18182 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -247,7 +247,7 @@ def check_logprobs_close( def build_model_context(model_name: str, tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, - processor_kwargs: Optional[Dict] = None, + mm_processor_kwargs: Optional[Dict] = None, limit_mm_per_prompt: Optional[Dict] = None): """Creates an InputContext for a given model. @@ -255,7 +255,7 @@ def build_model_context(model_name: str, model_name: Name of the model being considered. tokenizer_name: Name of the tokenizer being considered. trust_remote_code: Whether or not to allow loading remote code. - processor_kwargs: optional processor kwargs for to be leveraged + mm_processor_kwargs: optional processor kwargs for to be leveraged in the input processor, mapper, dummy data creation, etc. limit_mm_per_prompt: Multimodal limits. @@ -271,7 +271,7 @@ def build_model_context(model_name: str, trust_remote_code=trust_remote_code, dtype="float32", seed=0, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, ) return InputContext(model_config) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 35df3fe1492e..d7fa32a7f214 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -18,14 +18,14 @@ # Used for tests that need a multimodal model MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" -# For processor_kwargs - we test overrides by defining mocks for each place +# For mm_processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. DEFAULT_NUM_CROPS = 4 NUM_CROPS_OVERRIDE = 16 -# Mocks for all of the places that we use the processor_kwargs +# Mocks for all of the places that we use the mm_processor_kwargs # to override values in different callables @pytest.fixture def use_processor_mock(): @@ -72,7 +72,7 @@ def custom_dummy_data_factory(self, } -### Test for default processor logic & processor_kwargs wrapping +### Test for default processor logic & mm_processor_kwargs wrapping def test_default_processor_is_a_noop(): """Ensure that by default, there is no processor override.""" dummy_registry = InputRegistry() @@ -90,10 +90,12 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): # If we have a value for num_crops, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -101,7 +103,7 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -113,11 +115,11 @@ def test_processor_default_kwargs(use_processor_mock, num_crops): } ]) def test_processor_with_sad_kwarg_overrides(use_processor_mock, - processor_kwargs): - """Ensure that input processors filter out invalid processor_kwargs.""" + mm_processor_kwargs): + """Ensure that input processors filter out invalid mm_processor_kwargs""" dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) @@ -128,24 +130,26 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): """Ensure dummy data factories can use processor kwargs.""" - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq - # len is solely dependent on the value of the processor_kwargs. + # len is solely dependent on the value of the mm_processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == expected_seq_count @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -157,17 +161,17 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): } ]) def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, - processor_kwargs): - """Ensure that dummy data factory filters out invalid processor_kwargs.""" + mm_processor_kwargs): + """Ensure the dummy data factory filters out invalid mm_processor_kwargs""" dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, - processor_kwargs=processor_kwargs) + mm_processor_kwargs=mm_processor_kwargs) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq - # len is solely dependent on the value of the processor_kwargs. + # len is solely dependent on the value of the mm_processor_kwargs. seq_data, _ = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS @@ -177,19 +181,21 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_max_tokens_kwarg_overrides(num_crops): """Ensure max token calcs can use processor kwargs.""" - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the processor_kwargs. + # our num_crops value back from the mm_processor_kwargs. with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", @@ -202,7 +208,7 @@ def test_max_tokens_kwarg_overrides(num_crops): @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -213,11 +219,11 @@ def test_max_tokens_kwarg_overrides(num_crops): "ctx": "something bad" } ]) -def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs): - """Ensure that max token calcs filters out invalid processor_kwargs.""" +def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): + """Ensure that max token calcs filters out invalid mm_processor_kwargs""" ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() @@ -245,7 +251,7 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): # inspect what kwargs are or are not allowed. ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs={"num_crops": num_crops}, + mm_processor_kwargs={"num_crops": num_crops}, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() @@ -262,18 +268,20 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops): @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def test_custom_mapper_kwarg_overrides(image_assets, num_crops): """Ensure custom mappers can use processor kwargs.""" - processor_kwargs = None if num_crops is None else {"num_crops": num_crops} + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the processor_kwargs. + # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} @@ -288,7 +296,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): @pytest.mark.parametrize( - "processor_kwargs", + "mm_processor_kwargs", [ # Not part of the signature { @@ -300,18 +308,18 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops): } ]) def test_custom_mapper_with_sad_kwarg_overrides(image_assets, - processor_kwargs): - """Ensure that custom mappers filters out invalid processor_kwargs.""" + mm_processor_kwargs): + """Ensure that custom mappers filters out invalid mm_processor_kwargs""" ctx = build_model_context(MULTIMODAL_MODEL_ID, trust_remote_code=True, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the processor_kwargs. + # our num_crops value back from the mm_processor_kwargs. image = image_assets[0].pil_image mm_inputs = {"image": image} diff --git a/vllm/config.py b/vllm/config.py index 94552a22cc25..c30867565132 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -122,7 +122,7 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - processor_kwargs: Arguments to be forwarded to the model's processor, + mm_processor_kwargs: Arguments to be forwarded to the model's processor, e.g., tokenizer, image processor, or custom processor callable. """ @@ -153,7 +153,7 @@ def __init__(self, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, - processor_kwargs: Optional[Dict[str, Any]] = None) -> None: + mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -187,7 +187,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - self.processor_kwargs = processor_kwargs + self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca1f334de535..ca6034ddbe5c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -175,7 +175,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None - processor_kwargs: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -515,11 +515,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'images and 2 videos per prompt. Defaults to 1 for ' 'each modality.')) parser.add_argument( - '--processor-kwargs', + '--mm-processor-kwargs', default=None, type=json.loads, - help=('Overrides for the model processor, e.g., tokenizer or ' - 'image processor. For example: {"num_crops": 4}.')) + help=('Overrides for the multimodal input mapping/processing,' + 'e.g., image processor. For example: {"num_crops": 4}.')) # LoRA related configs parser.add_argument('--enable-lora', @@ -829,7 +829,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, - processor_kwargs=self.processor_kwargs, + mm_processor_kwargs=self.mm_processor_kwargs, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a482cbbe2009..4d9696e464bc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -235,7 +235,7 @@ def __init__( "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, processor_kwargs=%s)", + "use_async_output_proc=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -268,7 +268,7 @@ def __init__( scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, model_config.use_async_output_proc, - model_config.processor_kwargs, + model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d27ea214ff37..5dd02d0f9a12 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -134,7 +134,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, - processor_kwargs: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: ''' @@ -175,7 +175,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, - processor_kwargs=processor_kwargs, + mm_processor_kwargs=mm_processor_kwargs, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2a4a2250aba7..f6e53a08bb48 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -74,7 +74,7 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - **processor_kwargs: Any, + **mm_processor_kwargs: Any, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. @@ -82,7 +82,7 @@ def __call__( Note: :data:`InputProcessor` is not applied to the dummy data. - The :code:`processor_kwargs` are overrides provided at + The :code:`mm_processor_kwargs` are overrides provided at initialization time to values in the config whose values may affect the number of tokens per instance. """ @@ -190,12 +190,12 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=dummy_factory, overrides=model_config.processor_kwargs) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=dummy_factory, overrides=model_config.mm_processor_kwargs) seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), - **processor_kwargs) + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids @@ -244,7 +244,7 @@ def wrapper(model_cls: N) -> N: def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", processor: InputProcessor, - **processor_kwargs: Any) -> LLMInputs: + **mm_processor_kwargs: Any) -> LLMInputs: """ Apply an input processor to an instance of model inputs. This will usually not be invoked be directly, and instead will be wrapped in @@ -256,7 +256,7 @@ def _process_input(self, inputs: LLMInputs, model_config: "ModelConfig", :ref:`input_processing_pipeline` """ return processor(InputContext(model_config), inputs, - **processor_kwargs) + **mm_processor_kwargs) def create_input_processor(self, model_config: "ModelConfig"): """ @@ -268,12 +268,12 @@ def create_input_processor(self, model_config: "ModelConfig"): # NOTE: we don't allow override values for ctx/inputs, since doing # so can lead to value collisions etc. processor = self._get_model_input_processor(model_config) - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=processor, overrides=model_config.processor_kwargs) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=processor, overrides=model_config.mm_processor_kwargs) return functools.partial(self._process_input, model_config=model_config, processor=processor, - **processor_kwargs) + **mm_processor_kwargs) def _get_model_input_processor(self, model_config: "ModelConfig"): """ diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 06cbb528c34b..ee840caabe0b 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -261,16 +261,16 @@ def map_input(self, model_config: ModelConfig, # input mapper; no overrides are used on the default here because they # should be passed to the huggingface resource at initialization time. if mapper != self._default_input_mapper: - processor_kwargs = get_allowed_kwarg_only_overrides( - callable=mapper, overrides=model_config.processor_kwargs) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + callable=mapper, overrides=model_config.mm_processor_kwargs) else: - processor_kwargs = {} + mm_processor_kwargs = {} if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(InputContext(model_config), data, **processor_kwargs) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: @@ -343,11 +343,11 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - processor_kwargs = get_allowed_kwarg_only_overrides( + mm_processor_kwargs = get_allowed_kwarg_only_overrides( callable=max_mm_tokens, - overrides=model_config.processor_kwargs) + overrides=model_config.mm_processor_kwargs) max_mm_tokens = max_mm_tokens(InputContext(model_config), - **processor_kwargs) + **mm_processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index c2657a112173..d71e24d71f2e 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -23,14 +23,14 @@ def get_data_key(self) -> str: return "image" def _get_hf_image_processor(self, model_config: ModelConfig): - processor_kwargs = ({} if model_config.processor_kwargs is None else - model_config.processor_kwargs) + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) # We don't explicitly check kwarg overrides to the HF class # since the automodel just takes kwargs, so we can't inspect it return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, - **processor_kwargs) + **mm_processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index aff920977662..75216df451b3 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -37,14 +37,14 @@ def get_data_key(self) -> str: return "video" def _get_hf_video_processor(self, model_config: ModelConfig): - processor_kwargs = ({} if model_config.processor_kwargs is None else - model_config.processor_kwargs) + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) # We don't explicitly check kwarg overrides to the HF class # since the automodel just takes kwargs, so we can't inspect it return cached_get_video_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, - **processor_kwargs) + **mm_processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/utils.py b/vllm/utils.py index e1b8ccfd6aad..3369a8672909 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1267,7 +1267,7 @@ def get_allowed_kwarg_only_overrides( if param.kind == inspect.Parameter.KEYWORD_ONLY ] - # Drop any processor_kwargs provided by the user that are + # Drop any mm_processor_kwargs provided by the user that are # not kwarg names accepted by the provided input processor. filtered_overrides = { kwarg_name: val From 632dac106fd888697d6cdbf3aa9ab4478579b5e6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 19 Sep 2024 04:04:49 -0400 Subject: [PATCH 30/31] Expose phi3v num crops processor override Signed-off-by: Alex-Brooks Add tests for processing_kwarg overrides in phi3v Signed-off-by: Alex-Brooks Add processor_kwargs override to phi3v offline inference Signed-off-by: Alex-Brooks rename processor kwargs to mm processor kwargs Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language.py | 1 + .../vision_language/test_phi3v.py | 186 +++++++++++++++++- vllm/model_executor/models/phi3v.py | 31 ++- 3 files changed, 204 insertions(+), 14 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 464eaf334e3d..8732c5d9dff4 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -64,6 +64,7 @@ def run_phi3v(question): model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, max_num_seqs=5, + processor_kwargs={"num_crops": 16}, ) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index e248151c40a6..eba0a1a1bce4 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,16 +1,21 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +import torch +from transformers import AutoImageProcessor, AutoTokenizer +from vllm.inputs import InputContext, LLMInputs +from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import build_model_context, check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +76,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, mm_limit=2, tensor_parallel_size=1, ) + + +### Fast tests for correctness in processor_kwarg override handling + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_phi3v(): + from vllm.model_executor.models.phi3v import input_processor_for_phi3v + return input_processor_for_phi3v + + +@pytest.fixture() +def dummy_data_for_phi3v(): + from vllm.model_executor.models.phi3v import dummy_data_for_phi3v + return dummy_data_for_phi3v + + +@pytest.fixture() +def get_max_phi3v_image_tokens(): + from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens + return get_max_phi3v_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops", [4, 16, None]) +def test_input_mapper_override(model: str, image_assets: _ImageAssets, + num_crops: Optional[int]): + """Ensure that the [default] input mapper handles num_crops properly.""" + # We pass the processor kwargs here since for this model, we fall back to + # the default mapper; this will fall back to the HF mapper and forward + # mm_processor_kwargs to it. + mm_processor_kwargs = { + "num_crops": num_crops + } if num_crops is not None else {} + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + hf_processor = AutoImageProcessor.from_pretrained(model, + trust_remote_code=True, + **mm_processor_kwargs) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ) + + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + + assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) + assert torch.all( + hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) + + # For pixel values, the second axis should be the num_crops + 1 + # for the rescaled original image. The default value in VLLM falls + # back to the HF config, which is why we compare to the processor num_crops + assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) + assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_max_tokens", [ + (4, 781), + (16, 2653), +]) +def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, + num_crops: int, expected_max_tokens: int): + """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" + # NOTE: mm_processor_kwargs on the context in this test is unused, since + # this is testing the mapper directly. In practice, the processor kwargs + # are wrapped in a closure when calling the max tokens func. We explicitly + # do NOT use the mm_processor_kwargs in the model context here to ensure + # that the max image tokens implementation is referencing a mix of the + # kwargs to the function and the original mm_processor_kwargs in case + # values are somehow updated and end up in a bad state. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + actual_max_tokens = get_max_phi3v_image_tokens( + InputContext(ctx.model_config), + num_crops=num_crops, + ) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ + (4, 781, 1), + (4, 781, 2), + (16, 2653, 1), + (16, 2653, 2), +]) +def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, + num_crops: int, toks_per_img: int, num_imgs: int): + """Ensure dummy_data_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + sequence_data, _, = dummy_data_for_phi3v( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + num_crops=num_crops, + ) + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + assert img_tok_count == toks_per_img * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), +]) +def test_input_processor_override(input_processor_for_phi3v: Callable, + image_assets: _ImageAssets, model: str, + num_crops: int, expected_toks_per_img: int, + num_imgs: int): + """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model) + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + images = [image_assets[0].pil_image] * num_imgs + + llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + proc_llm_inputs = input_processor_for_phi3v( + ctx=ctx, + llm_inputs=llm_inputs, + num_crops=num_crops, + ) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccae..245381518a7f 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): transposed = False if width < height: width, height = height, width @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size( *, input_height: int, input_width: int, + num_crops: int, ) -> int: - num_crops = hf_config.get("num_crops", 16) + if num_crops is None: + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size( + (new_height // 336 + 1) * 12 -def get_max_phi3v_image_tokens(ctx: InputContext): +def get_max_phi3v_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): return get_phi3v_image_feature_size( ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + num_crops=num_crops, ) -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_phi3v(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops: Optional[int] = None): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx) + image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) seq_data = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops: Optional[int] = None): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = [ get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h) + input_height=h, + num_crops=num_crops) ] image_data = [image_data] elif is_list_of(image_data, Image.Image): @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size.append( get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h)) + input_height=h, + num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor): From 4a9ccaebab67e8f73da49469664a71f56aab3e14 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 23 Sep 2024 19:40:17 -0400 Subject: [PATCH 31/31] Update phi3v examples with num crops overrides --- examples/offline_inference_vision_language.py | 15 ++++++++++++++- ...fline_inference_vision_language_multi_image.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 394361317bb8..6675aa0109a6 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -83,11 +83,24 @@ def run_phi3v(question, modality): # In this example, we override max_num_seqs to 5 while # keeping the original context length of 128k. + + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, max_num_seqs=5, - processor_kwargs={"num_crops": 16}, + mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 92ab4f42baa8..8c5f1a7b7af0 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"num_crops": 4}, ) placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1))