From 0fd73b934aa6f97d14e594df38a0a7f5a12a2aff Mon Sep 17 00:00:00 2001 From: johnsmith253325 Date: Thu, 19 Oct 2023 23:34:02 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=E5=8A=A0=E5=85=A5=E9=80=9A?= =?UTF-8?q?=E4=B9=89=E5=8D=83=E9=97=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modules/models/Qwen.py | 57 ++++++++++++++++++++++++++++++++++++ modules/models/base_model.py | 12 ++++---- modules/models/models.py | 5 +++- modules/presets.py | 8 +++++ requirements_advanced.txt | 4 +++ 5 files changed, 80 insertions(+), 6 deletions(-) create mode 100644 modules/models/Qwen.py diff --git a/modules/models/Qwen.py b/modules/models/Qwen.py new file mode 100644 index 00000000..f5fc8d1b --- /dev/null +++ b/modules/models/Qwen.py @@ -0,0 +1,57 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig +import logging +import colorama +from .base_model import BaseLLMModel +from ..presets import MODEL_METADATA + + +class Qwen_Client(BaseLLMModel): + def __init__(self, model_name, user_name="") -> None: + super().__init__(model_name=model_name, user=user_name) + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_METADATA[model_name]["repo_id"], trust_remote_code=True, resume_download=True) + self.model = AutoModelForCausalLM.from_pretrained(MODEL_METADATA[model_name]["repo_id"], device_map="auto", trust_remote_code=True, resume_download=True).eval() + + def generation_config(self): + return GenerationConfig.from_dict({ + "chat_format": "chatml", + "do_sample": True, + "eos_token_id": 151643, + "max_length": self.token_upper_limit, + "max_new_tokens": 512, + "max_window_size": 6144, + "pad_token_id": 151643, + "top_k": 0, + "top_p": self.top_p, + "transformers_version": "4.33.2", + "trust_remote_code": True, + "temperature": self.temperature, + }) + + def _get_glm_style_input(self): + history = [x["content"] for x in self.history] + query = history.pop() + logging.debug(colorama.Fore.YELLOW + + f"{history}" + colorama.Fore.RESET) + assert ( + len(history) % 2 == 0 + ), f"History should be even length. current history is: {history}" + history = [[history[i], history[i + 1]] + for i in range(0, len(history), 2)] + return history, query + + def get_answer_at_once(self): + history, query = self._get_glm_style_input() + self.model.generation_config = self.generation_config() + response, history = self.model.chat(self.tokenizer, query, history=history) + return response, len(response) + + def get_answer_stream_iter(self): + history, query = self._get_glm_style_input() + self.model.generation_config = self.generation_config() + for response in self.model.chat_stream( + self.tokenizer, + query, + history, + ): + yield response diff --git a/modules/models/base_model.py b/modules/models/base_model.py index b3b8ff92..08657eb1 100644 --- a/modules/models/base_model.py +++ b/modules/models/base_model.py @@ -146,6 +146,7 @@ class ModelType(Enum): Spark = 12 OpenAIInstruct = 13 Claude = 14 + Qwen = 15 @classmethod def get_type(cls, model_name: str): @@ -181,7 +182,9 @@ def get_type(cls, model_name: str): elif "星火大模型" in model_name_lower: model_type = ModelType.Spark elif "claude" in model_name_lower: - model_type = ModelType.Claude + model_type = ModelType.Claude + elif "qwen" in model_name_lower: + model_type = ModelType.Qwen else: model_type = ModelType.LLaMA return model_type @@ -656,14 +659,13 @@ def delete_first_conversation(self): def delete_last_conversation(self, chatbot): if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]: msg = "由于包含报错信息,只删除chatbot记录" - chatbot.pop() + chatbot = chatbot[:-1] return chatbot, self.history if len(self.history) > 0: - self.history.pop() - self.history.pop() + self.history = self.history[:-2] if len(chatbot) > 0: msg = "删除了一组chatbot对话" - chatbot.pop() + chatbot = chatbot[:-1] if len(self.all_token_counts) > 0: msg = "删除了一组对话的token计数记录" self.all_token_counts.pop() diff --git a/modules/models/models.py b/modules/models/models.py index 8fdfc482..9f31a95b 100644 --- a/modules/models/models.py +++ b/modules/models/models.py @@ -116,9 +116,12 @@ def get_model( from .spark import Spark_Client model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv( "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name) - elif model_type == ModelType.Claude: + elif model_type == ModelType.Claude: from .Claude import Claude_Client model = Claude_Client(model_name="claude-2", api_secret=os.getenv("CLAUDE_API_SECRET")) + elif model_type == ModelType.Qwen: + from .Qwen import Qwen_Client + model = Qwen_Client(model_name, user_name=user_name) elif model_type == ModelType.Unknown: raise ValueError(f"未知模型: {model_name}") logging.info(msg) diff --git a/modules/presets.py b/modules/presets.py index 829d7d24..c9fb10af 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -87,6 +87,8 @@ "StableLM", "MOSS", "Llama-2-7B-Chat", + "Qwen 7B", + "Qwen 14B" ] # Additional metadate for local models @@ -98,6 +100,12 @@ "Llama-2-7B-Chat":{ "repo_id": "TheBloke/Llama-2-7b-Chat-GGUF", "filelist": ["llama-2-7b-chat.Q6_K.gguf"], + }, + "Qwen 7B": { + "repo_id": "Qwen/Qwen-7B-Chat-Int4", + }, + "Qwen 14B": { + "repo_id": "Qwen/Qwen-14B-Chat-Int4", } } diff --git a/requirements_advanced.txt b/requirements_advanced.txt index d43f3ce0..a8420f5f 100644 --- a/requirements_advanced.txt +++ b/requirements_advanced.txt @@ -6,3 +6,7 @@ sentence_transformers accelerate sentencepiece llama-cpp-python +transformers_stream_generator +einops +optimum +auto-gptq From 2c72b5e1999de8b83c5ac3c52c497901ff41cbea Mon Sep 17 00:00:00 2001 From: johnsmith253325 Date: Fri, 20 Oct 2023 23:22:47 +0800 Subject: [PATCH 2/3] =?UTF-8?q?doc:=20Readme=E6=B7=BB=E5=8A=A0=E9=80=9A?= =?UTF-8?q?=E4=B9=89=E5=8D=83=E9=97=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- readme/README_en.md | 2 +- readme/README_ja.md | 2 +- readme/README_ru.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f3054c51..c357962c 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ | [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) | | [LLaMA](https://github.com/facebookresearch/llama) | 支持 Lora 模型  | [Google PaLM](https://developers.generativeai.google/products/palm) | 不支持流式传输 | [StableLM](https://github.com/Stability-AI/StableLM) | [讯飞星火认知大模型](https://xinghuo.xfyun.cn) | | [MOSS](https://github.com/OpenLMLab/MOSS) -| [Inspur Yuan 1.0](https://air.inspur.com/home) | | +| [Inspur Yuan 1.0](https://air.inspur.com/home) | | [通义千问](https://github.com/QwenLM/Qwen/tree/main) | [MiniMax](https://api.minimax.chat/) | | [XMChat](https://github.com/MILVLG/xmchat) | 不支持流式传输 | [Midjourney](https://www.midjourney.com/) | 不支持流式传输 diff --git a/readme/README_en.md b/readme/README_en.md index 34163a79..138d745f 100644 --- a/readme/README_en.md +++ b/readme/README_en.md @@ -66,7 +66,7 @@ | [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) | | [LLaMA](https://github.com/facebookresearch/llama) | Support Lora models | [Google PaLM](https://developers.generativeai.google/products/palm) | Not support streaming | [StableLM](https://github.com/Stability-AI/StableLM) | [iFlytek Starfire Cognition Large Model](https://xinghuo.xfyun.cn) | | [MOSS](https://github.com/OpenLMLab/MOSS) -| [Inspur Yuan 1.0](https://air.inspur.com/home) | | +| [Inspur Yuan 1.0](https://air.inspur.com/home) | | [Qwen](https://github.com/QwenLM/Qwen/tree/main) | [MiniMax](https://api.minimax.chat/) | | [XMChat](https://github.com/MILVLG/xmchat) | Not support streaming | [Midjourney](https://www.midjourney.com/) | Not support streaming diff --git a/readme/README_ja.md b/readme/README_ja.md index fd37dbed..b8bb9da5 100644 --- a/readme/README_ja.md +++ b/readme/README_ja.md @@ -65,7 +65,7 @@ | [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) | | [LLaMA](https://github.com/facebookresearch/llama) | Loraモデルのサポートあり  | [Google PaLM](https://developers.generativeai.google/products/palm) | ストリーミング転送はサポートされていません | [StableLM](https://github.com/Stability-AI/StableLM) | [讯飞星火认知大模型](https://xinghuo.xfyun.cn) | | [MOSS](https://github.com/OpenLMLab/MOSS) -| [Inspur Yuan 1.0](https://air.inspur.com/home) | | +| [Inspur Yuan 1.0](https://air.inspur.com/home) | | [Qwen](https://github.com/QwenLM/Qwen/tree/main) | [MiniMax](https://api.minimax.chat/) | | [XMChat](https://github.com/MILVLG/xmchat) | ストリーミング転送はサポートされていません | [Midjourney](https://www.midjourney.com/) | ストリーミング転送はサポートされていません diff --git a/readme/README_ru.md b/readme/README_ru.md index 3dd87875..93ea4c8d 100644 --- a/readme/README_ru.md +++ b/readme/README_ru.md @@ -65,7 +65,7 @@ | [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) | | [LLaMA](https://github.com/facebookresearch/llama) | Поддерживает модель Lora  | [Google PaLM](https://developers.generativeai.google/products/palm) | Не поддерживает потоковую передачу данных | [StableLM](https://github.com/Stability-AI/StableLM) | [Xunfei Xinghuo Cognitive Model](https://xinghuo.xfyun.cn) | | [MOSS](https://github.com/OpenLMLab/MOSS) -| [Inspur Yuan 1.0](https://air.inspur.com/home) | | +| [Inspur Yuan 1.0](https://air.inspur.com/home) | | [Qwen](https://github.com/QwenLM/Qwen/tree/main) | [MiniMax](https://api.minimax.chat/) | | [XMChat](https://github.com/MILVLG/xmchat) | Не поддерживает потоковую передачу данных | [Midjourney](https://www.midjourney.com/) | Не поддерживает потоковую передачу данных From f9abb09fcdc948ed1d7c217dfb8b87d51db9219c Mon Sep 17 00:00:00 2001 From: johnsmith253325 Date: Fri, 20 Oct 2023 23:26:03 +0800 Subject: [PATCH 3/3] =?UTF-8?q?chore:=20=E4=B8=8D=E5=86=8D=E6=89=93?= =?UTF-8?q?=E5=8D=B0commit=20time?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modules/repo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/repo.py b/modules/repo.py index 75b88f03..cb69adee 100644 --- a/modules/repo.py +++ b/modules/repo.py @@ -152,7 +152,7 @@ def version_time(): ) commit_time = commit_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") - logging.info(f"commit time: {commit_time}") + # logging.info(f"commit time: {commit_time}") except Exception: commit_time = "unknown" return commit_time