Skip to content

Commit

Permalink
Fix minicpm device map (modelscope#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed May 21, 2024
1 parent 8c841d4 commit 54233a2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
4 changes: 3 additions & 1 deletion docs/source/Multi-Modal/minicpm-v-2.5最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

## 环境准备
```shell
pip install 'ms-swift[llm]' -U
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
```
模型链接:
- minicpm-v-v2_5-chat: [https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)
Expand Down
4 changes: 3 additions & 1 deletion docs/source/Multi-Modal/minicpm-v-2最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

## 环境准备
```shell
pip install 'ms-swift[llm]' -U
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e '.[llm]'
```

## 推理
Expand Down
47 changes: 47 additions & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3870,6 +3870,52 @@ def get_model_tokenizer_yi_vl(model_dir: str,
return model, tokenizer


def _patch_minicpm_v_device_map(model) -> None:
if not hasattr(model, 'hf_device_map') or len(model.hf_device_map.values()) == 1:
return
if hasattr(model.llm, '__old_forward'):
# avoid double patching
return
device = list(model.hf_device_map.values())[0]
if hasattr(model, 'get_vision_embedding'): # minicpm-v-v2-chat
_old_get_vision_embedding = model.get_vision_embedding

def _get_vision_embedding(pixel_values):
if len(pixel_values) == 0:
return _old_get_vision_embedding(pixel_values)
output = _old_get_vision_embedding(pixel_values)
return output.to(device=device)

model._old_get_vision_embedding = _old_get_vision_embedding
model.get_vision_embedding = _get_vision_embedding

if hasattr(model, 'resampler'): # minicpm-v-v2_5-chat
__old_resampler_forward = model.resampler.forward

def _new_resampler_forward(*args, **kwargs) -> Tensor:
output = __old_resampler_forward(*args, **kwargs)
return output.to(device=device)

model.resampler.forward = _new_resampler_forward

__old_forward = model.llm.forward

def _new_forward(*args, **kwargs) -> Tensor:
inputs = kwargs.get('inputs_embeds')
if inputs is None:
inputs = kwargs.get('input_ids')
device = inputs.device
output = __old_forward(*args, **kwargs)
if output.logits is not None:
output.logits = output.logits.to(device)
if output.loss is not None:
output.loss = output.loss.to(device)
return output

model.llm.forward = _new_forward
model.llm.__old_forward = __old_forward


@register_model(
ModelType.minicpm_v_3b_chat,
'OpenBMB/MiniCPM-V',
Expand Down Expand Up @@ -3904,6 +3950,7 @@ def get_model_tokenizer_minicpm_v(model_dir: str,
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
if load_model:
model.resampler.to(torch_dtype) # fix float32
_patch_minicpm_v_device_map(model)
func_list = ['generate', 'get_input_embeddings', 'forward']
_use_submodel_func(model, 'llm', func_list)
if patching_embedding:
Expand Down

0 comments on commit 54233a2

Please sign in to comment.