diff --git a/README.md b/README.md index 9654f70ff..e97107e01 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ SWIFT has rich documentations for users, please check [here](https://github.com/ SWIFT web-ui is available both on [Huggingface space](https://huggingface.co/spaces/tastelikefeet/swift) and [ModelScope studio](https://www.modelscope.cn/studios/iic/Scalable-lightWeight-Infrastructure-for-Fine-Tuning/summary), please feel free to try! ## 🎉 News +- 2024.07.02: Support for `llava1_6-vicuna-7b-chat`, `llava1_6-vicuna-13b-chat` and other llava-hf models. For best practices, refer to [here](docs/source_en/Multi-Modal/llava-best-practice.md). - 🔥2024.06.29: Support [eval-scope](https://github.com/modelscope/eval-scope)&[open-compass](https://github.com/open-compass/opencompass) for evaluation! Now we have supported over 50 eval datasets like `BoolQ, ocnli, humaneval, math, ceval, mmlu, gsk8k, ARC_e`, please check our [Eval Doc](https://github.com/modelscope/swift/blob/main/docs/source_en/LLM/LLM-eval.md) to begin! Next sprint we will support Multi-modal and Agent evaluation, remember to follow us : ) - 🔥2024.06.28: Support for **Florence** series model! See [document](docs/source_en/Multi-Modal/florence-best-pratice.md) - 🔥2024.06.28: Support for Gemma2 series models: gemma2-9b, gemma2-9b-instruct, gemma2-27b, gemma2-27b-instruct. diff --git a/README_CN.md b/README_CN.md index f7ba67219..09c481108 100644 --- a/README_CN.md +++ b/README_CN.md @@ -48,6 +48,7 @@ SWIFT具有丰富的文档体系,如有使用问题请请查看[这里](https: 可以在[Huggingface space](https://huggingface.co/spaces/tastelikefeet/swift) 和 [ModelScope创空间](https://www.modelscope.cn/studios/iic/Scalable-lightWeight-Infrastructure-for-Fine-Tuning/summary) 中体验SWIFT web-ui功能了。 ## 🎉 新闻 +- 2024.07.02: 支持`llava1_6-vicuna-7b-chat`, `llava1_6-vicuna-13b-chat`等llava-hf模型. 最佳实践可以查看[这里](docs/source/Multi-Modal/llava最佳实践.md). - 🔥2024.06.29: 支持[eval-scope](https://github.com/modelscope/eval-scope)&[open-compass](https://github.com/open-compass/opencompass)评测! 我们支持了包含`BoolQ, ocnli, humaneval, math, ceval, mmlu, gsk8k, ARC_e`等50+标准数据集在内的评测流程, 请查看我们的[评测文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM评测文档.md)来使用。下个迭代我们会支持多模态评测和Agent评测,记得持续关注我们: ) - 🔥2024.06.28: 支持**Florence**系列模型: 可以查看[Florence最佳实践](docs/source/Multi-Modal/florence最佳实践.md). - 🔥2024.06.28: 支持**Gemma2**系列模型: gemma2-9b, gemma2-9b-instruct, gemma2-27b, gemma2-27b-instruct. diff --git "a/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index db7bb8260..e8ab9de00 100644 --- "a/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -238,6 +238,7 @@ |mistral-7b-v2|[AI-ModelScope/Mistral-7B-v0.2-hf](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.2-hf/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.34|-|[alpindale/Mistral-7B-v0.2-hf](https://huggingface.co/alpindale/Mistral-7B-v0.2-hf)| |mistral-7b-instruct|[AI-ModelScope/Mistral-7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)| |mistral-7b-instruct-v2|[AI-ModelScope/Mistral-7B-Instruct-v0.2](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.2/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)| +|mistral-7b-instruct-v3|[LLM-Research/Mistral-7B-Instruct-v0.3](https://modelscope.cn/models/LLM-Research/Mistral-7B-Instruct-v0.3/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)| |mixtral-moe-7b|[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.36|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)| |mixtral-moe-7b-instruct|[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.36|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)| |mixtral-moe-7b-aqlm-2bit-1x16|[AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf/summary)|q_proj, k_proj, v_proj|default-generation|✔|✘|transformers>=4.38, aqlm, torch>=2.2.0|-|[ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://huggingface.co/ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf)| @@ -327,8 +328,11 @@ |qwen-audio-chat|[qwen/Qwen-Audio-Chat](https://modelscope.cn/models/qwen/Qwen-Audio-Chat/summary)|c_attn|qwen-audio|✔|✘||audio|[Qwen/Qwen-Audio-Chat](https://huggingface.co/Qwen/Qwen-Audio-Chat)| |glm4v-9b-chat|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b/summary)|self_attention.query_key_value|glm4v|✘|✘||vision|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)| |llava1_5-7b-chat|[huangjintao/llava-1.5-7b-hf](https://modelscope.cn/models/huangjintao/llava-1.5-7b-hf/summary)|q_proj, k_proj, v_proj|llava1_5|✔|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)| -|llava1_6-mistral-7b-instruct|[AI-ModelScope/llava-v1.6-mistral-7b](https://modelscope.cn/models/AI-ModelScope/llava-v1.6-mistral-7b/summary)|q_proj, k_proj, v_proj|llava-mistral-instruct|✔|✘|transformers>=4.34|vision|[liuhaotian/llava-v1.6-mistral-7b](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b)| -|llava1_6-yi-34b-instruct|[AI-ModelScope/llava-v1.6-34b](https://modelscope.cn/models/AI-ModelScope/llava-v1.6-34b/summary)|q_proj, k_proj, v_proj|llava-yi-instruct|✔|✘||vision|[liuhaotian/llava-v1.6-34b](https://huggingface.co/liuhaotian/llava-v1.6-34b)| +|llava1_5-13b-chat|[huangjintao/llava-1.5-13b-hf](https://modelscope.cn/models/huangjintao/llava-1.5-13b-hf/summary)|q_proj, k_proj, v_proj|llava1_5|✔|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)| +|llava1_6-mistral-7b-chat|[huangjintao/llava-v1.6-mistral-7b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-mistral-7b-hf/summary)|q_proj, k_proj, v_proj|llava-mistral|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)| +|llava1_6-vicuna-7b-chat|[huangjintao/llava-v1.6-vicuna-7b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-7b-hf/summary)|q_proj, k_proj, v_proj|llava-vicuna|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf)| +|llava1_6-vicuna-13b-chat|[huangjintao/llava-v1.6-vicuna-13b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-13b-hf/summary)|q_proj, k_proj, v_proj|llava-vicuna|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)| +|llava1_6-yi-34b-chat|[huangjintao/llava-v1.6-34b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-34b-hf/summary)|q_proj, k_proj, v_proj|llava-yi|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)| |llama3-llava-next-8b|[AI-Modelscope/llama3-llava-next-8b](https://modelscope.cn/models/AI-Modelscope/llama3-llava-next-8b/summary)|q_proj, k_proj, v_proj|llama-llava-next|✔|✘||vision|[lmms-lab/llama3-llava-next-8b](https://huggingface.co/lmms-lab/llama3-llava-next-8b)| |llava-next-72b|[AI-Modelscope/llava-next-72b](https://modelscope.cn/models/AI-Modelscope/llava-next-72b/summary)|q_proj, k_proj, v_proj|llava-qwen-instruct|✔|✘||vision|[lmms-lab/llava-next-72b](https://huggingface.co/lmms-lab/llava-next-72b)| |llava-next-110b|[AI-Modelscope/llava-next-110b](https://modelscope.cn/models/AI-Modelscope/llava-next-110b/summary)|q_proj, k_proj, v_proj|llava-qwen-instruct|✔|✘||vision|[lmms-lab/llava-next-110b](https://huggingface.co/lmms-lab/llava-next-110b)| diff --git "a/docs/source/Multi-Modal/llava\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source/Multi-Modal/llava\346\234\200\344\275\263\345\256\236\350\267\265.md" index 9e5e9b7f0..ce400e2a6 100644 --- "a/docs/source/Multi-Modal/llava\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source/Multi-Modal/llava\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -1,16 +1,17 @@ - # Llava 最佳实践 -本篇文档对应的模型 +本篇文档涉及的模型如下: + +- [llava1_5-7b-chat](https://modelscope.cn/models/huangjintao/llava-1.5-7b-hf) +- [llava1_5-13b-chat](https://modelscope.cn/models/huangjintao/llava-1.5-13b-hf) +- [llava1_6-mistral-7b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-mistral-7b-hf) +- [llava1_6-vicuna-7b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-7b-hf) +- [llava1_6-vicuna-13b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-13b-hf) +- [llava1_6-yi-34b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-34b-hf) +- [llava-next-72b](https://modelscope.cn/models/AI-Modelscope/llava-next-72b) +- [llava-next-110b](https://modelscope.cn/models/AI-Modelscope/llava-next-110b) -| model | model_type | -|-------|------------| -| [llava-v1.6-mistral-7b](https://modelscope.cn/models/AI-ModelScope/llava-v1.6-mistral-7b/summary) | llava1_6-mistral-7b-instruct | -| [llava-v1.6-34b](https://www.modelscope.cn/models/AI-ModelScope/llava-v1.6-34b/summary) | llava1_6-yi-34b-instruct | -|[llama3-llava-next-8b](https://modelscope.cn/models/AI-ModelScope/llama3-llava-next-8b/summary)|llama3-llava-next-8b| -|[llava-next-72b](https://modelscope.cn/models/AI-ModelScope/llava-next-72b/summary)|llava-next-72b| -|[llava-next-110b](https://modelscope.cn/models/AI-ModelScope/llava-next-110b/summary)|llava-next-110b| -以下实践以`llava-v1.6-mistral-7b`为例,你也可以通过指定`--model_type`切换为其他模型 +以下实践以`llava1_6-mistral-7b-chat`为例,你也可以通过指定`--model_type`切换为其他模型. ## 目录 - [环境准备](#环境准备) @@ -30,13 +31,13 @@ pip install -e '.[llm]' ```shell # Experimental environment: A100 # 20GB GPU memory -CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-mistral-7b-instruct +CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-mistral-7b-chat # 70GB GPU memory -CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-yi-34b-instruct +CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-yi-34b-chat # 4*20GB GPU memory -CUDA_VISIBLE_DEVICES=0,1,2,3 swift infer --model_type llava1_6-yi-34b-instruct +CUDA_VISIBLE_DEVICES=0,1,2,3 swift infer --model_type llava1_6-yi-34b-chat ``` 输出: (支持传入本地路径或URL) @@ -54,9 +55,10 @@ The image shows a close-up of a kitten with a soft, blurred background that sugg Input a media path or URL <<< http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png There are four sheep in the picture. -------------------------------------------------- +<<< clear <<< What is the calculation result? Input a media path or URL <<< http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/math.png -The calculation result is 14352 + 45304 = 145304. +The calculation result is 1452 + 453004 = 453006. -------------------------------------------------- <<< Write a poem based on the content of the picture. Input a media path or URL <<< http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/poem.png @@ -141,7 +143,7 @@ from swift.llm import ( from swift.utils import seed_everything import torch -model_type = 'llava1_6-mistral-7b-instruct' +model_type = 'llava1_6-mistral-7b-chat' template_type = get_default_template_type(model_type) print(f'template_type: {template_type}') @@ -198,13 +200,13 @@ LoRA微调: # Experimental environment: A10, 3090, V100... # 21GB GPU memory CUDA_VISIBLE_DEVICES=0 swift sft \ - --model_type llava1_6-mistral-7b-instruct \ + --model_type llava1_6-mistral-7b-chat\ --dataset coco-en-2-mini \ # Experimental environment: 2*A100... # 2*45GB GPU memory CUDA_VISIBLE_DEVICES=0,1 swift sft \ - --model_type llava1_6-yi-34b-instruct \ + --model_type llava1_6-yi-34b-chat \ --dataset coco-en-2-mini \ ``` @@ -213,14 +215,14 @@ CUDA_VISIBLE_DEVICES=0,1 swift sft \ # Experimental environment: 4 * A100 # 4 * 70 GPU memory NPROC_PER_NODE=4 CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \ - --model_type llava1_6-mistral-7b-instruct \ + --model_type llava1_6-mistral-7b-chat\ --dataset coco-en-2-mini \ --sft_type full \ --deepspeed default-zero2 # 8 * 50 GPU memory CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 swift sft \ - --model_type llava1_6-yi-34b-instruct \ + --model_type llava1_6-yi-34b-chat \ --dataset coco-en-2-mini \ --sft_type full \ ``` @@ -239,7 +241,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 swift sft \ ## 微调后推理 直接推理: ```shell -model_type="llava1_6-mistral-7b-instruct" +model_type="llava1_6-mistral-7b-chat" CUDA_VISIBLE_DEVICES=0 swift infer \ --ckpt_dir output/${model_type}/vx-xxx/checkpoint-xxx \ @@ -248,7 +250,8 @@ CUDA_VISIBLE_DEVICES=0 swift infer \ **merge-lora**并推理: ```shell -model_type="llava1_6-mistral-7b-instruct" +model_type="llava1_6-mistral-7b-chat" + CUDA_VISIBLE_DEVICES=0 swift export \ --ckpt_dir "output/${model_type}/vx-xxx/checkpoint-xxx" \ --merge_lora true diff --git a/docs/source_en/LLM/Supported-models-datasets.md b/docs/source_en/LLM/Supported-models-datasets.md index e8c4f5a9c..d2e6f57f6 100644 --- a/docs/source_en/LLM/Supported-models-datasets.md +++ b/docs/source_en/LLM/Supported-models-datasets.md @@ -238,6 +238,7 @@ The table below introcudes all models supported by SWIFT: |mistral-7b-v2|[AI-ModelScope/Mistral-7B-v0.2-hf](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.2-hf/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.34|-|[alpindale/Mistral-7B-v0.2-hf](https://huggingface.co/alpindale/Mistral-7B-v0.2-hf)| |mistral-7b-instruct|[AI-ModelScope/Mistral-7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)| |mistral-7b-instruct-v2|[AI-ModelScope/Mistral-7B-Instruct-v0.2](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.2/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)| +|mistral-7b-instruct-v3|[LLM-Research/Mistral-7B-Instruct-v0.3](https://modelscope.cn/models/LLM-Research/Mistral-7B-Instruct-v0.3/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.34|-|[mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)| |mixtral-moe-7b|[AI-ModelScope/Mixtral-8x7B-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-v0.1/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.36|-|[mistralai/Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)| |mixtral-moe-7b-instruct|[AI-ModelScope/Mixtral-8x7B-Instruct-v0.1](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7B-Instruct-v0.1/summary)|q_proj, k_proj, v_proj|llama|✔|✔|transformers>=4.36|-|[mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)| |mixtral-moe-7b-aqlm-2bit-1x16|[AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://modelscope.cn/models/AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf/summary)|q_proj, k_proj, v_proj|default-generation|✔|✘|transformers>=4.38, aqlm, torch>=2.2.0|-|[ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf](https://huggingface.co/ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf)| @@ -327,8 +328,11 @@ The table below introcudes all models supported by SWIFT: |qwen-audio-chat|[qwen/Qwen-Audio-Chat](https://modelscope.cn/models/qwen/Qwen-Audio-Chat/summary)|c_attn|qwen-audio|✔|✘||audio|[Qwen/Qwen-Audio-Chat](https://huggingface.co/Qwen/Qwen-Audio-Chat)| |glm4v-9b-chat|[ZhipuAI/glm-4v-9b](https://modelscope.cn/models/ZhipuAI/glm-4v-9b/summary)|self_attention.query_key_value|glm4v|✘|✘||vision|[THUDM/glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b)| |llava1_5-7b-chat|[huangjintao/llava-1.5-7b-hf](https://modelscope.cn/models/huangjintao/llava-1.5-7b-hf/summary)|q_proj, k_proj, v_proj|llava1_5|✔|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)| -|llava1_6-mistral-7b-instruct|[AI-ModelScope/llava-v1.6-mistral-7b](https://modelscope.cn/models/AI-ModelScope/llava-v1.6-mistral-7b/summary)|q_proj, k_proj, v_proj|llava-mistral-instruct|✔|✘|transformers>=4.34|vision|[liuhaotian/llava-v1.6-mistral-7b](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b)| -|llava1_6-yi-34b-instruct|[AI-ModelScope/llava-v1.6-34b](https://modelscope.cn/models/AI-ModelScope/llava-v1.6-34b/summary)|q_proj, k_proj, v_proj|llava-yi-instruct|✔|✘||vision|[liuhaotian/llava-v1.6-34b](https://huggingface.co/liuhaotian/llava-v1.6-34b)| +|llava1_5-13b-chat|[huangjintao/llava-1.5-13b-hf](https://modelscope.cn/models/huangjintao/llava-1.5-13b-hf/summary)|q_proj, k_proj, v_proj|llava1_5|✔|✘|transformers>=4.36|vision|[llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)| +|llava1_6-mistral-7b-chat|[huangjintao/llava-v1.6-mistral-7b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-mistral-7b-hf/summary)|q_proj, k_proj, v_proj|llava-mistral|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)| +|llava1_6-vicuna-7b-chat|[huangjintao/llava-v1.6-vicuna-7b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-7b-hf/summary)|q_proj, k_proj, v_proj|llava-vicuna|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf)| +|llava1_6-vicuna-13b-chat|[huangjintao/llava-v1.6-vicuna-13b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-13b-hf/summary)|q_proj, k_proj, v_proj|llava-vicuna|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)| +|llava1_6-yi-34b-chat|[huangjintao/llava-v1.6-34b-hf](https://modelscope.cn/models/huangjintao/llava-v1.6-34b-hf/summary)|q_proj, k_proj, v_proj|llava-yi|✔|✘|transformers>=4.36|vision|[llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)| |llama3-llava-next-8b|[AI-Modelscope/llama3-llava-next-8b](https://modelscope.cn/models/AI-Modelscope/llama3-llava-next-8b/summary)|q_proj, k_proj, v_proj|llama-llava-next|✔|✘||vision|[lmms-lab/llama3-llava-next-8b](https://huggingface.co/lmms-lab/llama3-llava-next-8b)| |llava-next-72b|[AI-Modelscope/llava-next-72b](https://modelscope.cn/models/AI-Modelscope/llava-next-72b/summary)|q_proj, k_proj, v_proj|llava-qwen-instruct|✔|✘||vision|[lmms-lab/llava-next-72b](https://huggingface.co/lmms-lab/llava-next-72b)| |llava-next-110b|[AI-Modelscope/llava-next-110b](https://modelscope.cn/models/AI-Modelscope/llava-next-110b/summary)|q_proj, k_proj, v_proj|llava-qwen-instruct|✔|✘||vision|[lmms-lab/llava-next-110b](https://huggingface.co/lmms-lab/llava-next-110b)| diff --git a/docs/source_en/Multi-Modal/llava-best-practice.md b/docs/source_en/Multi-Modal/llava-best-practice.md index 8620ed355..b20520845 100644 --- a/docs/source_en/Multi-Modal/llava-best-practice.md +++ b/docs/source_en/Multi-Modal/llava-best-practice.md @@ -1,15 +1,16 @@ # Llava Best Practice -The document corresponds to the following models +The document corresponds to the following models: -| model | model_type | -|-------|------------| -| [llava-v1.6-mistral-7b](https://modelscope.cn/models/AI-ModelScope/llava-v1.6-mistral-7b/summary) | llava1_6-mistral-7b-instruct | -| [llava-v1.6-34b](https://www.modelscope.cn/models/AI-ModelScope/llava-v1.6-34b/summary) | llava1_6-yi-34b-instruct | -|[llama3-llava-next-8b](https://modelscope.cn/models/AI-ModelScope/llama3-llava-next-8b/summary)|llama3-llava-next-8b| -|[llava-next-72b](https://modelscope.cn/models/AI-ModelScope/llava-next-72b/summary)|llava-next-72b| -|[llava-next-110b](https://modelscope.cn/models/AI-ModelScope/llava-next-110b/summary)|llava-next-110b| +- [llava1_5-7b-chat](https://modelscope.cn/models/huangjintao/llava-1.5-7b-hf) +- [llava1_5-13b-chat](https://modelscope.cn/models/huangjintao/llava-1.5-13b-hf) +- [llava1_6-mistral-7b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-mistral-7b-hf) +- [llava1_6-vicuna-7b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-7b-hf) +- [llava1_6-vicuna-13b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-vicuna-13b-hf) +- [llava1_6-yi-34b-chat](https://modelscope.cn/models/huangjintao/llava-v1.6-34b-hf) +- [llava-next-72b](https://modelscope.cn/models/AI-Modelscope/llava-next-72b) +- [llava-next-110b](https://modelscope.cn/models/AI-Modelscope/llava-next-110b) -The following practices take `llava-v1.6-mistral-7b` as an example. You can also switch to other models by specifying the `--model_type`. +The following practice takes `llava1_6-mistral-7b-chat` as an example, and you can also switch to other models by specifying `--model_type`. ## Table of Contents @@ -29,13 +30,13 @@ pip install -e '.[llm]' ```shell # Experimental environment: A100 # 20GB GPU memory -CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-mistral-7b-instruct +CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-mistral-7b-chat # 70GB GPU memory -CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-yi-34b-instruct +CUDA_VISIBLE_DEVICES=0 swift infer --model_type llava1_6-yi-34b-chat # 4*20GB GPU memory -CUDA_VISIBLE_DEVICES=0,1,2,3 swift infer --model_type llava1_6-yi-34b-instruct +CUDA_VISIBLE_DEVICES=0,1,2,3 swift infer --model_type llava1_6-yi-34b-chat ``` Output: (supports passing in local path or URL) @@ -49,9 +50,10 @@ The image shows a close-up of a kitten with a soft, blurred background that sugg Input a media path or URL <<< http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png There are four sheep in the picture. -------------------------------------------------- +<<< clear <<< What is the calculation result? Input a media path or URL <<< http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/math.png -The calculation result is 14352 + 45304 = 145304. +The calculation result is 1452 + 453004 = 453006. -------------------------------------------------- <<< Write a poem based on the content of the picture. Input a media path or URL <<< http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/poem.png @@ -84,6 +86,20 @@ The boat, a symbol of solitude, In the vast expanse of the universe's beauty, A lone journey, a solitary quest, In the quiet of the night, it finds its rest. +-------------------------------------------------- +<<< Perform OCR on the image. +Input a media path or URL <<< https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/ocr_en.png +The text in the image is as follows: + +INTRODUCTION + +SWIFT supports training, inference, evaluation and deployment of 250+ LLMs (multimodal large models). Developers can directly apply our framework to their own research and production environments to realize the complete workflow from model training and evaluation to application. In addition, SWIFT provides a complete Adapters library to support the latest training techniques such as NLP, Vision, etc. This adapter library can be used directly in your own custom workflow without our training scripts. + +To facilitate use by users unfamiliar with deep learning, we provide a Grado web-ui for controlling training and inference, as well as accompanying deep learning courses and best practices for beginners. + +SWIFT has rich documentation for users, please check here. + +SWIFT is web-ui available both on Huggingface space and ModelScope studio, please feel free to try! """ ``` @@ -118,7 +134,7 @@ from swift.llm import ( from swift.utils import seed_everything import torch -model_type = 'llava1_6-mistral-7b-instruct' +model_type = 'llava1_6-mistral-7b-chat' template_type = get_default_template_type(model_type) print(f'template_type: {template_type}') @@ -175,12 +191,12 @@ LoRA fine-tuning: # Experimental environment: A10, 3090, V100... # 21GB GPU memory CUDA_VISIBLE_DEVICES=0 swift sft \ - --model_type llava1_6-mistral-7b-instruct \ + --model_type llava1_6-mistral-7b-chat \ --dataset coco-en-2-mini \ # 2*45GB GPU memory CUDA_VISIBLE_DEVICES=0,1 swift sft \ - --model_type llava1_6-yi-34b-instruct \ + --model_type llava1_6-yi-34b-chat \ --dataset coco-en-2-mini \ ``` @@ -189,14 +205,14 @@ Full parameter fine-tuning: # Experimental environment: 4 * A100 # 4 * 70 GPU memory NPROC_PER_NODE=4 CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \ - --model_type llava1_6-mistral-7b-instruct \ + --model_type llava1_6-mistral-7b-chat \ --dataset coco-en-2-mini \ --sft_type full \ --deepspeed default-zero2 # 8 * 50 GPU memory CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 swift sft \ - --model_type llava1_6-yi-34b-instruct \ + --model_type llava1_6-yi-34b-chat \ --dataset coco-en-2-mini \ --sft_type full \ ``` @@ -215,7 +231,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 swift sft \ ## Inference after Fine-tuning Direct inference: ```shell -model_type="llava1_6-mistral-7b-instruct" +model_type="llava1_6-mistral-7b-chat" CUDA_VISIBLE_DEVICES=0 swift infer \ --ckpt_dir output/${model_type}/vx-xxx/checkpoint-xxx \ --load_dataset_config true @@ -223,7 +239,7 @@ CUDA_VISIBLE_DEVICES=0 swift infer \ **merge-lora** and inference: ```shell -model_type="llava1_6-mistral-7b-instruct" +model_type="llava1_6-mistral-7b-chat" CUDA_VISIBLE_DEVICES=0 swift export \ --ckpt_dir "output/${model_type}/vx-xxx/checkpoint-xxx" \ --merge_lora true diff --git a/swift/llm/app_ui.py b/swift/llm/app_ui.py index f3627b59e..1b5844b7d 100644 --- a/swift/llm/app_ui.py +++ b/swift/llm/app_ui.py @@ -76,10 +76,11 @@ def model_chat(query: str, history: History) -> Iterator[Tuple[str, History]]: gr.Markdown(f'
{model_name} Bot
') chatbot = gr.Chatbot(label=f'{model_name}') - message = gr.Textbox(lines=2, label='Input') + message = gr.Textbox(lines=1, label='Input') with gr.Row(): clear_history = gr.Button('🧹 清除历史对话') send = gr.Button('🚀 发送') + message.submit(model_chat, inputs=[message, chatbot], outputs=[message, chatbot]) send.click(model_chat, inputs=[message, chatbot], outputs=[message, chatbot]) clear_history.click(fn=clear_session, inputs=[], outputs=[chatbot], queue=False) # Compatible with InferArguments diff --git a/swift/llm/tuner.py b/swift/llm/tuner.py index dcf113ea3..35da42594 100644 --- a/swift/llm/tuner.py +++ b/swift/llm/tuner.py @@ -242,6 +242,9 @@ def prepare_model(model, args: SftArguments): is_logging = True p.data = p.data.to(dtype=torch.float32) elif args.sft_type == 'full': + model.train() + model.requires_grad_(True) + if args.freeze_parameters > 0: freeze_model_parameters(model, args.freeze_parameters) if len(args.additional_trainable_parameters) > 0: diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index f0ddcd8a5..042aa687f 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -56,6 +56,12 @@ def _check_path(cls, k: str, value: Union[str, List[str]], value = res return value + @staticmethod + def _is_multimodal(model_type: str) -> bool: + model_info = MODEL_MAPPING[model_type] + tags = model_info.get('tags') or [] + return 'multi-modal' in tags + def handle_path(self: Union['SftArguments', 'InferArguments']) -> None: check_exist_path = ['ckpt_dir', 'resume_from_checkpoint', 'custom_register_path'] maybe_check_exist_path = ['model_id_or_path', 'custom_dataset_info'] @@ -181,8 +187,10 @@ def handle_compatibility(self: Union['SftArguments', 'InferArguments']) -> None: 'cogvlm-17b-instruct': 'cogvlm-17b-chat', 'minicpm-v-v2': 'minicpm-v-v2-chat', 'mplug-owl2d1-chat': 'mplug-owl2_1-chat', - 'llava1d6-mistral-7b-instruct': 'llava1_6-mistral-7b-instruct', - 'llava1d6-yi-34b-instruct': 'llava1_6-yi-34b-instruct', + 'llava1d6-mistral-7b-instruct': 'llava1_6-mistral-7b-chat', + 'llava1d6-yi-34b-instruct': 'llava1_6-yi-34b-chat', + 'llava1_6-mistral-7b-instruct': 'llava1_6-mistral-7b-chat', + 'llava1_6-yi-34b-instruct': 'llava1_6-yi-34b-chat', } dataset_name_mapping = { 'ms-bench-mini': 'ms-bench#20000', @@ -775,6 +783,7 @@ def __post_init__(self) -> None: self.set_model_type() self.check_flash_attn() self.handle_generation_config() + self.is_multimodal = self._is_multimodal(self.model_type) self.lora_use_embedding = False self.lora_use_all = False @@ -1157,6 +1166,7 @@ def __post_init__(self) -> None: self.set_model_type() self.check_flash_attn() self.handle_generation_config() + self.is_multimodal = self._is_multimodal(self.model_type) self.torch_dtype, _, _ = self.select_dtype() self.prepare_template() @@ -1298,9 +1308,6 @@ class DeployArguments(InferArguments): def __post_init__(self): super().__post_init__() - model_info = MODEL_MAPPING[self.model_type] - tags = model_info.get('tags') or [] - self.is_multimodal = 'multi-modal' in tags @dataclass diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 4aabfb364..6eb2141af 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -32,7 +32,7 @@ logger = get_logger() -# Model Home: 'https://modelscope.cn/models/{model_id_or_path}/summary' +# Model Home: 'https://modelscope.cn/models/{model_id_or_path}' MODEL_MAPPING: Dict[str, Dict[str, Any]] = {} @@ -190,8 +190,11 @@ class ModelType: atom_7b_chat = 'atom-7b-chat' # llava llava1_5_7b_chat = 'llava1_5-7b-chat' - llava1_6_mistral_7b_instruct = 'llava1_6-mistral-7b-instruct' - llava1_6_yi_34b_instruct = 'llava1_6-yi-34b-instruct' + llava1_5_13b_chat = 'llava1_5-13b-chat' + llava1_6_mistral_7b_chat = 'llava1_6-mistral-7b-chat' + llava1_6_vicuna_7b_chat = 'llava1_6-vicuna-7b-chat' + llava1_6_vicuna_13b_chat = 'llava1_6-vicuna-13b-chat' + llava1_6_yi_34b_chat = 'llava1_6-yi-34b-chat' llama3_llava_next_8b = 'llama3-llava-next-8b' llava_next_72b = 'llava-next-72b' llava_next_110b = 'llava-next-110b' @@ -326,6 +329,7 @@ class ModelType: mistral_7b_v2 = 'mistral-7b-v2' mistral_7b_instruct = 'mistral-7b-instruct' mistral_7b_instruct_v2 = 'mistral-7b-instruct-v2' + mistral_7b_instruct_v3 = 'mistral-7b-instruct-v3' mixtral_moe_7b = 'mixtral-moe-7b' mixtral_moe_7b_instruct = 'mixtral-moe-7b-instruct' mixtral_moe_7b_aqlm_2bit_1x16 = 'mixtral-moe-7b-aqlm-2bit-1x16' # aqlm @@ -925,9 +929,6 @@ def get_model_tokenizer_from_repo(model_dir: str, with context: model = automodel_class.from_pretrained( model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs) - if is_training: - model.train() - model.requires_grad_(True) model.is_gptq = is_gptq model.is_awq = is_awq model.is_aqlm = is_aqlm @@ -2336,6 +2337,16 @@ def _output_device_map_hook(module, input, output): support_flash_attn=True, support_vllm=True, hf_model_id='mistralai/Mistral-7B-Instruct-v0.2') +@register_model( + ModelType.mistral_7b_instruct_v3, + 'LLM-Research/Mistral-7B-Instruct-v0.3', + LoRATM.llama, + TemplateType.llama, + ignore_file_pattern=['consolidated.safetensors'], + requires=['transformers>=4.34'], + support_flash_attn=True, + support_vllm=True, + hf_model_id='mistralai/Mistral-7B-Instruct-v0.3') @register_model( ModelType.mistral_7b, 'AI-ModelScope/Mistral-7B-v0.1', @@ -4897,6 +4908,24 @@ def _new_generate(inputs=None, *args, **kwargs): model.generate = _new_generate +def get_model_tokenizer_llava_hf(model_dir: str, *args, **kwargs): + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(model_dir) + model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs) + tokenizer.processor = processor + return model, tokenizer + + +@register_model( + ModelType.llava1_5_13b_chat, + 'huangjintao/llava-1.5-13b-hf', + LoRATM.llama, + TemplateType.llava1_5, + eos_token='', + support_flash_attn=True, + requires=['transformers>=4.36'], + tags=['multi-modal', 'vision'], + hf_model_id='llava-hf/llava-1.5-13b-hf') @register_model( ModelType.llava1_5_7b_chat, 'huangjintao/llava-1.5-7b-hf', @@ -4907,35 +4936,62 @@ def _new_generate(inputs=None, *args, **kwargs): requires=['transformers>=4.36'], tags=['multi-modal', 'vision'], hf_model_id='llava-hf/llava-1.5-7b-hf') -def get_model_tokenizer_llava1_5(model_dir: str, *args, **kwargs): - from transformers import AutoProcessor, LlavaForConditionalGeneration - processor = AutoProcessor.from_pretrained(model_dir) - model, tokenizer = get_model_tokenizer_with_flash_attn( - model_dir, *args, automodel_class=LlavaForConditionalGeneration, **kwargs) - tokenizer.processor = processor - return model, tokenizer +def get_model_tokenizer_llava_1_5(*args, **kwargs): + from transformers import LlavaForConditionalGeneration + kwargs['automodel_class'] = LlavaForConditionalGeneration + return get_model_tokenizer_llava_hf(*args, **kwargs) @register_model( - ModelType.llava1_6_yi_34b_instruct, - 'AI-ModelScope/llava-v1.6-34b', + ModelType.llava1_6_vicuna_7b_chat, + 'huangjintao/llava-v1.6-vicuna-7b-hf', LoRATM.llama, - TemplateType.llava_yi_instruct, - eos_token='<|im_end|>', + TemplateType.llava_vicuna, support_flash_attn=True, - function_kwargs={'llm_model_type': 'llama'}, + requires=['transformers>=4.36'], tags=['multi-modal', 'vision'], - hf_model_id='liuhaotian/llava-v1.6-34b') + hf_model_id='llava-hf/llava-v1.6-vicuna-7b-hf') @register_model( - ModelType.llava1_6_mistral_7b_instruct, - 'AI-ModelScope/llava-v1.6-mistral-7b', + ModelType.llava1_6_vicuna_13b_chat, + 'huangjintao/llava-v1.6-vicuna-13b-hf', LoRATM.llama, - TemplateType.llava_mistral_instruct, - requires=['transformers>=4.34'], + TemplateType.llava_vicuna, + support_flash_attn=True, + requires=['transformers>=4.36'], + tags=['multi-modal', 'vision'], + hf_model_id='llava-hf/llava-v1.6-vicuna-13b-hf') +@register_model( + ModelType.llava1_6_mistral_7b_chat, + 'huangjintao/llava-v1.6-mistral-7b-hf', + LoRATM.llama, + TemplateType.llava_mistral, + support_flash_attn=True, + requires=['transformers>=4.36'], + tags=['multi-modal', 'vision'], + hf_model_id='llava-hf/llava-v1.6-mistral-7b-hf') +def get_model_tokenizer_llava_next(*args, **kwargs): + from transformers import LlavaNextForConditionalGeneration + kwargs['automodel_class'] = LlavaNextForConditionalGeneration + return get_model_tokenizer_llava_hf(*args, **kwargs) + + +@register_model( + ModelType.llava1_6_yi_34b_chat, + 'huangjintao/llava-v1.6-34b-hf', + LoRATM.llama, + TemplateType.llava_yi, support_flash_attn=True, - function_kwargs={'llm_model_type': 'mistral'}, + eos_token='<|im_end|>', + requires=['transformers>=4.36'], tags=['multi-modal', 'vision'], - hf_model_id='liuhaotian/llava-v1.6-mistral-7b') + hf_model_id='llava-hf/llava-v1.6-34b-hf') +def get_model_tokenizer_llava_next_yi(*args, **kwargs): + model, tokenizer = get_model_tokenizer_llava_next(*args, **kwargs) + if model is not None: + model.config.image_token_index = 64003 + return model, tokenizer + + @register_model( ModelType.llama3_llava_next_8b, 'AI-Modelscope/llama3-llava-next-8b', diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index fc72a91c3..4880eba70 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -41,8 +41,9 @@ class TemplateType: llama = 'llama' # llama2 llama3 = 'llama3' llava1_5 = 'llava1_5' - llava_mistral_instruct = 'llava-mistral-instruct' - llava_yi_instruct = 'llava-yi-instruct' + llava_mistral = 'llava-mistral' + llava_vicuna = 'llava-vicuna' + llava_yi = 'llava-yi' llava_llama_instruct = 'llava-llama-instruct' llava_qwen_instruct = 'llava-qwen-instruct' llama_llava_next = 'llama-llava-next' @@ -643,6 +644,10 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = if len(pixel_values) > 0: res['pixel_values'] = torch.concat(pixel_values) + image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None] + if len(image_sizes) > 0: + res['image_sizes'] = torch.concat(image_sizes) + if loss_scale is not None: res['loss_scale'] = loss_scale return res @@ -973,10 +978,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False) input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:]) if labels is not None: - image_size: int = self.model.config.vision_config['image_size'] - patch_size: int = self.model.config.vision_config['patch_size'] - num_patches = (image_size // patch_size // 2)**2 - labels = (labels[:idx] + [-100] * (len(placeholder_id) + num_patches - 1) + labels[idx + 1:]) + labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:]) messages = history_to_messages(example.get('history') or [], example['query'], example.get('system')) messages[0]['image'] = image inputs2: Dict[str, Any] = self.tokenizer.apply_chat_template(messages, return_dict=True) @@ -1445,10 +1447,7 @@ def post_process_generate_response(self, response, example): 'and other non-computer science questions, you will refuse to answer\n'))) -class Llava1_5Template(Template): - - def __init__(self): - super().__init__([''], ['USER: {{QUERY}}\nASSISTANT:'], ['\n'], ['']) +class LlavaHfTemplate(Template): def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]: assert media_type == 'image' @@ -1462,10 +1461,19 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any images = _read_batch(images_path) image_processor = self.tokenizer.processor.image_processor if images: - inputs['pixel_values'] = image_processor(images, return_tensors='pt')['pixel_values'].to(self.model.dtype) + image_inputs = image_processor(images, return_tensors='pt').to(self.model.dtype) + inputs['pixel_values'] = image_inputs['pixel_values'] + if 'image_sizes' in image_inputs: + inputs['image_sizes'] = image_inputs['image_sizes'] return inputs, {} +class Llava1_5Template(LlavaHfTemplate): + + def __init__(self): + super().__init__([''], ['USER: {{QUERY}}\nASSISTANT:'], [''], ['']) + + register_template( TemplateType.llava1_5, Llava1_5Template(), use_model=True, infer_media_type='round', lazy_tokenize=True) @@ -1518,19 +1526,40 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]: return generate_ids[0].tolist() +class Llava1_6MistralTemplate(LlavaHfTemplate): + + def __init__(self): + super().__init__(['[INST] '], ['{{QUERY}} [/INST]'], [''], [''], + system_prefix=['<>\n{{system}}\n<>\n\n']) + + +class Llava1_6VicunaTemplate(LlavaHfTemplate): + system = ('A chat between a curious human and an artificial intelligence assistant. ' + "The assistant gives helpful, detailed, and polite answers to the human's questions.") + + def __init__(self): + super().__init__([''], ['USER: {{QUERY}} ASSISTANT:'], [''], [''], + self.system, + system_prefix=['{{SYSTEM}} ']) + + +register_template( + TemplateType.llava_mistral, Llava1_6MistralTemplate(), use_model=True, infer_media_type='round', lazy_tokenize=True) + register_template( - TemplateType.llava_mistral_instruct, LLavaTemplate(), use_model=True, infer_media_type='round', lazy_tokenize=True) + TemplateType.llava_vicuna, Llava1_6VicunaTemplate(), use_model=True, infer_media_type='round', lazy_tokenize=True) -class LLavaYiTemplate(LLavaTemplate): - llavayi_query_template = '\n<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n' +class LLavaYiTemplate(LlavaHfTemplate): def __init__(self): - Template.__init__(self, [], [self.llavayi_query_template], None, ['<|im_end|>']) + super().__init__([], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'], + ['<|im_end|>'], + system_prefix=['<|im_start|>system\n{{SYSTEM}}<|im_end|>']) register_template( - TemplateType.llava_yi_instruct, LLavaYiTemplate(), use_model=True, infer_media_type='round', lazy_tokenize=True) + TemplateType.llava_yi, LLavaYiTemplate(), use_model=True, infer_media_type='round', lazy_tokenize=True) class LLavaLlamaTemplate(Template): @@ -1651,12 +1680,6 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any inputs['labels'] = labels return inputs, {} - def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]: - res = super().data_collator(batch, padding_to) - if 'pixel_values' in res: - res['image_sizes'] = torch.concat([b['image_sizes'] for b in batch if 'image_sizes' in b]) - return res - register_template(TemplateType.phi3_vl, Phi3VisionTemplate(), lazy_tokenize=True) @@ -2033,7 +2056,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any images_path = example.get('images') or [] images = _read_batch(images_path) for i, image in enumerate(images): - # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary + # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1 max_edge = max(image.size) image = image.resize((max_edge, max_edge)) images[i] = image diff --git a/swift/llm/utils/vllm_utils.py b/swift/llm/utils/vllm_utils.py index bc476efac..71d9f1c0e 100644 --- a/swift/llm/utils/vllm_utils.py +++ b/swift/llm/utils/vllm_utils.py @@ -102,6 +102,7 @@ def get_vllm_engine( _engine = llm_engine.engine else: _engine = llm_engine + llm_engine.dtype = _engine.model_config.dtype # compat with pt # compatible with vllm==0.3.* if version.parse(vllm.__version__) >= version.parse('0.3'): assert isinstance(_engine.tokenizer.tokenizer, PreTrainedTokenizerBase) diff --git a/swift/utils/hub.py b/swift/utils/hub.py index 20aa15df1..2eaf3b4f9 100644 --- a/swift/utils/hub.py +++ b/swift/utils/hub.py @@ -71,7 +71,7 @@ def push_to_ms_hub(ckpt_dir: str, else: subprocess_run(['git', '-C', ckpt_dir, 'commit', '-m', commit_message]) subprocess_run(['git', '-C', ckpt_dir, 'push']) - url = f'https://www.modelscope.cn/models/{hub_model_id}/summary' + url = f'https://www.modelscope.cn/models/{hub_model_id}' logger.info(f'Push to Modelscope successful. url: `{url}`.')