Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Expose Phi3v num_crops as a mm_processor_kwarg #8658

Merged
merged 33 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
550378b
Allow for processor kwarg overrides
alex-jw-brooks Sep 16, 2024
190606f
Pass processor through to partial
alex-jw-brooks Sep 17, 2024
b1ca041
Add default & processor kwarg override tests
alex-jw-brooks Sep 17, 2024
195e31c
Don't allow ctx or inputs as kwargs
alex-jw-brooks Sep 17, 2024
1472d04
Add kwarg override for processor to dummy data factories
alex-jw-brooks Sep 17, 2024
f10601f
Add kwarg override forr processor to max token calc
alex-jw-brooks Sep 19, 2024
429097a
Move kwarg only override func to utils
alex-jw-brooks Sep 19, 2024
159cfc2
Force processor kwargs to be keyword-only
alex-jw-brooks Sep 19, 2024
af91930
Pass unfiltered processor kwargs to default mapper
alex-jw-brooks Sep 19, 2024
9adad10
Add hack for mapper preprocessor kwargs
alex-jw-brooks Sep 19, 2024
9f7aed8
Simplify dummy data processor kwarg & add tests
alex-jw-brooks Sep 19, 2024
ff59e44
Add tests for max multimodal token kwarg overrides
alex-jw-brooks Sep 19, 2024
6b26454
Format registry
alex-jw-brooks Sep 20, 2024
0e2d53d
Fix default mapper comparison
alex-jw-brooks Sep 20, 2024
5a3341b
Move kwarg filtering into hf processor getter
alex-jw-brooks Sep 20, 2024
3e1fe54
Enable processor_kwargs in video processor
alex-jw-brooks Sep 20, 2024
feccfd7
Add tests for mapper processor_kwargs
alex-jw-brooks Sep 20, 2024
3ada64d
Update mapper not on multimodal processor kwargs
alex-jw-brooks Sep 20, 2024
58dcc63
processor kwarg test cleanup
alex-jw-brooks Sep 20, 2024
1cee215
Move context builder to test utils
alex-jw-brooks Sep 19, 2024
d5f9efa
Use common context builder in processor kwarg tests
alex-jw-brooks Sep 20, 2024
b5d434b
Update vllm/entrypoints/llm.py
alex-jw-brooks Sep 22, 2024
a096301
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
79962e0
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
2cb1f72
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
37eb532
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
a4c7c3d
Update vllm/inputs/registry.py
alex-jw-brooks Sep 22, 2024
36dd2cb
Fix formatting
alex-jw-brooks Sep 22, 2024
f95c86f
Rename processor kwargs to mm processor kwargs
alex-jw-brooks Sep 22, 2024
632dac1
Expose phi3v num crops processor override
alex-jw-brooks Sep 19, 2024
9eca61a
Merge branch 'main' into phi3v_num_crops
DarkLight1337 Sep 23, 2024
a3ab6cb
Merge branch 'main' into phi3v_num_crops
DarkLight1337 Sep 23, 2024
4a9ccae
Update phi3v examples with num crops overrides
alex-jw-brooks Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rename processor kwargs to mm processor kwargs
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Sep 22, 2024
commit f95c86f7798b2fa6925a7bc993fb5a39f662f5ee
6 changes: 3 additions & 3 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def test_bad_nullable_kvs(arg):
}
}),
])
def test_processor_kwargs_prompt_parser(arg, expected):
def test_mm_processor_kwargs_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--processor-kwargs", arg])
assert args.processor_kwargs == expected
args = parser.parse_args(["--mm-processor-kwargs", arg])
assert args.mm_processor_kwargs == expected
6 changes: 3 additions & 3 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,15 @@ def check_logprobs_close(
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
processor_kwargs: Optional[Dict] = None,
mm_processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.

Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
processor_kwargs: optional processor kwargs for to be leveraged
mm_processor_kwargs: optional processor kwargs for to be leveraged
in the input processor, mapper, dummy data creation, etc.
limit_mm_per_prompt: Multimodal limits.

Expand All @@ -271,7 +271,7 @@ def build_model_context(model_name: str,
trust_remote_code=trust_remote_code,
dtype="float32",
seed=0,
processor_kwargs=processor_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
)
return InputContext(model_config)
74 changes: 41 additions & 33 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"

# For processor_kwargs - we test overrides by defining mocks for each place
# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS = 4
NUM_CROPS_OVERRIDE = 16


# Mocks for all of the places that we use the processor_kwargs
# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
Expand Down Expand Up @@ -72,7 +72,7 @@ def custom_dummy_data_factory(self,
}


### Test for default processor logic & processor_kwargs wrapping
### Test for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
Expand All @@ -90,18 +90,20 @@ def test_processor_default_kwargs(use_processor_mock, num_crops):
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)

num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops


@pytest.mark.parametrize(
"processor_kwargs",
"mm_processor_kwargs",
[
# Not part of the signature
{
Expand All @@ -113,11 +115,11 @@ def test_processor_default_kwargs(use_processor_mock, num_crops):
}
])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor_kwargs):
"""Ensure that input processors filter out invalid processor_kwargs."""
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs)

processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
Expand All @@ -128,24 +130,26 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
"""Ensure dummy data factories can use processor kwargs."""
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the processor_kwargs.
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count


@pytest.mark.parametrize(
"processor_kwargs",
"mm_processor_kwargs",
[
# Not part of the signature
{
Expand All @@ -157,17 +161,17 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
}
])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
processor_kwargs):
"""Ensure that dummy data factory filters out invalid processor_kwargs."""
mm_processor_kwargs):
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
processor_kwargs=processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the processor_kwargs.
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
Expand All @@ -177,19 +181,21 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_max_tokens_kwarg_overrides(num_crops):
"""Ensure max token calcs can use processor kwargs."""
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops

ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the processor_kwargs.
# our num_crops value back from the mm_processor_kwargs.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
Expand All @@ -202,7 +208,7 @@ def test_max_tokens_kwarg_overrides(num_crops):


@pytest.mark.parametrize(
"processor_kwargs",
"mm_processor_kwargs",
[
# Not part of the signature
{
Expand All @@ -213,11 +219,11 @@ def test_max_tokens_kwarg_overrides(num_crops):
"ctx": "something bad"
}
])
def test_max_tokens_with_sad_kwarg_overrides(processor_kwargs):
"""Ensure that max token calcs filters out invalid processor_kwargs."""
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
Expand Down Expand Up @@ -245,7 +251,7 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
# inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs={"num_crops": num_crops},
mm_processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
Expand All @@ -262,18 +268,20 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"""Ensure custom mappers can use processor kwargs."""
processor_kwargs = None if num_crops is None else {"num_crops": num_crops}
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the processor_kwargs.
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}

Expand All @@ -288,7 +296,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):


@pytest.mark.parametrize(
"processor_kwargs",
"mm_processor_kwargs",
[
# Not part of the signature
{
Expand All @@ -300,18 +308,18 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
}
])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
processor_kwargs):
"""Ensure that custom mappers filters out invalid processor_kwargs."""
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
processor_kwargs=processor_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the processor_kwargs.
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}

Expand Down
6 changes: 3 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
processor_kwargs: Arguments to be forwarded to the model's processor,
mm_processor_kwargs: Arguments to be forwarded to the model's processor,
e.g., tokenizer, image processor, or custom processor callable.
"""

Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(self,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(self,
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.processor_kwargs = processor_kwargs
self.mm_processor_kwargs = mm_processor_kwargs

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down
10 changes: 5 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
processor_kwargs: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -515,11 +515,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))
parser.add_argument(
'--processor-kwargs',
'--mm-processor-kwargs',
default=None,
type=json.loads,
help=('Overrides for the model processor, e.g., tokenizer or '
'image processor. For example: {"num_crops": 4}.'))
help=('Overrides for the multimodal input mapping/processing,'
'e.g., image processor. For example: {"num_crops": 4}.'))

# LoRA related configs
parser.add_argument('--enable-lora',
Expand Down Expand Up @@ -829,7 +829,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
processor_kwargs=self.processor_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
)

def create_load_config(self) -> LoadConfig:
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, processor_kwargs=%s)",
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -268,7 +268,7 @@ def __init__(
scheduler_config.num_scheduler_steps,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
model_config.processor_kwargs,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
processor_kwargs: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
'''
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
processor_kwargs=processor_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
Expand Down
Loading
Loading