From dd86fb86b16b89059ea3f27ee8f00a48a1406517 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Wed, 30 Aug 2023 11:07:35 +0800 Subject: [PATCH] feat: Multi-model command line --- pilot/model/cli.py | 151 +++++++++++++++++++++++++++ pilot/model/controller/controller.py | 9 +- pilot/model/controller/registry.py | 19 +++- pilot/model/loader.py | 1 - pilot/model/parameter.py | 13 +++ pilot/model/worker/default_worker.py | 14 ++- pilot/model/worker/manager.py | 29 ++++- pilot/scene/base_chat.py | 15 +-- pilot/scripts/__init__.py | 0 pilot/scripts/cli_scripts.py | 45 ++++++++ pilot/server/dbgpt_server.py | 10 +- pilot/speech/brian.py | 3 +- pilot/speech/eleven_labs.py | 2 +- pilot/speech/gtts.py | 3 +- pilot/utils/__init__.py | 1 + pilot/utils/api_utils.py | 9 +- pilot/utils/utils.py | 16 ++- requirements.txt | 5 +- setup.py | 4 +- tools/cli/cli_scripts.py | 3 +- 20 files changed, 317 insertions(+), 35 deletions(-) create mode 100644 pilot/model/cli.py create mode 100644 pilot/scripts/__init__.py create mode 100644 pilot/scripts/cli_scripts.py diff --git a/pilot/model/cli.py b/pilot/model/cli.py new file mode 100644 index 000000000..6109d11ab --- /dev/null +++ b/pilot/model/cli.py @@ -0,0 +1,151 @@ +import click +import functools + +from pilot.model.controller.registry import ModelRegistryClient +from pilot.model.worker.manager import ( + RemoteWorkerManager, + WorkerApplyRequest, + WorkerApplyType, +) +from pilot.utils import get_or_create_event_loop + + +@click.group("model") +def model_cli_group(): + pass + + +@model_cli_group.command() +@click.option( + "--address", + type=str, + default="http://127.0.0.1:8000", + required=False, + help=( + "Address of the Model Controller to connect to." + "Just support light deploy model" + ), +) +@click.option( + "--model-name", type=str, default=None, required=False, help=("The name of model") +) +@click.option( + "--model-type", type=str, default="llm", required=False, help=("The type of model") +) +def list(address: str, model_name: str, model_type: str): + """List model instances""" + from prettytable import PrettyTable + + loop = get_or_create_event_loop() + registry = ModelRegistryClient(address) + + if not model_name: + instances = loop.run_until_complete(registry.get_all_model_instances()) + else: + if not model_type: + model_type = "llm" + register_model_name = f"{model_name}@{model_type}" + instances = loop.run_until_complete( + registry.get_all_instances(register_model_name) + ) + table = PrettyTable() + + table.field_names = [ + "Model Name", + "Model Type", + "Host", + "Port", + "Healthy", + "Enabled", + "Prompt Template", + "Last Heartbeat", + ] + for instance in instances: + model_name, model_type = instance.model_name.split("@") + table.add_row( + [ + model_name, + model_type, + instance.host, + instance.port, + instance.healthy, + instance.enabled, + instance.prompt_template, + instance.last_heartbeat, + ] + ) + + print(table) + + +def add_model_options(func): + @click.option( + "--address", + type=str, + default="http://127.0.0.1:8000", + required=False, + help=( + "Address of the Model Controller to connect to." + "Just support light deploy model" + ), + ) + @click.option( + "--model-name", + type=str, + default=None, + required=True, + help=("The name of model"), + ) + @click.option( + "--model-type", + type=str, + default="llm", + required=False, + help=("The type of model"), + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +@model_cli_group.command() +@add_model_options +def stop(address: str, model_name: str, model_type: str): + """Stop model instances""" + worker_apply(address, model_name, model_type, WorkerApplyType.STOP) + + +@model_cli_group.command() +@add_model_options +def start(address: str, model_name: str, model_type: str): + """Start model instances""" + worker_apply(address, model_name, model_type, WorkerApplyType.START) + + +@model_cli_group.command() +@add_model_options +def restart(address: str, model_name: str, model_type: str): + """Restart model instances""" + worker_apply(address, model_name, model_type, WorkerApplyType.RESTART) + + +# @model_cli_group.command() +# @add_model_options +# def modify(address: str, model_name: str, model_type: str): +# """Restart model instances""" +# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS) + + +def worker_apply( + address: str, model_name: str, model_type: str, apply_type: WorkerApplyType +): + loop = get_or_create_event_loop() + registry = ModelRegistryClient(address) + worker_manager = RemoteWorkerManager(registry) + apply_req = WorkerApplyRequest( + model=model_name, worker_type=model_type, apply_type=apply_type + ) + res = loop.run_until_complete(worker_manager.worker_apply(apply_req)) + print(res) diff --git a/pilot/model/controller/controller.py b/pilot/model/controller/controller.py index aff42f29b..84d4dfb29 100644 --- a/pilot/model/controller/controller.py +++ b/pilot/model/controller/controller.py @@ -27,6 +27,9 @@ async def get_all_instances( ) return await self.registry.get_all_instances(model_name, healthy_only) + async def get_all_model_instances(self) -> List[ModelInstance]: + return await self.registry.get_all_model_instances() + async def send_heartbeat(self, instance: ModelInstance) -> bool: return await self.registry.send_heartbeat(instance) @@ -51,10 +54,12 @@ async def api_deregister_instance(request: ModelInstance): @router.get("/controller/models") -async def api_get_all_instances(model_name: str, healthy_only: bool = False): +async def api_get_all_instances(model_name: str = None, healthy_only: bool = False): + if not model_name: + return await controller.get_all_model_instances() return await controller.get_all_instances(model_name, healthy_only=healthy_only) @router.post("/controller/heartbeat") -async def api_get_all_instances(request: ModelInstance): +async def api_model_heartbeat(request: ModelInstance): return await controller.send_heartbeat(request) diff --git a/pilot/model/controller/registry.py b/pilot/model/controller/registry.py index 6a68fa573..445e68c26 100644 --- a/pilot/model/controller/registry.py +++ b/pilot/model/controller/registry.py @@ -5,6 +5,7 @@ from collections import defaultdict from datetime import datetime, timedelta from typing import Dict, List, Tuple +import itertools from pilot.model.base import ModelInstance @@ -57,6 +58,15 @@ async def get_all_instances( - List[ModelInstance]: A list of instances for the given model. """ + @abstractmethod + async def get_all_model_instances(self) -> List[ModelInstance]: + """ + Fetch all instances of all models + + Returns: + - List[ModelInstance]: A list of instances for the all models. + """ + async def select_one_health_instance(self, model_name: str) -> ModelInstance: """ Selects one healthy and enabled instance for a given model. @@ -154,12 +164,15 @@ async def deregister_instance(self, instance: ModelInstance) -> bool: async def get_all_instances( self, model_name: str, healthy_only: bool = False ) -> List[ModelInstance]: - print(self.registry) instances = self.registry[model_name] if healthy_only: instances = [ins for ins in instances if ins.healthy == True] return instances + async def get_all_model_instances(self) -> List[ModelInstance]: + print(self.registry) + return list(itertools.chain(*self.registry.values())) + async def send_heartbeat(self, instance: ModelInstance) -> bool: _, exist_ins = self._get_instances( instance.model_name, instance.host, instance.port, healthy_only=False @@ -194,6 +207,10 @@ async def get_all_instances( ) -> List[ModelInstance]: pass + @api_remote(path="/api/controller/models") + async def get_all_model_instances(self) -> List[ModelInstance]: + pass + @api_remote(path="/api/controller/models") async def select_one_health_instance(self, model_name: str) -> ModelInstance: instances = await self.get_all_instances(model_name, healthy_only=True) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 490088b7d..e4478450a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -118,7 +118,6 @@ def loader( def loader_with_params(self, model_params: ModelParameters): llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) model_type = llm_adapter.model_type() - param_cls = llm_adapter.model_param_class(model_type) self.prompt_template = model_params.prompt_template logger.info(f"model_params:\n{model_params}") if model_type == ModelType.HF: diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index e78d8f450..baa646711 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -184,6 +184,19 @@ def update_from(self, source: Union["BaseParameters", dict]) -> bool: return updated + def __str__(self) -> str: + class_name = self.__class__.__name__ + parameters = [ + f"\n\n=========================== {class_name} ===========================\n" + ] + for field_info in fields(self): + value = getattr(self, field_info.name) + parameters.append(f"{field_info.name}: {value}") + parameters.append( + "\n======================================================================\n\n" + ) + return "\n".join(parameters) + @dataclass class ModelWorkerParameters(BaseParameters): diff --git a/pilot/model/worker/default_worker.py b/pilot/model/worker/default_worker.py index 199e273f0..deea90191 100644 --- a/pilot/model/worker/default_worker.py +++ b/pilot/model/worker/default_worker.py @@ -4,12 +4,12 @@ import torch from pilot.configs.model_config import DEVICE -from pilot.model.adapter import get_llm_model_adapter +from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper from pilot.model.base import ModelOutput from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.parameter import EnvArgumentParser, ModelParameters from pilot.model.worker.base import ModelWorker -from pilot.server.chat_adapter import get_llm_chat_adapter +from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter from pilot.utils.model_utils import _clear_torch_cache logger = logging.getLogger("model_worker") @@ -20,6 +20,8 @@ def __init__(self) -> None: self.model = None self.tokenizer = None self._model_params = None + self.llm_adapter: BaseLLMAdaper = None + self.llm_chat_adapter: BaseChatAdpter = None def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: if model_path.endswith("/"): @@ -28,9 +30,9 @@ def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: self.model_name = model_name self.model_path = model_path - llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) - model_type = llm_adapter.model_type() - self.param_cls = llm_adapter.model_param_class(model_type) + self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) + model_type = self.llm_adapter.model_type() + self.param_cls = self.llm_adapter.model_param_class(model_type) self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( @@ -50,12 +52,14 @@ def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: param_cls = self.model_param_class() model_args = EnvArgumentParser() env_prefix = EnvArgumentParser.get_env_prefix(self.model_name) + model_type = self.llm_adapter.model_type() model_params: ModelParameters = model_args.parse_args_into_dataclass( param_cls, env_prefix=env_prefix, command_args=command_args, model_name=self.model_name, model_path=self.model_path, + model_type=model_type, ) if not model_params.device: model_params.device = DEVICE diff --git a/pilot/model/worker/manager.py b/pilot/model/worker/manager.py index b858d5f55..3d18088e9 100644 --- a/pilot/model/worker/manager.py +++ b/pilot/model/worker/manager.py @@ -1,4 +1,5 @@ import asyncio +import httpx import itertools import json import os @@ -147,7 +148,7 @@ def __init__( self.model_registry = model_registry def _worker_key(self, worker_type: str, model_name: str) -> str: - return f"$${worker_type}_$$_{model_name}" + return f"{model_name}@{worker_type}" def add_worker( self, @@ -311,7 +312,9 @@ async def _apply_worker( # Apply to all workers worker_instances = list(itertools.chain(*self.workers.values())) logger.info(f"Apply to all workers: {worker_instances}") - await asyncio.gather(*(apply_func(worker) for worker in worker_instances)) + return await asyncio.gather( + *(apply_func(worker) for worker in worker_instances) + ) async def _start_all_worker( self, apply_req: WorkerApplyRequest @@ -423,6 +426,28 @@ async def get_model_instances( worker_instances.append(wr) return worker_instances + async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput: + async def _remote_apply_func(worker_run_data: WorkerRunData): + worker_addr = worker_run_data.worker.worker_addr + async with httpx.AsyncClient() as client: + response = await client.post( + worker_addr + "/apply", + headers=worker_run_data.worker.headers, + json=apply_req.dict(), + timeout=worker_run_data.worker.timeout, + ) + if response.status_code == 200: + output = WorkerApplyOutput(**response.json()) + logger.info(f"worker_apply success: {output}") + else: + output = WorkerApplyOutput(message=response.text) + logger.warn(f"worker_apply failed: {output}") + return output + + results = await self._apply_worker(apply_req, _remote_apply_func) + if results: + return results[0] + class WorkerManagerAdapter(WorkerManager): def __init__(self, worker_manager: WorkerManager = None) -> None: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index c5c6dfb4a..88f457935 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -28,10 +28,7 @@ from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.configs.model_config import LOGDIR, DATASETS_DIR -from pilot.utils import ( - build_logger, - server_error_msg, -) +from pilot.utils import build_logger, server_error_msg, get_or_create_event_loop from pilot.scene.base_message import ( BaseMessage, SystemMessage, @@ -222,13 +219,10 @@ async def nostream_call(self): return self.current_ai_response() def _blocking_stream_call(self): - import asyncio - logger.warn( "_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance" ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = get_or_create_event_loop() async_gen = self.stream_call() while True: try: @@ -238,13 +232,10 @@ def _blocking_stream_call(self): break def _blocking_nostream_call(self): - import asyncio - logger.warn( "_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance" ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = get_or_create_event_loop() return loop.run_until_complete(self.nostream_call()) def call(self): diff --git a/pilot/scripts/__init__.py b/pilot/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scripts/cli_scripts.py b/pilot/scripts/cli_scripts.py new file mode 100644 index 000000000..537b0ed25 --- /dev/null +++ b/pilot/scripts/cli_scripts.py @@ -0,0 +1,45 @@ +import sys +import click +import os +import copy +import logging + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +) + + +@click.group() +@click.option( + "--log-level", + required=False, + type=str, + default="warn", + help="Log level", +) +@click.version_option() +def cli(log_level: str): + # TODO not working now + logging.basicConfig(level=log_level, encoding="utf-8") + + +def add_command_alias(command, name: str, hidden: bool = False): + new_command = copy.deepcopy(command) + new_command.hidden = hidden + cli.add_command(new_command, name=name) + + +try: + from pilot.model.cli import model_cli_group + + add_command_alias(model_cli_group, name="model") +except ImportError as e: + logging.warning(f"Integrating dbgpt model command line tool failed: {e}") + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index ef4b7fc65..86ca49330 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -75,14 +75,20 @@ def swagger_monkey_patch(*args, **kwargs): app.include_router(knowledge_router) # app.include_router(api_editor_route_v1) + def mount_static_files(app): os.makedirs(static_message_img_path, exist_ok=True) app.mount( - "/images", StaticFiles(directory=static_message_img_path, html=True), name="static2" + "/images", + StaticFiles(directory=static_message_img_path, html=True), + name="static2", + ) + app.mount( + "/_next/static", StaticFiles(directory=static_file_path + "/_next/static") ) - app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static")) app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") + app.add_exception_handler(RequestValidationError, validation_exception_handler) if __name__ == "__main__": diff --git a/pilot/speech/brian.py b/pilot/speech/brian.py index 505c9a6f8..e0b59a87c 100644 --- a/pilot/speech/brian.py +++ b/pilot/speech/brian.py @@ -2,7 +2,6 @@ import os import requests -from playsound import playsound from pilot.speech.base import VoiceBase @@ -23,6 +22,8 @@ def _speech(self, text: str, _: int = 0) -> bool: Returns: bool: True if the request was successful, False otherwise """ + from playsound import playsound + tts_url = ( f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}" ) diff --git a/pilot/speech/eleven_labs.py b/pilot/speech/eleven_labs.py index dad841517..671a3d729 100644 --- a/pilot/speech/eleven_labs.py +++ b/pilot/speech/eleven_labs.py @@ -2,7 +2,6 @@ import os import requests -from playsound import playsound from pilot.configs.config import Config from pilot.speech.base import VoiceBase @@ -70,6 +69,7 @@ def _speech(self, text: str, voice_index: int = 0) -> bool: bool: True if the request was successful, False otherwise """ from pilot.logs import logger + from playsound import playsound tts_url = ( f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}" diff --git a/pilot/speech/gtts.py b/pilot/speech/gtts.py index 7ad164f30..8fc7df19c 100644 --- a/pilot/speech/gtts.py +++ b/pilot/speech/gtts.py @@ -2,7 +2,6 @@ import os import gtts -from playsound import playsound from pilot.speech.base import VoiceBase @@ -15,6 +14,8 @@ def _setup(self) -> None: def _speech(self, text: str, _: int = 0) -> bool: """Play the given text.""" + from playsound import playsound + tts = gtts.gTTS(text) tts.save("speech.mp3") playsound("speech.mp3", True) diff --git a/pilot/utils/__init__.py b/pilot/utils/__init__.py index 82aa640ae..8a84bc0ec 100644 --- a/pilot/utils/__init__.py +++ b/pilot/utils/__init__.py @@ -5,4 +5,5 @@ disable_torch_init, pretty_print_semaphore, server_error_msg, + get_or_create_event_loop, ) diff --git a/pilot/utils/api_utils.py b/pilot/utils/api_utils.py index af45ea3c6..ea2bef6ef 100644 --- a/pilot/utils/api_utils.py +++ b/pilot/utils/api_utils.py @@ -1,6 +1,7 @@ import httpx from inspect import signature import typing_inspect +import logging from typing import get_type_hints, List, Type, TypeVar, Union, Optional, Tuple from dataclasses import is_dataclass, asdict @@ -21,7 +22,9 @@ def decorator(func): raise TypeError("Return type must be annotated in the decorated function.") actual_dataclass = _extract_dataclass_from_generic(return_type) - print(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}") + logging.debug( + f"return_type: {return_type}, actual_dataclass: {actual_dataclass}" + ) if not actual_dataclass: actual_dataclass = return_type sig = signature(func) @@ -57,7 +60,9 @@ async def wrapper(self, *args, **kwargs): else: # For GET, DELETE, etc. request_params["params"] = request_data - print(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}") + logging.info( + f"request_params: {request_params}, args: {args}, kwargs: {kwargs}" + ) async with httpx.AsyncClient() as client: response = await client.request(**request_params) diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py index e5fef7281..ca7cf9d3c 100644 --- a/pilot/utils/utils.py +++ b/pilot/utils/utils.py @@ -5,8 +5,8 @@ import logging.handlers import os import sys +import asyncio -import torch from pilot.configs.model_config import LOGDIR server_error_msg = ( @@ -17,6 +17,8 @@ def get_gpu_memory(max_gpus=None): + import torch + gpu_memory = [] num_gpus = ( torch.cuda.device_count() @@ -130,3 +132,15 @@ def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +def get_or_create_event_loop() -> asyncio.BaseEventLoop: + try: + loop = asyncio.get_event_loop() + except Exception as e: + if not "no running event loop" in str(e): + raise e + logging.warning("Cant not get running event loop, create new event loop now") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop diff --git a/requirements.txt b/requirements.txt index e4dd4def2..daec0dc85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -76,4 +76,7 @@ bardapi==0.1.29 # TODO moved to optional dependencies pymysql duckdb -duckdb-engine \ No newline at end of file +duckdb-engine + +# cli +prettytable \ No newline at end of file diff --git a/setup.py b/setup.py index a9ee31213..d717b0bc6 100644 --- a/setup.py +++ b/setup.py @@ -267,7 +267,7 @@ def llama_cpp_python_cuda_requires(): llama_cpp_version = "0.1.77" py_version = "cp310" os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" - extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl" + extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl" extra_index_url, _ = encode_url(extra_index_url) print(f"Install llama_cpp_python_cuda from {extra_index_url}") @@ -361,7 +361,7 @@ def init_install_requires(): extras_require=setup_spec.extras, entry_points={ "console_scripts": [ - "dbgpt_server=pilot.server:webserver", + "dbgpt=pilot.scripts.cli_scripts:main", ], }, ) diff --git a/tools/cli/cli_scripts.py b/tools/cli/cli_scripts.py index 545acb4a5..2176cc272 100644 --- a/tools/cli/cli_scripts.py +++ b/tools/cli/cli_scripts.py @@ -25,7 +25,6 @@ from pilot.configs.model_config import DATASETS_DIR -from tools.cli.knowledge_client import knowledge_init API_ADDRESS: str = "http://127.0.0.1:5000" @@ -97,6 +96,8 @@ def knowledge( verbose: bool, ): """Knowledge command line tool""" + from tools.cli.knowledge_client import knowledge_init + knowledge_init( API_ADDRESS, vector_name,