diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c1129316a6e3..6675aa0109a6 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -83,10 +83,24 @@ def run_phi3v(question, modality): # In this example, we override max_num_seqs to 5 while # keeping the original context length of 128k. + + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, max_num_seqs=5, + mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 92ab4f42baa8..8c5f1a7b7af0 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"num_crops": 4}, ) placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index e248151c40a6..eba0a1a1bce4 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,16 +1,21 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +import torch +from transformers import AutoImageProcessor, AutoTokenizer +from vllm.inputs import InputContext, LLMInputs +from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import build_model_context, check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +76,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, mm_limit=2, tensor_parallel_size=1, ) + + +### Fast tests for correctness in processor_kwarg override handling + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_phi3v(): + from vllm.model_executor.models.phi3v import input_processor_for_phi3v + return input_processor_for_phi3v + + +@pytest.fixture() +def dummy_data_for_phi3v(): + from vllm.model_executor.models.phi3v import dummy_data_for_phi3v + return dummy_data_for_phi3v + + +@pytest.fixture() +def get_max_phi3v_image_tokens(): + from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens + return get_max_phi3v_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops", [4, 16, None]) +def test_input_mapper_override(model: str, image_assets: _ImageAssets, + num_crops: Optional[int]): + """Ensure that the [default] input mapper handles num_crops properly.""" + # We pass the processor kwargs here since for this model, we fall back to + # the default mapper; this will fall back to the HF mapper and forward + # mm_processor_kwargs to it. + mm_processor_kwargs = { + "num_crops": num_crops + } if num_crops is not None else {} + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + hf_processor = AutoImageProcessor.from_pretrained(model, + trust_remote_code=True, + **mm_processor_kwargs) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ) + + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + + assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) + assert torch.all( + hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) + + # For pixel values, the second axis should be the num_crops + 1 + # for the rescaled original image. The default value in VLLM falls + # back to the HF config, which is why we compare to the processor num_crops + assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) + assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_max_tokens", [ + (4, 781), + (16, 2653), +]) +def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, + num_crops: int, expected_max_tokens: int): + """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" + # NOTE: mm_processor_kwargs on the context in this test is unused, since + # this is testing the mapper directly. In practice, the processor kwargs + # are wrapped in a closure when calling the max tokens func. We explicitly + # do NOT use the mm_processor_kwargs in the model context here to ensure + # that the max image tokens implementation is referencing a mix of the + # kwargs to the function and the original mm_processor_kwargs in case + # values are somehow updated and end up in a bad state. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + actual_max_tokens = get_max_phi3v_image_tokens( + InputContext(ctx.model_config), + num_crops=num_crops, + ) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ + (4, 781, 1), + (4, 781, 2), + (16, 2653, 1), + (16, 2653, 2), +]) +def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, + num_crops: int, toks_per_img: int, num_imgs: int): + """Ensure dummy_data_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + sequence_data, _, = dummy_data_for_phi3v( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + num_crops=num_crops, + ) + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + assert img_tok_count == toks_per_img * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), +]) +def test_input_processor_override(input_processor_for_phi3v: Callable, + image_assets: _ImageAssets, model: str, + num_crops: int, expected_toks_per_img: int, + num_imgs: int): + """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model) + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + images = [image_assets[0].pil_image] * num_imgs + + llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + proc_llm_inputs = input_processor_for_phi3v( + ctx=ctx, + llm_inputs=llm_inputs, + num_crops=num_crops, + ) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccae..245381518a7f 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): transposed = False if width < height: width, height = height, width @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size( *, input_height: int, input_width: int, + num_crops: int, ) -> int: - num_crops = hf_config.get("num_crops", 16) + if num_crops is None: + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size( + (new_height // 336 + 1) * 12 -def get_max_phi3v_image_tokens(ctx: InputContext): +def get_max_phi3v_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): return get_phi3v_image_feature_size( ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + num_crops=num_crops, ) -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_phi3v(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops: Optional[int] = None): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx) + image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) seq_data = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops: Optional[int] = None): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = [ get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h) + input_height=h, + num_crops=num_crops) ] image_data = [image_data] elif is_list_of(image_data, Image.Image): @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size.append( get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h)) + input_height=h, + num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor):