Skip to content

Commit

Permalink
Expose phi3v num crops processor override
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

Add tests for processing_kwarg overrides in phi3v

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

Add processor_kwargs override to phi3v offline inference

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

rename processor kwargs to mm processor kwargs

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Sep 22, 2024
1 parent f95c86f commit 632dac1
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def run_phi3v(question):
model="microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_num_seqs=5,
processor_kwargs={"num_crops": 16},
)
stop_token_ids = None
return llm, prompt, stop_token_ids
Expand Down
186 changes: 181 additions & 5 deletions tests/models/decoder_only/vision_language/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
import re
from typing import List, Optional, Tuple, Type
from typing import Callable, List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer
import torch
from transformers import AutoImageProcessor, AutoTokenizer

from vllm.inputs import InputContext, LLMInputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu, is_hip

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_logprobs_close
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ...utils import build_model_context, check_logprobs_close

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
Expand Down Expand Up @@ -71,7 +76,7 @@ def run_test(
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
Expand Down Expand Up @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
mm_limit=2,
tensor_parallel_size=1,
)


### Fast tests for correctness in processor_kwarg override handling


# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_phi3v():
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
return input_processor_for_phi3v


@pytest.fixture()
def dummy_data_for_phi3v():
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
return dummy_data_for_phi3v


@pytest.fixture()
def get_max_phi3v_image_tokens():
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
return get_max_phi3v_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops", [4, 16, None])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
num_crops: Optional[int]):
"""Ensure that the [default] input mapper handles num_crops properly."""
# We pass the processor kwargs here since for this model, we fall back to
# the default mapper; this will fall back to the HF mapper and forward
# mm_processor_kwargs to it.
mm_processor_kwargs = {
"num_crops": num_crops
} if num_crops is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)

hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)

vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)

assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
assert torch.all(
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])

# For pixel values, the second axis should be the num_crops + 1
# for the rescaled original image. The default value in VLLM falls
# back to the HF config, which is why we compare to the processor num_crops
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
# this is testing the mapper directly. In practice, the processor kwargs
# are wrapped in a closure when calling the max tokens func. We explicitly
# do NOT use the mm_processor_kwargs in the model context here to ensure
# that the max image tokens implementation is referencing a mix of the
# kwargs to the function and the original mm_processor_kwargs in case
# values are somehow updated and end up in a bad state.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

actual_max_tokens = get_max_phi3v_image_tokens(
InputContext(ctx.model_config),
num_crops=num_crops,
)

assert expected_max_tokens == actual_max_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
(4, 781, 1),
(4, 781, 2),
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
num_crops: int, toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

sequence_data, _, = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
(4, 757, 1),
(4, 757, 2),
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v: Callable,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model)
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs

llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})

proc_llm_inputs = input_processor_for_phi3v(
ctx=ctx,
llm_inputs=llm_inputs,
num_crops=num_crops,
)

# Ensure we have the right number of placeholders per num_crops size
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
31 changes: 22 additions & 9 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):


# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
transposed = False
if width < height:
width, height = height, width
Expand Down Expand Up @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
num_crops: int,
) -> int:
num_crops = hf_config.get("num_crops", 16)
if num_crops is None:
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
Expand All @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12


def get_max_phi3v_image_tokens(ctx: InputContext):
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):

return get_phi3v_image_feature_size(
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
num_crops=num_crops,
)


def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
def dummy_data_for_phi3v(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops: Optional[int] = None):
num_images = mm_counts["image"]

image_feature_size = get_max_phi3v_image_tokens(ctx)
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)

seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
Expand Down Expand Up @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
return image_placeholder_token_ids


def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
def input_processor_for_phi3v(ctx: InputContext,
llm_inputs: LLMInputs,
*,
num_crops: Optional[int] = None):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
Expand All @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
input_height=h,
num_crops=num_crops)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
Expand All @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h))
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
Expand Down

0 comments on commit 632dac1

Please sign in to comment.