Skip to content

Commit

Permalink
[Model] Refactor and decouple phi3v image embedding (vllm-project#6621)
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and phil committed Aug 6, 2024
1 parent 802dfb7 commit 63a7fa1
Showing 1 changed file with 118 additions and 119 deletions.
237 changes: 118 additions & 119 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings

logger = init_logger(__name__)

Expand Down Expand Up @@ -71,9 +72,8 @@

class Phi3ImageEmbeddingBase(nn.Module):

def __init__(self, wte=None) -> None:
def __init__(self) -> None:
super().__init__()
self.wte = wte
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
Expand All @@ -100,10 +100,9 @@ def get_img_features(self,
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""

def __init__(self, config: PretrainedConfig, wte=None) -> None:
super().__init__(wte)
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()

self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
Expand Down Expand Up @@ -149,118 +148,115 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None:
nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)

self.vocab_size = config.vocab_size
self.type_feature = config.img_processor.get('type_feature', 'patch')

def forward(self, input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
def forward(self, pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings."""

# (batch_size, max_num_crops, 3, height, width)
img_embeds = pixel_values

# (batch_size, 2)
img_sizes = image_sizes

input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])

positions = torch.nonzero(input_ids == self.image_token_id)

select = False

target_dtype = self.img_projection[0].bias.dtype

if len(positions.tolist()) > 0:
# if self.use_hd_transform and img_sizes:
# img_embeds: (num_images, max_num_crops, 3, H, W)
# img_sizes: (num_images, 2).view(1, -1)

bs = img_embeds.shape[0]
# Nx(HW)xC
img_features = self.get_img_features(img_embeds.flatten(0, 1))
base_feat_height = base_feat_width = int(
img_features.shape[1]**0.5)

# bs x max_num_crops x (24x24) x C
img_features = img_features.view(
bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
C = self.image_dim_out
H = base_feat_height

output_imgs = []
output_len = []

for _bs in range(bs):
h, w = img_sizes[_bs]
h = h // 336
w = w // 336
B_ = h * w

# 1 x (24x24) x 1024
global_img_feature = img_features[_bs, :1]

# 1 x 12 x 12 x 4096
glb_img = global_img_feature \
.reshape(1, H // 2, 2, H // 2, 2,C) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, H // 2, H // 2, 4 * C)
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)

# 1 x 156 x 4096
glb_img = torch.cat([glb_img, temp_glb_GN],
dim=2).reshape(1, -1, 4 * C)

# (max_num_crops-1) x (12x12) x C
sub_img = img_features[_bs, 1:]
# 16x574x1024
# get rid of padding sub_img
sub_img = sub_img[:B_]

sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, h * 12, w * 12, 4 * C)
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN],
dim=2).reshape(1, -1, 4 * C)
# (1, num_img_tokens, 1024*4)

# glb + sub
if self.hd_transform_order == 'glb_sub':
output_imgs.append(
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
elif self.hd_transform_order == 'sub_glb':
output_imgs.append(
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))

temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
output_len.append(temp_len)

num_img_tokens = output_len
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_dtype))
img_set_tensor.append(img_feature_proj)
select = True

input_ids.clamp_min_(0).clamp_max_(self.vocab_size)

hidden_states = self.wte(input_ids)

if select:
idx = 0
for i, cnt in enumerate(num_img_tokens):
hidden_states[positions[idx, 0],
positions[idx, 1]:positions[idx, 1] +
cnt] = (img_set_tensor[i].to(
hidden_states.dtype))
idx += cnt

return hidden_states.squeeze(0)
"""
process image and return vision embeddings.
pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
"""
num_images, num_crops, c, h, w = pixel_values.shape
pixel_values = pixel_values.flatten(0, 1)
img_features = self.get_img_features(pixel_values)
img_features = img_features.reshape(num_images, num_crops, -1,
self.image_dim_out)
image_features_proj = self.hd_feature_transform(
img_features, image_sizes)
return image_features_proj

def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""
assert (
self.hd_transform_order == 'sub_glb'
), f'hd_transform_order `{self.hd_transform_order}` not implemented'
if isinstance(self.img_projection, nn.Sequential):
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
else: # It's a single nn.Linear layer
target_device = self.img_projection.bias.device
target_dtype = self.img_projection.bias.dtype

global_image_features = image_features[:,
0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(
global_image_features, 1, 1)
global_image_features_hd_newline = self.add_image_newline(
global_image_features_hd)

all_image_embeddings = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop

# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1:1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(
sub_image_features_hd)

# [sub features, separator, global features]
all_image_embeddings.append(
torch.cat([
sub_image_features_hd_newline.squeeze(
0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]))

image_features_proj = self.img_projection(
torch.stack(all_image_embeddings).to(target_device, target_dtype)
) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)

return image_features_proj

def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
where h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2,
-1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2,
4 * C) # n_img, h_crop*12, w_crop*12, 4096
)
return image_features_hd

def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1,
-1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat(
[image_features_hd, newline_embeddings],
dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline


class Phi3VImagePixelInputs(TypedDict):
Expand Down Expand Up @@ -458,12 +454,12 @@ def __init__(self,

self.config = config
self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID

self.model = LlamaModel(config, cache_config, quant_config)

# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config, self.model.embed_tokens)
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
Expand Down Expand Up @@ -530,9 +526,12 @@ def forward(self,
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
inputs_embeds = self.vision_embed_tokens(
input_ids, image_input["data"], image_input["image_sizes"])

vision_embeddings = self.vision_embed_tokens(
image_input["data"], image_input["image_sizes"])
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None
Expand Down

0 comments on commit 63a7fa1

Please sign in to comment.