Skip to content

Commit

Permalink
support minicpm-v-v2_6-chat (modelscope#1609)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Aug 6, 2024
1 parent 26cf37e commit b43e56e
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 8 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group:
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">

## 🎉 News
- 2024.08.06: Support for minicpm-v-v2_6-chat is available. You can use `swift infer --model_type minicpm-v-v2_6-chat` for inference experience. Best practices can be found [here](https://github.com/modelscope/swift/issues/1613).
- 2024.08.06: Supports internlm2.5 series of 1.8b and 20b. Experience it using `swift infer --model_type internlm2_5-1_8b-chat`.
- 🔥2024.08.05: Support evaluation for multi-modal models! Same command with [new datasets](https://swift.readthedocs.io/en/latest/LLM/LLM-eval.html#introduction).
- 🔥2024.08.02: Support Fourier Ft. Use `--sft_type fourierft` to begin, Check parameter documentation [here](https://swift.readthedocs.io/en/latest/LLM/Command-line-parameters.html).
Expand Down Expand Up @@ -612,7 +613,7 @@ The complete list of supported models and datasets can be found at [Supported Mo
| YI-VL | [01AI's YI series vision models](https://github.com/01-ai) | Chinese<br>English | 6B-34B | chat model |
| XComposer2<br>XComposer2.5 | [Pujiang AI Lab InternLM vision model](https://github.com/InternLM/InternLM-XComposer) | Chinese<br>English | 7B | chat model |
| DeepSeek-VL | [DeepSeek series vision models](https://github.com/deepseek-ai) | Chinese<br>English | 1.3B-7B | chat model |
| MiniCPM-V<br>MiniCPM-V-2<br>MiniCPM-V-2_5 | [OpenBmB MiniCPM vision model](https://github.com/OpenBMB/MiniCPM) | Chinese<br>English | 3B-9B | chat model |
| MiniCPM-V<br>MiniCPM-V-2<br>MiniCPM-V-2.5<br>MiniCPM-V-2.6 | [OpenBmB MiniCPM vision model](https://github.com/OpenBMB/MiniCPM) | Chinese<br>English | 3B-9B | chat model |
| CogVLM<br>CogAgent<br>CogVLM2<br>CogVLM2-Video<br>GLM4V | [Zhipu ChatGLM visual QA and Agent model](https://github.com/THUDM/) | Chinese<br>English | 9B-19B | chat model |
| Llava1.5<br>Llava1.6 | [Llava series models](https://github.com/haotian-liu/LLaVA) | English | 7B-34B | chat model |
| Llava-Next<br>Llava-Next-Video | [Llava-Next series models](https://github.com/LLaVA-VL/LLaVA-NeXT) | Chinese<br>English | 7B-110B | chat model |
Expand Down
3 changes: 2 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:


## 🎉 新闻
- 2024.08.06: 支持minicpm-v-v2_6-chat, 使用`swift infer --model_type minicpm-v-v2_6-chat`进行推理体验, 最佳实践可以查看[这里](https://github.com/modelscope/swift/issues/1613).
- 2024.08.06: 支持internlm2.5的1.8b和20b系列. 使用`swift infer --model_type internlm2_5-1_8b-chat`进行体验.
- 🔥2024.08.05: 支持多模态数据集的评测!命令行完全一致,新增了许多[多模态数据集](https://swift.readthedocs.io/zh-cn/latest/LLM/LLM%E8%AF%84%E6%B5%8B%E6%96%87%E6%A1%A3.html#id2).
- 🔥2024.08.02: 支持Fourier Ft训练. 使用方式为`--sft_type fourierft`, 参数可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/LLM/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html).
Expand Down Expand Up @@ -606,7 +607,7 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
| YI-VL | [01AI的YI系列视觉模型](https://github.com/01-ai) | 中文<br>英文 | 6B-34B | chat模型 |
| XComposer2<br>XComposer2.5 | [浦江实验室书生浦语视觉模型](https://github.com/InternLM/InternLM-XComposer) | 中文<br>英文 | 7B | chat模型 |
| DeepSeek-VL | [幻方系列视觉模型](https://github.com/deepseek-ai) | 中文<br>英文 | 1.3B-7B | chat模型 |
| MiniCPM-V<br>MiniCPM-V-2<br>MiniCPM-V-2_5 | [OpenBmB MiniCPM视觉模型](https://github.com/OpenBMB/MiniCPM) | 中文<br>英文 | 3B-9B | chat模型 |
| MiniCPM-V<br>MiniCPM-V-2<br>MiniCPM-V-2.5<br>MiniCPM-V-2.6 | [OpenBmB MiniCPM视觉模型](https://github.com/OpenBMB/MiniCPM) | 中文<br>英文 | 3B-9B | chat模型 |
| CogVLM<br>CogAgent<br>CogVLM2<br>CogVLM2-Video<br>GLM4V | [智谱ChatGLM视觉问答和Agent模型](https://github.com/THUDM/) | 中文<br>英文 | 9B-19B | chat模型 |
| Llava1.5<br>Llava1.6 | [Llava系列模型](https://github.com/haotian-liu/LLaVA) | 英文 | 7B-34B | chat模型 |
| Llava-Next<br>Llava-Next-Video | [Llava-Next系列模型](https://github.com/LLaVA-VL/LLaVA-NeXT) | 中文<br>英文 | 7B-110B | chat模型 |
Expand Down
1 change: 1 addition & 0 deletions docs/source/LLM/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
|minicpm-v-3b-chat|[OpenBMB/MiniCPM-V](https://modelscope.cn/models/OpenBMB/MiniCPM-V/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers<4.42|vision|[openbmb/MiniCPM-V](https://huggingface.co/openbmb/MiniCPM-V)|
|minicpm-v-v2-chat|[OpenBMB/MiniCPM-V-2](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers<4.42|vision|[openbmb/MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2)|
|minicpm-v-v2_5-chat|[OpenBMB/MiniCPM-Llama3-V-2_5](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_5|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers>=4.36|vision|[openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)|
|minicpm-v-v2_6-chat|[OpenBMB/MiniCPM-V-2_6](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_6|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers>=4.36, decord|vision|[openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)|
|mplug-owl2-chat|[iic/mPLUG-Owl2](https://modelscope.cn/models/iic/mPLUG-Owl2/summary)|q_proj, k_proj.multiway.0, k_proj.multiway.1, v_proj.multiway.0, v_proj.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[MAGAer13/mplug-owl2-llama2-7b](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b)|
|mplug-owl2_1-chat|[iic/mPLUG-Owl2.1](https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary)|c_attn.multiway.0, c_attn.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[Mizukiluke/mplug_owl_2_1](https://huggingface.co/Mizukiluke/mplug_owl_2_1)|
|phi3-vision-128k-instruct|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct/summary)|^(model.layers\|model.vision_embed_tokens.img_projection)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|phi3-vl|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.36|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/LLM/Supported-models-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ The table below introcudes all models supported by SWIFT:
|minicpm-v-3b-chat|[OpenBMB/MiniCPM-V](https://modelscope.cn/models/OpenBMB/MiniCPM-V/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers<4.42|vision|[openbmb/MiniCPM-V](https://huggingface.co/openbmb/MiniCPM-V)|
|minicpm-v-v2-chat|[OpenBMB/MiniCPM-V-2](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers<4.42|vision|[openbmb/MiniCPM-V-2](https://huggingface.co/openbmb/MiniCPM-V-2)|
|minicpm-v-v2_5-chat|[OpenBMB/MiniCPM-Llama3-V-2_5](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_5|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers>=4.36|vision|[openbmb/MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)|
|minicpm-v-v2_6-chat|[OpenBMB/MiniCPM-V-2_6](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2_6/summary)|^(llm\|resampler)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|minicpm-v-v2_6|&#x2714;|&#x2718;|&#x2718;|&#x2718;|timm, transformers>=4.36, decord|vision|[openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6)|
|mplug-owl2-chat|[iic/mPLUG-Owl2](https://modelscope.cn/models/iic/mPLUG-Owl2/summary)|q_proj, k_proj.multiway.0, k_proj.multiway.1, v_proj.multiway.0, v_proj.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[MAGAer13/mplug-owl2-llama2-7b](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b)|
|mplug-owl2_1-chat|[iic/mPLUG-Owl2.1](https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary)|c_attn.multiway.0, c_attn.multiway.1|mplug-owl2|&#x2714;|&#x2718;|&#x2718;|&#x2718;|transformers<4.35, icecream|vision|[Mizukiluke/mplug_owl_2_1](https://huggingface.co/Mizukiluke/mplug_owl_2_1)|
|phi3-vision-128k-instruct|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct/summary)|^(model.layers\|model.vision_embed_tokens.img_projection)(?!.\*(lm_head\|output\|emb\|wte\|shared)).\*|phi3-vl|&#x2714;|&#x2714;|&#x2718;|&#x2718;|transformers>=4.36|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def _preprocess_vision_dataset(dataset: HfDataset) -> HfDataset:
response_key = 'caption'

dataset._info.features._column_requires_decoding['image'] = False
query_format = f'Picture 1:<img>{{image_path}}</img>\n{prompt}'
query_format = f'<img>{{image_path}}</img>{prompt}'
query = []
response = []
for d in tqdm(dataset):
Expand Down
25 changes: 23 additions & 2 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class ModelType:
minicpm_v_3b_chat = 'minicpm-v-3b-chat'
minicpm_v_v2_chat = 'minicpm-v-v2-chat'
minicpm_v_v2_5_chat = 'minicpm-v-v2_5-chat'
minicpm_v_v2_6_chat = 'minicpm-v-v2_6-chat'
# openbuddy
openbuddy_llama_65b_chat = 'openbuddy-llama-65b-chat'
openbuddy_llama2_13b_chat = 'openbuddy-llama2-13b-chat'
Expand Down Expand Up @@ -5705,6 +5706,17 @@ def get_model_tokenizer_minicpm_v(model_dir: str,
return model, tokenizer


@register_model(
ModelType.minicpm_v_v2_6_chat,
'OpenBMB/MiniCPM-V-2_6',
LoRATM.minicpm_v,
TemplateType.minicpm_v_v2_6,
support_flash_attn=True,
requires=['timm', 'transformers>=4.36', 'decord'],
placeholder_tokens=['<unk>'],
function_kwargs={'version': 'v2.6'},
tags=['multi-modal', 'vision'],
hf_model_id='openbmb/MiniCPM-V-2_6')
@register_model(
ModelType.minicpm_v_v2_5_chat,
'OpenBMB/MiniCPM-Llama3-V-2_5',
Expand All @@ -5715,13 +5727,17 @@ def get_model_tokenizer_minicpm_v(model_dir: str,
placeholder_tokens=['<unk>'],
tags=['multi-modal', 'vision'],
hf_model_id='openbmb/MiniCPM-Llama3-V-2_5')
def get_model_tokenizer_minicpm_v_2_5(model_dir: str,
def get_model_tokenizer_minicpm_v_2_x(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
version = kwargs.get('version', 'v2.5')
if version == 'v2.6':
model_cls = get_class_from_dynamic_module('modeling_navit_siglip.SiglipVisionTransformer', model_dir)
model_cls._no_split_modules = []
model, tokenizer = get_model_tokenizer_minicpm_v(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
tokenizer.processor = processor
if load_model:
Expand Down Expand Up @@ -6277,7 +6293,12 @@ def get_model_tokenizer(model_type: str,


def get_additional_saved_files(model_type: str) -> List[str]:
files_mapping = {'qwen-vl': ['SimSun.ttf'], 'qwen-audio': ['mel_filters.npz'], 'yi-vl': ['vit']}
files_mapping = {
'qwen-vl': ['SimSun.ttf'],
'qwen-audio': ['mel_filters.npz'],
'yi-vl': ['vit'],
'minicpm-v-v2_6-chat': ['modeling_navit_siglip.py']
}
for key, files_list in files_mapping.items():
if key in model_type:
return files_list
Expand Down
123 changes: 120 additions & 3 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class TemplateType:
minicpm = 'minicpm'
minicpm_v = 'minicpm-v'
minicpm_v_v2_5 = 'minicpm-v-v2_5'
minicpm_v_v2_6 = 'minicpm-v-v2_6'
gemma = 'gemma'
paligemma = 'paligemma'
mplug_owl2 = 'mplug-owl2'
Expand Down Expand Up @@ -2466,7 +2467,6 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
if len(inputs) == 0:
return inputs, {}
images = example['images']
image = images[0]
input_ids = inputs['input_ids']
labels = inputs['labels']
idx_list = _findall(input_ids, -1)
Expand All @@ -2483,7 +2483,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
pixel_values = to_device(image_inputs['pixel_values'], self.model.device)
tgt_sizes = image_inputs['tgt_sizes']
else:
images, placeholder = self.model.get_slice_image_placeholder(image, self.tokenizer)
images, placeholder = self.model.get_slice_image_placeholder(images[0], self.tokenizer)
pixel_values = [[self.model.transform(img).to(device=self.model.device) for img in images]]
placeholder += '\n'
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
Expand All @@ -2506,7 +2506,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
if labels is not None:
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
image_bound = [torch.tensor([[idx, idx + config.query_num]])]
pixel_values = [[self.model.transform(image).to(device=self.model.device)]]
pixel_values = [[self.model.transform(images[0]).to(device=self.model.device)]]
data = {
'input_ids': torch.tensor(input_ids)[None].to(device=self.model.device),
'image_bound': image_bound,
Expand Down Expand Up @@ -2535,6 +2535,123 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
dataloader_num_workers=0,
dataloader_pin_memory=False)


def _encode_video(video_path):
from .vision_utils import _read_video
mp4_stream = _read_video(video_path)
MAX_NUM_FRAMES = 64

from PIL import Image
from decord import VideoReader, cpu # pip install decord

def uniform_sample(_l, _n):
gap = len(_l) / _n
idxs = [int(i * gap + gap / 2) for i in range(_n)]
return [_l[i] for i in idxs]

vr = VideoReader(mp4_stream, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]

if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames


class MiniCPMV2_6Template(Template):

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
assert media_type in {'image', 'video'}
return [[-1]]

def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
from .vision_utils import _read_batch
inputs, _ = super()._encode(example)
if len(inputs) == 0:
return inputs, {}
images = example.get('images')
videos_path = example.get('videos')
is_plain_text = not images and not videos_path
images = [images]
use_image_id = True
max_slice_nums = None

if videos_path:
images = _read_batch(videos_path, _encode_video)
use_image_id = False
max_slice_nums = 1 # or 2

input_ids = inputs['input_ids']
labels = inputs['labels']
idx_list = _findall(input_ids, -1)
idx_list.insert(0, -1)

from .utils import to_device
image_processor = self.tokenizer.processor.image_processor
image_inputs = image_processor(images, return_tensors='pt', max_slice_nums=max_slice_nums).to(self.model.dtype)
pixel_values = to_device(image_inputs['pixel_values'], self.model.device)
tgt_sizes = image_inputs['tgt_sizes']

res_input_ids = []
res_labels = []
for i in range(len(idx_list) - 1):
placeholder = image_processor.get_slice_image_placeholder(
image_inputs.image_sizes[0][i], image_idx=i, max_slice_nums=max_slice_nums, use_image_id=use_image_id)
placeholder += '\n'
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
res_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + placeholder_id
if labels is not None:
res_labels += labels[idx_list[i] + 1:idx_list[i + 1]] + [-100] * len(placeholder_id)
res_input_ids += input_ids[idx_list[-1] + 1:]
input_ids = res_input_ids
if labels is not None:
res_labels += labels[idx_list[-1] + 1:]
labels = res_labels
if not is_plain_text:
input_tensor_ids = torch.tensor(input_ids)
unk_token = self.tokenizer.encode('<unk>', add_special_tokens=False)[0]
indices = (input_tensor_ids == unk_token).nonzero(as_tuple=True)[0].tolist()

ranges = []
start = indices[0]
for i in range(1, len(indices)):
if indices[i] != indices[i - 1] + 1:
ranges.append([start, indices[i - 1] + 1])
start = indices[i]
ranges.append([start, indices[-1] + 1])
image_bound = [torch.tensor(ranges)]
else:
image_bound = []

data = {
'input_ids': torch.tensor(input_ids)[None].to(device=self.model.device),
'image_bound': image_bound,
'pixel_values': pixel_values,
'tgt_sizes': tgt_sizes
}
inputs_embeds, _ = self.model.get_vllm_embedding(data)
inputs_embeds = inputs_embeds.detach()
inputs['input_ids'] = input_ids
inputs['labels'] = labels
inputs['inputs_embeds'] = inputs_embeds[0]
return inputs, {}

@staticmethod
def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
return generate_ids[0].tolist()


register_template(
TemplateType.minicpm_v_v2_6,
MiniCPMV2_6Template([], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'], ['<|im_end|>\n'],
['<|im_end|>'], DEFAULT_SYSTEM, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']),
use_model=True,
lazy_tokenize=True,
dataloader_num_workers=0,
dataloader_pin_memory=False)

register_template(
TemplateType.minicpm_v_v2_5,
MiniCPMVTemplate(['<|begin_of_text|>{{SYSTEM}}'], [
Expand Down

0 comments on commit b43e56e

Please sign in to comment.