From 63a7fa13734afc1aa3cc7e815a81e53fa5a14ef7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 22 Jul 2024 07:07:58 +0800 Subject: [PATCH] [Model] Refactor and decouple phi3v image embedding (#6621) --- vllm/model_executor/models/phi3v.py | 237 ++++++++++++++-------------- 1 file changed, 118 insertions(+), 119 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 8b2c425289f0a..75e2f5fc95cb7 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -43,6 +43,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, input_processor_for_clip) from .interfaces import SupportsVision +from .utils import merge_vision_embeddings logger = init_logger(__name__) @@ -71,9 +72,8 @@ class Phi3ImageEmbeddingBase(nn.Module): - def __init__(self, wte=None) -> None: + def __init__(self) -> None: super().__init__() - self.wte = wte self.layer_idx: int self.type_feature: str self.img_processor: CLIPVisionModel @@ -100,10 +100,9 @@ def get_img_features(self, class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): """Phi3 Image embedding with HD transform.""" - def __init__(self, config: PretrainedConfig, wte=None) -> None: - super().__init__(wte) + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() - self.image_token_id = _IMAGE_TOKEN_ID # n_embed or hidden_size hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size @@ -149,118 +148,115 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None: nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) - self.vocab_size = config.vocab_size self.type_feature = config.img_processor.get('type_feature', 'patch') - def forward(self, input_ids: torch.LongTensor, - pixel_values: torch.FloatTensor, + def forward(self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor) -> torch.FloatTensor: - """process and merge text embeddings with image embeddings.""" - - # (batch_size, max_num_crops, 3, height, width) - img_embeds = pixel_values - - # (batch_size, 2) - img_sizes = image_sizes - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - positions = torch.nonzero(input_ids == self.image_token_id) - - select = False - - target_dtype = self.img_projection[0].bias.dtype - - if len(positions.tolist()) > 0: - # if self.use_hd_transform and img_sizes: - # img_embeds: (num_images, max_num_crops, 3, H, W) - # img_sizes: (num_images, 2).view(1, -1) - - bs = img_embeds.shape[0] - # Nx(HW)xC - img_features = self.get_img_features(img_embeds.flatten(0, 1)) - base_feat_height = base_feat_width = int( - img_features.shape[1]**0.5) - - # bs x max_num_crops x (24x24) x C - img_features = img_features.view( - bs, -1, base_feat_height * base_feat_width, self.image_dim_out) - C = self.image_dim_out - H = base_feat_height - - output_imgs = [] - output_len = [] - - for _bs in range(bs): - h, w = img_sizes[_bs] - h = h // 336 - w = w // 336 - B_ = h * w - - # 1 x (24x24) x 1024 - global_img_feature = img_features[_bs, :1] - - # 1 x 12 x 12 x 4096 - glb_img = global_img_feature \ - .reshape(1, H // 2, 2, H // 2, 2,C) \ - .permute(0, 1, 3, 2, 4, 5) \ - .reshape(1, H // 2, H // 2, 4 * C) - temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) - - # 1 x 156 x 4096 - glb_img = torch.cat([glb_img, temp_glb_GN], - dim=2).reshape(1, -1, 4 * C) - - # (max_num_crops-1) x (12x12) x C - sub_img = img_features[_bs, 1:] - # 16x574x1024 - # get rid of padding sub_img - sub_img = sub_img[:B_] - - sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \ - .permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C) - sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \ - .permute(0, 1, 3, 2, 4, 5) \ - .reshape(1, h * 12, w * 12, 4 * C) - temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) - sub_img = torch.cat([sub_img, temp_sub_GN], - dim=2).reshape(1, -1, 4 * C) - # (1, num_img_tokens, 1024*4) - - # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) - - temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) - output_len.append(temp_len) - - num_img_tokens = output_len - img_set_tensor = [] - for _output_img in output_imgs: - img_feature_proj = self.img_projection( - _output_img.to(target_dtype)) - img_set_tensor.append(img_feature_proj) - select = True - - input_ids.clamp_min_(0).clamp_max_(self.vocab_size) - - hidden_states = self.wte(input_ids) - - if select: - idx = 0 - for i, cnt in enumerate(num_img_tokens): - hidden_states[positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + - cnt] = (img_set_tensor[i].to( - hidden_states.dtype)) - idx += cnt - - return hidden_states.squeeze(0) + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + output: (num_images, num_img_tokens, hidden_size) + """ + num_images, num_crops, c, h, w = pixel_values.shape + pixel_values = pixel_values.flatten(0, 1) + img_features = self.get_img_features(pixel_values) + img_features = img_features.reshape(num_images, num_crops, -1, + self.image_dim_out) + image_features_proj = self.hd_feature_transform( + img_features, image_sizes) + return image_features_proj + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + assert ( + self.hd_transform_order == 'sub_glb' + ), f'hd_transform_order `{self.hd_transform_order}` not implemented' + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + global_image_features = image_features[:, + 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge( + global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline( + global_image_features_hd) + + all_image_embeddings = [] + # need a for loop to process each image because of different image sizes + # (patch arrangement is different for each image) + for i, img_size in enumerate(image_sizes): + h, w = img_size + h_crop = h // 336 + w_crop = w // 336 + num_crops = h_crop * w_crop + + # NOTE: real num_crops is padded + # (num_crops, 24*24, 1024) + sub_image_features = image_features[i, 1:1 + num_crops] + sub_image_features_hd = self.reshape_hd_patches_2x2merge( + sub_image_features, h_crop, w_crop) + sub_image_features_hd_newline = self.add_image_newline( + sub_image_features_hd) + + # [sub features, separator, global features] + all_image_embeddings.append( + torch.cat([ + sub_image_features_hd_newline.squeeze( + 0), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ])) + + image_features_proj = self.img_projection( + torch.stack(all_image_embeddings).to(target_device, target_dtype) + ) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size) + + return image_features_proj + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096) + where h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, + -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, + 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, + -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat( + [image_features_hd, newline_embeddings], + dim=2).reshape(num_images, -1, hid_dim) + return image_features_hd_newline class Phi3VImagePixelInputs(TypedDict): @@ -458,12 +454,12 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config + self.image_token_id = _IMAGE_TOKEN_ID self.model = LlamaModel(config, cache_config, quant_config) # TODO: Optionally initializes this for supporting embeddings. - self.vision_embed_tokens = Phi3HDImageEmbedding( - config, self.model.embed_tokens) + self.vision_embed_tokens = Phi3HDImageEmbedding(config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -530,9 +526,12 @@ def forward(self, image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: - inputs_embeds = self.vision_embed_tokens( - input_ids, image_input["data"], image_input["image_sizes"]) - + vision_embeddings = self.vision_embed_tokens( + image_input["data"], image_input["image_sizes"]) + inputs_embeds = self.model.get_input_embeddings(input_ids) + inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, + vision_embeddings, + self.image_token_id) input_ids = None else: inputs_embeds = None