From 886ada87ebd7e6f614bfdc70fa5f89521701b776 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Mon, 22 Jul 2024 21:55:44 +0800 Subject: [PATCH 01/19] wip --- swift/llm/utils/dataset.py | 14 ++++++-- swift/llm/utils/template.py | 64 ++++++++++++++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index 5d731c9e2..47bf3f973 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -1203,7 +1203,12 @@ def preprocess(row): bbox[i] = round(float(bbox[i])) res = {} - objects = [[caption, bbox]] + objects = [{ + 'caption': caption, + 'bbox': bbox, + 'bbox_type': 'real', + 'image': 0, + }] media_tag(res, [image_path]) res['images'] = [image_path] res['objects'] = json.dumps(objects, ensure_ascii=False) @@ -1248,7 +1253,12 @@ def preprocess(row): bbox[i] = round(float(bbox[i])) res = {} - objects = [[caption, bbox]] + objects = [{ + 'caption': caption, + 'bbox': bbox, + 'bbox_type': 'real', + 'image': 0, + }] media_tag(res, [image_path]) res['images'] = [image_path] res['objects'] = json.dumps(objects, ensure_ascii=False) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 28200d0c4..76a44acab 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -487,6 +487,47 @@ def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: else: return [''] + @classmethod + def normalize_bbox(cls, objects, images, to_type: Literal['real', 'norm_1000', 'norm_1']): + if not objects or not images: + return + + for object in objects: + bbox = object['bbox'] + bbox_type = object['bbox_type'] + idx = object['image'] + image = images[idx] + if bbox_type == 'real': + if to_type == 'real': + continue + width, height = image.width, image.height + object['bbox'] = [ + int(coord / dim * 999) if to_type == 'norm_1000' else coord / dim for coord, dim in + zip(bbox, [width, height, width, height]) + ] + elif bbox_type == 'norm_1000': + if to_type == 'norm_1000': + continue + if to_type == 'norm_1': + object['bbox'] = [coord/999. for coord in bbox] + elif to_type == 'real': + width, height = image.width, image.height + object['bbox'] = [ + int(coord / 999. * dim) for coord, dim in + zip(bbox, [width, height, width, height]) + ] + elif bbox_type == 'norm_1': + if to_type == 'norm_1': + continue + if to_type == 'norm_1000': + object['bbox'] = [int(coord * 999) for coord in bbox] + elif to_type == 'real': + width, height = image.width, image.height + object['bbox'] = [ + int(coord * dim) for coord, dim in + zip(bbox, [width, height, width, height]) + ] + def pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float], **kwargs) -> Tuple[List[Context], List[float]]: # replace tag/object/box @@ -1452,6 +1493,22 @@ def replace_tag(self, media_type, index, example) -> List[Context]: context_list.append('\n') return context_list + def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]: + objects = example.get('objects') + if objects: + object_ = objects[index] + return [f'{object_}'] + else: + return [''] + + def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: + objects = example.get('objects') + if objects: + object_ = objects[index] + return [f' [[{object_[1][0]}, {object_[1][1]}, {object_[1][2]}, {object_[1][3]}]] '] + else: + return [''] + def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = example.copy() history = example.pop('history', None) @@ -1477,6 +1534,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any from .vision_utils import load_image, load_video pixel_values_images = _read_batch(images_path, load_image) videos_path = [path for path in videos_path if path is not None] + self.normalize_bbox(example.get('objects'), pixel_values_images, to_type='norm_1000') if pixel_values_images: pixel_values = pixel_values_images assert len(pixel_values) == len(idx_list) @@ -1607,10 +1665,7 @@ def __init__(self): } def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: - width, height = example['_image'].width, example['_image'].height - x1, y1, x2, y2 = [ - int(coord / dim * 999) for coord, dim in zip(example['objects'][index][1], [width, height, width, height]) - ] + x1, y1, x2, y2 = example['objects'][index][1] return [f''] def _construct_prompts(self, text): @@ -1641,6 +1696,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any images = _load_image(images_path[0]) example['_image'] = images + self.normalize_bbox(example.get('objects'), images, to_type='norm_1000') # process bbox if example.get('objects') is not None: From 8b0fa4c2984d6d447eed5333144c2492d4b64767 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Mon, 22 Jul 2024 23:15:50 +0800 Subject: [PATCH 02/19] fix --- swift/llm/utils/template.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 76a44acab..e3daa6165 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1435,7 +1435,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any idx_list = _findall(input_ids, -100) labels = inputs.get('labels') images_path = example.get('images') or [] - if isinstance(images_path, str): + if not isinstance(images_path, (list, tuple)): images_path = [images_path] from .vision_utils import load_image pixel_values = _read_batch(images_path, load_image) @@ -1527,9 +1527,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any idx_list = _findall(input_ids, -100) labels = inputs.get('labels') videos_path = example.get('videos') or [] - if isinstance(images_path, str): + if not isinstance(images_path, (list, tuple)): images_path = [images_path] - if isinstance(videos_path, str): + if not isinstance(images_path, (list, tuple)): videos_path = [videos_path] from .vision_utils import load_image, load_video pixel_values_images = _read_batch(images_path, load_image) From 79709a7936f77406c8046dd3093d588aca22ce32 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 22 Jul 2024 23:44:50 +0800 Subject: [PATCH 03/19] fix --- swift/llm/utils/template.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index e3daa6165..faa7a8b9a 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1520,12 +1520,6 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any images.extend(images_path) example['images'] = images images_path = example.get('images') or [] - inputs, _ = super(InternvlTemplate, self).encode(example) - if len(inputs) == 0: - return inputs, {} - input_ids = inputs['input_ids'] - idx_list = _findall(input_ids, -100) - labels = inputs.get('labels') videos_path = example.get('videos') or [] if not isinstance(images_path, (list, tuple)): images_path = [images_path] @@ -1534,7 +1528,15 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any from .vision_utils import load_image, load_video pixel_values_images = _read_batch(images_path, load_image) videos_path = [path for path in videos_path if path is not None] + if example.get('objects'): + example['objects'] = json.loads(example.get('objects')) self.normalize_bbox(example.get('objects'), pixel_values_images, to_type='norm_1000') + inputs, _ = super(InternvlTemplate, self).encode(example) + if len(inputs) == 0: + return inputs, {} + input_ids = inputs['input_ids'] + idx_list = _findall(input_ids, -100) + labels = inputs.get('labels') if pixel_values_images: pixel_values = pixel_values_images assert len(pixel_values) == len(idx_list) From dd5fdb295ba4f73c7d876011c1be0c29229c178d Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 23 Jul 2024 10:18:52 +0800 Subject: [PATCH 04/19] wip --- swift/llm/utils/template.py | 202 ++++++++++++++++++-------------- swift/llm/utils/vision_utils.py | 6 +- 2 files changed, 118 insertions(+), 90 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index faa7a8b9a..3215cd5d5 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -14,6 +14,7 @@ from transformers import PreTrainedTokenizerBase, StoppingCriteria from transformers.dynamic_module_utils import get_class_from_dynamic_module +from swift.llm import MediaTag from swift.llm.agent.utils import calculate_loss_scale, get_tools_prompt from swift.torchacc_utils import pad_and_split_batch from swift.utils import get_dist_setting, upper_bound, use_torchacc @@ -164,6 +165,7 @@ class Template: special_tokens = ['', '', '', '', ''] special_keys = ['images', 'videos', 'audios', 'objects'] + grounding_type = 'norm_1000' def __init__(self, prefix: Prompt, @@ -318,33 +320,49 @@ def _prepare_vllm_images(self, images: List['PIL.Image.Image']) -> List['PIL.Ima return new_images - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """return: inputs, tokenizer_kwargs""" + def preprocess(self, example): + # Duplicate example and create a new one to prepare in-place changes + example = example.copy() + template_type: Optional[str] = getattr(self, 'template_type', None) + tools: Union[List[Any], str] = example.get('tools') or [] + + # Template needs to be initialized if not self._is_init: raise ValueError( 'Template is not initialized, please use the `get_template` function to obtain the template.') - if example.get('images') and not isinstance(example['images'], (tuple, list)): - # change images field to list - example['images'] = [example['images']] - example = example.copy() + + # Check whether this template supports multi-round + history: History = example.get('history') or [] + if len(history) > 0: + assert self.support_multi_round, ( + f'The template does not support multi-round chat, template_type: {template_type}') + + # Format media_keys to list + for media_key in MediaTag.media_keys.values(): + if example.get(media_key) and not isinstance(example[media_key], (tuple, list)): + # change images field to list + example[media_key] = [example[media_key]] + + # Parse format images and merged into images key + example['query'], example['history'], images_path = replace_img_tag(example['query'], history, '') + if images_path: + images = example.get('images', []) + images.extend(images_path) + example['images'] = images + + # Add default tags to examples to note where to put the medias into the sequence self.add_default_tags(example) + + # Check the example that whether matching the very template's rules self.check_example(example) + + # Format objects(groundings/refs) to json if example.get('objects') and isinstance(example['objects'], str): # reload grounding from str example['objects'] = json.loads(example['objects']) - query: str = example.get('query') or '' - query_role: str = example.get('query_role') or 'user' - response: Optional[str] = example.get('response') - history: History = example.get('history') or [] - history_roles: Optional[History] = example.get('history_roles') - system: Optional[str] = example.get('system', None) - template_type: Optional[str] = getattr(self, 'template_type', None) - tools: Union[List[Any], str] = example.get('tools') or [] - is_multi_modal: bool = any([example.get(key) for key in Template.special_keys]) - if len(history) > 0: - assert self.support_multi_round, ( - f'The template does not support multi-round chat, template_type: {template_type}') + # Reset system (by default value and agent tools) + system: Optional[str] = example.get('system', None) if system is None: if self.use_default_system: system = self.default_system @@ -359,10 +377,36 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any if system is None: system = '' system += get_tools_prompt(tools, self.tools_prompt) + + example['system'] = system + + # Set history_roles + history_roles: Optional[History] = example.get('history_roles') if history_roles is None: - history_roles = [['user', 'assistant'] for _ in range(len(history))] + example['history_roles'] = [['user', 'assistant'] for _ in range(len(history))] + + # Load image into PIL format + from .vision_utils import load_image, load_video + if example.get('images'): + example['images'] = [load_image(img) for img in example['images']] + # Normalize grounding bboxes + self.normalize_bbox(example.get('objects'), example.get('images'), to_type=self.grounding_type) + + def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + example = self.preprocess(example) + return self._encode(example) - inputs, tokenizer_kwargs = self._encode( + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """return: inputs, tokenizer_kwargs""" + query: str = example.get('query') or '' + query_role: str = example.get('query_role') or 'user' + response: Optional[str] = example.get('response') + history: History = example.get('history') or [] + history_roles: Optional[History] = example.get('history_roles') + system: Optional[str] = example.get('system', None) + is_multi_modal: bool = any([example.get(key) for key in Template.special_keys]) + + inputs, tokenizer_kwargs = self._concat_and_tokenize( query, query_role, response, @@ -509,7 +553,7 @@ def normalize_bbox(cls, objects, images, to_type: Literal['real', 'norm_1000', ' if to_type == 'norm_1000': continue if to_type == 'norm_1': - object['bbox'] = [coord/999. for coord in bbox] + object['bbox'] = [coord / 999. for coord in bbox] elif to_type == 'real': width, height = image.width, image.height object['bbox'] = [ @@ -581,16 +625,16 @@ def _encode_context_list(self, context_list: List[Context], loss_scale.extend([loss_weight] * len(token_list)) return input_ids, labels, loss_scale, tokenizer_kwargs - def _encode(self, - query: str, - query_role: str, - response: Optional[str], - history: History, - history_roles: History, - system: Optional[str], - truncation_strategy: str, - auto_add_bos: bool = False, - **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _concat_and_tokenize(self, + query: str, + query_role: str, + response: Optional[str], + history: History, + history_roles: History, + system: Optional[str], + truncation_strategy: str, + auto_add_bos: bool = False, + **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ return: inputs, tokenizer_kwargs """ @@ -710,7 +754,7 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = labels[0] = F.pad(labels[0], (0, padding_len) if padding_right else (padding_len, 0), 'constant', -100) if loss_scale: loss_scale[0] = F.pad(loss_scale[0], (0, padding_to - labels[0].shape[-1]) if padding_right else - (padding_to - labels[0].shape[-1], 0), 'constant', 0.) + (padding_to - labels[0].shape[-1], 0), 'constant', 0.) if input_ids is None: inputs_embeds = self.pad_sequence(inputs_embeds, 0, self.padding_side) @@ -797,15 +841,15 @@ def _get_safe_print_idx(cls, response: str, print_idx: int, is_finished: bool = return print_idx def generate_ids_to_response( - self, - generate_ids: List[int], - is_finished: bool = True, - *, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - # only stream=True - return_delta: bool = False, - print_idx: Optional[List[int]] = None, - first_num_space: Optional[List[int]] = None, + self, + generate_ids: List[int], + is_finished: bool = True, + *, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + # only stream=True + return_delta: bool = False, + print_idx: Optional[List[int]] = None, + first_num_space: Optional[List[int]] = None, ): if tokenizer_kwargs is None: tokenizer_kwargs = {} @@ -1048,7 +1092,7 @@ def replace_tag(self, media_type, index, example) -> List[Context]: assert media_type == 'image' return [[-200], '\n'] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -1107,7 +1151,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa assert media_type == 'image' return [[-100]] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: from .utils import history_to_messages inputs, _ = super().encode(example) @@ -1291,7 +1335,7 @@ def __init__(self): system_prefix = ['[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n'] super().__init__(prefix, prompt, chat_sep, suffix, self.INTERNLM_XCOMPOSER_SYSTEM, system_prefix) - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = example.copy() history = example.pop('history', None) if history is None: @@ -1427,7 +1471,7 @@ def replace_tag(self, media_type, index, example) -> List[Context]: assert media_type == 'image' return [[-100]] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -1472,7 +1516,6 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]: class Internvl2Template(InternvlTemplate): - video_segments = 8 def __init__(self): @@ -1509,28 +1552,9 @@ def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: else: return [''] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = example.copy() - history = example.pop('history', None) - if history is None: - history = [] - example['query'], example['history'], images_path = replace_img_tag(example['query'], history, '') - if images_path: - images = example.get('images') or [] - images.extend(images_path) - example['images'] = images - images_path = example.get('images') or [] - videos_path = example.get('videos') or [] - if not isinstance(images_path, (list, tuple)): - images_path = [images_path] - if not isinstance(images_path, (list, tuple)): - videos_path = [videos_path] - from .vision_utils import load_image, load_video - pixel_values_images = _read_batch(images_path, load_image) - videos_path = [path for path in videos_path if path is not None] - if example.get('objects'): - example['objects'] = json.loads(example.get('objects')) - self.normalize_bbox(example.get('objects'), pixel_values_images, to_type='norm_1000') + history = example.pop('history', []) inputs, _ = super(InternvlTemplate, self).encode(example) if len(inputs) == 0: return inputs, {} @@ -1689,7 +1713,7 @@ def _construct_prompts(self, text): prompts.append(_text) return prompts - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = example.copy() # read image processor = self.tokenizer.processor @@ -1810,7 +1834,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa else: return ['\n'] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -1837,7 +1861,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa else: return [''], system_prefix=['<>\n{{system}}\n<>\n\n']) - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if 'pixel_values' in inputs: inputs['pixel_values'] = inputs['pixel_values'].squeeze(0) @@ -1961,7 +1985,7 @@ def __init__(self): self.system, system_prefix=['{{SYSTEM}} ']) - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if 'pixel_values' in inputs: inputs['pixel_values'] = inputs['pixel_values'].squeeze(0) @@ -1982,7 +2006,7 @@ def __init__(self): ['<|im_end|>'], system_prefix=['<|im_start|>system\n{{SYSTEM}}<|im_end|>']) - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if 'pixel_values' in inputs: inputs['pixel_values'] = inputs['pixel_values'].squeeze(0) @@ -2003,7 +2027,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa def __init__(self): Template.__init__(self, [], [self.llavallama_query_template], ['<|eot_id|>'], ['<|eot_id|>']) - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -2036,7 +2060,7 @@ def replace_tag(self, media_type, index, example) -> List[Context]: assert media_type == 'image' return ['' * self.tokenizer.processor.image_seq_length] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -2075,7 +2099,7 @@ def __init__(self): def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]: return ['<|image|>'] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = example.copy() history = example.pop('history', None) if history is None: @@ -2182,7 +2206,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa assert media_type == 'image' return [''] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = example.copy() history = example.pop('history', None) if history is None: @@ -2264,7 +2288,7 @@ def check_example(self, example): def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]: return [] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -2361,7 +2385,7 @@ def check_example(self, example): videos = example.get('videos') or [] assert len(videos) <= 1 - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super(CogTemplate, self).encode(example) if len(inputs) == 0: return inputs, {} @@ -2422,7 +2446,7 @@ def check_example(self, example): images = example.get('images') or [] assert len(images) == 1 - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -2563,7 +2587,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa assert media_type == 'image' return [[-200]] - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: from mplug_owl2.mm_utils import process_images processor = self.tokenizer.processor images_path = example.get('images') or [] @@ -2617,13 +2641,13 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = def get_template( - template_type: str, - tokenizer: PreTrainedTokenizerBase, - default_system: Optional[str] = None, - max_length: Optional[int] = None, - truncation_strategy: Literal['delete', 'truncation_left'] = 'delete', - model=None, - **kwargs, + template_type: str, + tokenizer: PreTrainedTokenizerBase, + default_system: Optional[str] = None, + max_length: Optional[int] = None, + truncation_strategy: Literal['delete', 'truncation_left'] = 'delete', + model=None, + **kwargs, ) -> Template: template_info = TEMPLATE_MAPPING[template_type] template = deepcopy(template_info['template']) diff --git a/swift/llm/utils/vision_utils.py b/swift/llm/utils/vision_utils.py index 1d0b8e876..f3e8707fe 100644 --- a/swift/llm/utils/vision_utils.py +++ b/swift/llm/utils/vision_utils.py @@ -75,7 +75,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai return processed_images -def load_image(img_path, input_size=448, max_num=6): +def load_image(img_path): if isinstance(img_path, str): img_path = img_path.strip() if img_path.startswith('http'): @@ -93,6 +93,10 @@ def load_image(img_path, input_size=448, max_num=6): image = img_path if image.mode != 'RGB': image = image.convert('RGB') + return image + + +def transform_image(image, input_size=448, max_num=6): transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images] From 61c2bd9979064b67b7d669cc9b88f50cd97e8f91 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 23 Jul 2024 14:27:08 +0800 Subject: [PATCH 05/19] wip --- swift/llm/utils/template.py | 130 +++++++++----------------------- swift/llm/utils/vision_utils.py | 20 ++++- 2 files changed, 55 insertions(+), 95 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 5e94c8144..edfac7ee6 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -166,6 +166,7 @@ class Template: special_tokens = ['', '', '', '', ''] special_keys = ['images', 'videos', 'audios', 'objects'] grounding_type = 'norm_1000' + image_placeholder = '' def __init__(self, prefix: Prompt, @@ -344,7 +345,7 @@ def preprocess(self, example): example[media_key] = [example[media_key]] # Parse format images and merged into images key - example['query'], example['history'], images_path = replace_img_tag(example['query'], history, '') + example['query'], example['history'], images_path = replace_img_tag(example['query'], history, self.image_placeholder) if images_path: images = example.get('images', []) images.extend(images_path) @@ -1049,31 +1050,6 @@ class QwenAudioGenerationTemplate(_QwenAudioTemplateMixin, DefaultGenerationTemp '仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。') -def _load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image': - from PIL import Image, UnidentifiedImageError - import os - import base64 - import binascii - if isinstance(img_path, str): - img_path = img_path.strip() - if img_path.startswith('http'): - content = requests.get(img_path).content - image = Image.open(BytesIO(content)) - elif os.path.exists(img_path): - image = Image.open(img_path) - else: # base64_str - try: - image_data = base64.b64decode(img_path) - image = Image.open(BytesIO(image_data)) - except (binascii.Error, UnidentifiedImageError) as error: - raise ValueError(f'invalid image: {error}') - else: - image = img_path - if image.mode != 'RGB': - image = image.convert('RGB') - return image - - def _load_video_llava(video_path: str) -> np.ndarray: import av container = av.open(video_path) @@ -1094,16 +1070,6 @@ def _load_video_llava(video_path: str) -> np.ndarray: _T = TypeVar('_T') -def _read_batch(path_list: List[Union[str, 'PIL.Image.Image', None]], - load_func: Callable[[str], _T] = _load_image) -> List[_T]: - res = [] - for path in path_list: - if path is None: # ignore None - continue - res.append(load_func(path)) - return res - - class YiVLTemplate(Template): def replace_tag(self, media_type, index, example) -> List[Context]: @@ -1120,8 +1086,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An if not hasattr(model, 'vision_tower'): model = model.model image_processor = model.vision_tower.image_processor - images_path = example.get('images') or [] - images = _read_batch(images_path) + images = example.get('images', []) for i, image in enumerate(images): background_color = tuple(int(x * 255) for x in image_processor.image_mean) image = expand2square(image, background_color) @@ -1180,8 +1145,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An idx_list = _findall(input_ids, -100) if idx_list: idx = idx_list[0] - images_path = example.get('images') or [] - image = _read_batch(images_path)[0] + image = example.get('images', []) placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>' placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False) input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:]) @@ -1344,6 +1308,7 @@ class InternLMXComposer2Template(Template): '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen ' 'by the user such as English and 中文.') is_v2_5 = False + image_placeholder = '' def __init__(self): prefix = [''] @@ -1354,17 +1319,11 @@ def __init__(self): super().__init__(prefix, prompt, chat_sep, suffix, self.INTERNLM_XCOMPOSER_SYSTEM, system_prefix) def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - example = example.copy() - history = example.pop('history', None) - if history is None: - history = [] - example['query'], example['history'], images_path = replace_img_tag(example['query'], history, '') inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} dtype = self.model.dtype - images_path.extend(example.get('images') or []) - images = _read_batch(images_path) + images = example.get('images', []) if self.is_v2_5: hd_num = 24 Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', self.tokenizer.model_dir) @@ -1496,11 +1455,8 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An input_ids = inputs['input_ids'] idx_list = _findall(input_ids, -100) labels = inputs.get('labels') - images_path = example.get('images') or [] - if not isinstance(images_path, (list, tuple)): - images_path = [images_path] - from .vision_utils import load_image - pixel_values = _read_batch(images_path, load_image) + from .vision_utils import transform_image + pixel_values = [transform_image(image) for image in example.get('images', [])] if pixel_values: pixel_values = torch.cat(pixel_values, dim=0) image_bs = pixel_values.shape[0] @@ -1571,14 +1527,15 @@ def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: return [''] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - example = example.copy() - history = example.pop('history', []) inputs, _ = super(InternvlTemplate, self).encode(example) if len(inputs) == 0: return inputs, {} input_ids = inputs['input_ids'] idx_list = _findall(input_ids, -100) labels = inputs.get('labels') + from swift.llm.utils.vision_utils import transform_image + pixel_values_images = [transform_image(image) for image in example.get('images', [])] + videos_path = example.get('videos_path', []) if pixel_values_images: pixel_values = pixel_values_images assert len(pixel_values) == len(idx_list) @@ -1600,6 +1557,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs['image_flags'] = torch.ones(patches) elif videos_path: assert len(videos_path) == 1 + from swift.llm.utils.vision_utils import load_video pixel_values, num_patches = load_video(videos_path[0], num_segments=self.video_segments) assert len(num_patches) == len(idx_list) added_tokens_len = 0 @@ -1732,15 +1690,12 @@ def _construct_prompts(self, text): return prompts def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - example = example.copy() - # read image processor = self.tokenizer.processor - images_path = example.get('images') or [] - assert len(images_path) == 1, 'Florence series models only supports input with a single image.' - - images = _load_image(images_path[0]) + images = example.get('images', []) + assert len(images) == 1, 'Florence series models only supports input with a single image.' + from .vision_utils import transform_image + images = transform_image(images[0]) example['_image'] = images - self.normalize_bbox(example.get('objects'), images, to_type='norm_1000') # process bbox if example.get('objects') is not None: @@ -1856,8 +1811,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - images_path = example.get('images') or [] - images = _read_batch(images_path) + images = example.get('images', []) image_processor = self.tokenizer.processor.image_processor if self._is_vllm: images = self._prepare_vllm_images(images) @@ -1883,7 +1837,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - media_files = example.get('videos') or [] + media_files = example.get('videos', []) images_path, videos_path = [], [] for media_file in media_files: if media_file is None: @@ -1893,11 +1847,13 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An else: videos_path.append(media_file) if len(videos_path) > 0: + from swift.llm.utils.vision_utils import _read_batch videos = _read_batch(videos_path, _load_video_llava) video_processor = self.tokenizer.processor.video_processor video_inputs = video_processor(videos, return_tensors='pt').to(self.model.dtype) inputs['pixel_values_videos'] = video_inputs['pixel_values_videos'] if len(images_path) > 0: + from swift.llm.utils.vision_utils import _read_batch images = _read_batch(images_path) image_processor = self.tokenizer.processor.image_processor image_inputs = image_processor(images, return_tensors='pt').to(self.model.dtype) @@ -1949,8 +1905,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - images_path = example.get('images') or [] - images = _read_batch(images_path) + images = example.get('images', []) image_sizes = [x.size for x in images] from llava.mm_utils import process_images model = self.model.model @@ -2049,8 +2004,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - image_path = example.get('images') or [] - raw_image = _read_batch(image_path) + raw_image = example.get('images', []) if raw_image: pixel_values = self.tokenizer.processor.image_processor(raw_image[0], return_tensors='pt')['pixel_values'] inputs['pixel_values'] = pixel_values.to(self.model.dtype) @@ -2082,7 +2036,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - image_path = example.get('images') or [] + raw_image = example.get('images', []) processor = self.tokenizer.processor if inputs['labels'] is not None: n = upper_bound(0, len(inputs['labels']), lambda idx: inputs['labels'][idx] == -100) @@ -2090,7 +2044,6 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs['token_type_ids'] = [0] * n + [1] * n2 else: inputs['token_type_ids'] = [0] * len(inputs['input_ids']) - raw_image = _read_batch(image_path) if raw_image: model_inputs = processor(text=example['query'], images=raw_image[0], return_tensors='pt') inputs['pixel_values'] = model_inputs['pixel_values'] @@ -2110,6 +2063,8 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = class Phi3VisionTemplate(Template): + image_placeholder = '<|image|>' + def __init__(self): Template.__init__(self, [''], ['<|user|>\n{{QUERY}}<|end|>\n<|assistant|>\n'], ['<|end|>\n'], ['<|end|>'], None, ['<|system|>\n{{SYSTEM}}<|end|>\n']) @@ -2118,13 +2073,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa return ['<|image|>'] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - example = example.copy() - history = example.pop('history', None) - if history is None: - history = [] - example['query'], example['history'], images_path = replace_img_tag(example['query'], history, '<|image|>') - images_path.extend(example.get('images') or []) - images = _read_batch(images_path) + images = example.get('images', []) inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} @@ -2216,6 +2165,8 @@ class DeepseekVLTemplate(Template): 'You are able to understand the visual content that the user provides, ' 'and assist the user with a variety of tasks using natural language.') + image_placeholder = '' + def __init__(self): super().__init__(['<|begin▁of▁sentence|>{{SYSTEM}}\n\n'], ['User: {{QUERY}}\n\nAssistant:'], ['<|end▁of▁sentence|>'], ['<|end▁of▁sentence|>'], self.DEEPSEEK_VL_SYSTEM) @@ -2225,18 +2176,10 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa return [''] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - example = example.copy() - history = example.pop('history', None) - if history is None: - history = [] - - example['query'], example['history'], images_path = replace_img_tag(example['query'], history, - '') inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - images_path.extend(example.get('images') or []) - images = _read_batch(images_path) + images = example.get('images') processor = self.tokenizer.processor input_ids, labels = inputs['input_ids'], inputs['labels'] idx_list = _findall(input_ids, processor.image_id) @@ -2310,8 +2253,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - images_path = example.get('images') or [] - image = _read_batch(images_path) + image = example.get('images', []) inputs.pop('loss_scale', None) model = self.model inputs2 = model.build_conversation_input_ids( @@ -2407,7 +2349,8 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs, _ = super(CogTemplate, self).encode(example) if len(inputs) == 0: return inputs, {} - videos_path = example.get('videos') or [] + videos_path = example.get('videos', []) + from swift.llm.utils.vision_utils import _read_batch video = _read_batch(videos_path, _load_video_cogvlm2) inputs.pop('loss_scale', None) model = self.model @@ -2461,15 +2404,15 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa return [[-1]] def check_example(self, example): - images = example.get('images') or [] + images = example.get('images', []) assert len(images) == 1 def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: inputs, _ = super().encode(example) if len(inputs) == 0: return inputs, {} - images_path = example['images'] - image = _load_image(images_path[0]) + images = example['images'] + image = images[0] input_ids = inputs['input_ids'] labels = inputs['labels'] idx_list = _findall(input_ids, -1) @@ -2608,8 +2551,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: from mplug_owl2.mm_utils import process_images processor = self.tokenizer.processor - images_path = example.get('images') or [] - images = _read_batch(images_path) + images = example.get('images', []) for i, image in enumerate(images): # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1 max_edge = max(image.size) diff --git a/swift/llm/utils/vision_utils.py b/swift/llm/utils/vision_utils.py index f3e8707fe..b91375708 100644 --- a/swift/llm/utils/vision_utils.py +++ b/swift/llm/utils/vision_utils.py @@ -2,6 +2,7 @@ import binascii import os from io import BytesIO +from typing import Union, List, Callable, TypeVar import numpy as np import requests @@ -75,7 +76,11 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai return processed_images -def load_image(img_path): +def _load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image': + from PIL import Image, UnidentifiedImageError + import os + import base64 + import binascii if isinstance(img_path, str): img_path = img_path.strip() if img_path.startswith('http'): @@ -96,6 +101,19 @@ def load_image(img_path): return image +_T = TypeVar('_T') + + +def _read_batch(path_list: List[Union[str, 'PIL.Image.Image', None]], + load_func: Callable[[str], _T] = _load_image) -> List[_T]: + res = [] + for path in path_list: + if path is None: # ignore None + continue + res.append(load_func(path)) + return res + + def transform_image(image, input_size=448, max_num=6): transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) From 686b355e17569b3840a26b31f2614b1213124513 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 23 Jul 2024 14:33:48 +0800 Subject: [PATCH 06/19] fix --- swift/llm/utils/template.py | 2 +- swift/llm/utils/vision_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index edfac7ee6..667063f85 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -387,7 +387,7 @@ def preprocess(self, example): example['history_roles'] = [['user', 'assistant'] for _ in range(len(history))] # Load image into PIL format - from .vision_utils import load_image, load_video + from .vision_utils import load_image if example.get('images'): example['images'] = [load_image(img) for img in example['images']] # Normalize grounding bboxes diff --git a/swift/llm/utils/vision_utils.py b/swift/llm/utils/vision_utils.py index b91375708..531d66963 100644 --- a/swift/llm/utils/vision_utils.py +++ b/swift/llm/utils/vision_utils.py @@ -76,7 +76,7 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai return processed_images -def _load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image': +def load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image': from PIL import Image, UnidentifiedImageError import os import base64 @@ -105,7 +105,7 @@ def _load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image': def _read_batch(path_list: List[Union[str, 'PIL.Image.Image', None]], - load_func: Callable[[str], _T] = _load_image) -> List[_T]: + load_func: Callable[[str], _T] = load_image) -> List[_T]: res = [] for path in path_list: if path is None: # ignore None From 7b51f718750a588c18d6f7dfa3d1fbb43f0fbde4 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 23 Jul 2024 15:12:12 +0800 Subject: [PATCH 07/19] fix --- swift/llm/utils/template.py | 57 +++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 667063f85..e056c648c 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -14,7 +14,6 @@ from transformers import PreTrainedTokenizerBase, StoppingCriteria from transformers.dynamic_module_utils import get_class_from_dynamic_module -from swift.llm import MediaTag from swift.llm.agent.utils import calculate_loss_scale, get_tools_prompt from swift.torchacc_utils import pad_and_split_batch from swift.utils import get_dist_setting, upper_bound, use_torchacc @@ -338,6 +337,7 @@ def preprocess(self, example): assert self.support_multi_round, ( f'The template does not support multi-round chat, template_type: {template_type}') + from swift.llm import MediaTag # Format media_keys to list for media_key in MediaTag.media_keys.values(): if example.get(media_key) and not isinstance(example[media_key], (tuple, list)): @@ -392,6 +392,7 @@ def preprocess(self, example): example['images'] = [load_image(img) for img in example['images']] # Normalize grounding bboxes self.normalize_bbox(example.get('objects'), example.get('images'), to_type=self.grounding_type) + return example def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = self.preprocess(example) @@ -525,7 +526,7 @@ def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]: objects = example.get('objects') if objects: object_ = objects[index] - return [object_[0]] + return [object_['caption']] else: return [''] @@ -533,7 +534,7 @@ def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: objects = example.get('objects') if objects: object_ = objects[index] - return [f'({object_[1][0]},{object_[1][1]}),({object_[1][2]},{object_[1][3]})'] + return [f'({object_["bbox"][0]},{object_["bbox"][1]}),({object_["bbox"][2]},{object_["bbox"][3]})'] else: return [''] @@ -971,12 +972,12 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]: objects = example['objects'] object_ = objects[index] - return [f'{object_[0]}'] + return [f'{object_["caption"]}'] def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: objects = example['objects'] object_ = objects[index] - return [f'({object_[1][0]},{object_[1][1]}),({object_[1][2]},{object_[1][3]})'] + return [f'({object_["bbox"][0]},{object_["bbox"][1]}),({object_["bbox"][2]},{object_["bbox"][3]})'] register_template(TemplateType.qwen, QwenTemplate()) @@ -991,8 +992,8 @@ def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: class _QwenAudioTemplateMixin: - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, tokenizer_kwargs = super().encode(example) + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + inputs, tokenizer_kwargs = super()._encode(example) if len(inputs) == 0: return inputs, tokenizer_kwargs inputs.pop('loss_scale', None) @@ -1077,7 +1078,7 @@ def replace_tag(self, media_type, index, example) -> List[Context]: return [[-200], '\n'] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, _ = super().encode(example) + inputs, _ = super()._encode(example) if len(inputs) == 0: return inputs, {} inputs.pop('loss_scale', None) @@ -1137,7 +1138,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: from .utils import history_to_messages - inputs, _ = super().encode(example) + inputs, _ = super()._encode(example) if len(inputs) == 0: return inputs, {} input_ids = inputs['input_ids'] @@ -1319,7 +1320,7 @@ def __init__(self): super().__init__(prefix, prompt, chat_sep, suffix, self.INTERNLM_XCOMPOSER_SYSTEM, system_prefix) def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, _ = super().encode(example) + inputs, _ = super()._encode(example) if len(inputs) == 0: return inputs, {} dtype = self.model.dtype @@ -1449,7 +1450,7 @@ def replace_tag(self, media_type, index, example) -> List[Context]: return [[-100]] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, _ = super().encode(example) + inputs, _ = super()._encode(example) if len(inputs) == 0: return inputs, {} input_ids = inputs['input_ids'] @@ -1514,7 +1515,7 @@ def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]: objects = example.get('objects') if objects: object_ = objects[index] - return [f'{object_}'] + return [f'{object_["caption"]}'] else: return [''] @@ -1522,12 +1523,12 @@ def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]: objects = example.get('objects') if objects: object_ = objects[index] - return [f' [[{object_[1][0]}, {object_[1][1]}, {object_[1][2]}, {object_[1][3]}]] '] + return [f' [[{object_["bbox"][0]}, {object_["bbox"][1]}, {object_["bbox"][2]}, {object_["bbox"][3]}]] '] else: return [''] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, _ = super(InternvlTemplate, self).encode(example) + inputs, _ = super(InternvlTemplate, self)._encode(example) if len(inputs) == 0: return inputs, {} input_ids = inputs['input_ids'] @@ -1808,7 +1809,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa return ['\n'] def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: - inputs, _ = super().encode(example) + inputs, _ = super()._encode(example) if len(inputs) == 0: return inputs, {} images = example.get('images', []) @@ -1834,7 +1835,7 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa return ['