Skip to content

Commit

Permalink
feat: Multi-model command line
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Aug 30, 2023
1 parent d467092 commit dd86fb8
Show file tree
Hide file tree
Showing 20 changed files with 317 additions and 35 deletions.
151 changes: 151 additions & 0 deletions pilot/model/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import click
import functools

from pilot.model.controller.registry import ModelRegistryClient
from pilot.model.worker.manager import (
RemoteWorkerManager,
WorkerApplyRequest,
WorkerApplyType,
)
from pilot.utils import get_or_create_event_loop


@click.group("model")
def model_cli_group():
pass


@model_cli_group.command()
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name", type=str, default=None, required=False, help=("The name of model")
)
@click.option(
"--model-type", type=str, default="llm", required=False, help=("The type of model")
)
def list(address: str, model_name: str, model_type: str):
"""List model instances"""
from prettytable import PrettyTable

loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)

if not model_name:
instances = loop.run_until_complete(registry.get_all_model_instances())
else:
if not model_type:
model_type = "llm"
register_model_name = f"{model_name}@{model_type}"
instances = loop.run_until_complete(
registry.get_all_instances(register_model_name)
)
table = PrettyTable()

table.field_names = [
"Model Name",
"Model Type",
"Host",
"Port",
"Healthy",
"Enabled",
"Prompt Template",
"Last Heartbeat",
]
for instance in instances:
model_name, model_type = instance.model_name.split("@")
table.add_row(
[
model_name,
model_type,
instance.host,
instance.port,
instance.healthy,
instance.enabled,
instance.prompt_template,
instance.last_heartbeat,
]
)

print(table)


def add_model_options(func):
@click.option(
"--address",
type=str,
default="http://127.0.0.1:8000",
required=False,
help=(
"Address of the Model Controller to connect to."
"Just support light deploy model"
),
)
@click.option(
"--model-name",
type=str,
default=None,
required=True,
help=("The name of model"),
)
@click.option(
"--model-type",
type=str,
default="llm",
required=False,
help=("The type of model"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


@model_cli_group.command()
@add_model_options
def stop(address: str, model_name: str, model_type: str):
"""Stop model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.STOP)


@model_cli_group.command()
@add_model_options
def start(address: str, model_name: str, model_type: str):
"""Start model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.START)


@model_cli_group.command()
@add_model_options
def restart(address: str, model_name: str, model_type: str):
"""Restart model instances"""
worker_apply(address, model_name, model_type, WorkerApplyType.RESTART)


# @model_cli_group.command()
# @add_model_options
# def modify(address: str, model_name: str, model_type: str):
# """Restart model instances"""
# worker_apply(address, model_name, model_type, WorkerApplyType.UPDATE_PARAMS)


def worker_apply(
address: str, model_name: str, model_type: str, apply_type: WorkerApplyType
):
loop = get_or_create_event_loop()
registry = ModelRegistryClient(address)
worker_manager = RemoteWorkerManager(registry)
apply_req = WorkerApplyRequest(
model=model_name, worker_type=model_type, apply_type=apply_type
)
res = loop.run_until_complete(worker_manager.worker_apply(apply_req))
print(res)
9 changes: 7 additions & 2 deletions pilot/model/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ async def get_all_instances(
)
return await self.registry.get_all_instances(model_name, healthy_only)

async def get_all_model_instances(self) -> List[ModelInstance]:
return await self.registry.get_all_model_instances()

async def send_heartbeat(self, instance: ModelInstance) -> bool:
return await self.registry.send_heartbeat(instance)

Expand All @@ -51,10 +54,12 @@ async def api_deregister_instance(request: ModelInstance):


@router.get("/controller/models")
async def api_get_all_instances(model_name: str, healthy_only: bool = False):
async def api_get_all_instances(model_name: str = None, healthy_only: bool = False):
if not model_name:
return await controller.get_all_model_instances()
return await controller.get_all_instances(model_name, healthy_only=healthy_only)


@router.post("/controller/heartbeat")
async def api_get_all_instances(request: ModelInstance):
async def api_model_heartbeat(request: ModelInstance):
return await controller.send_heartbeat(request)
19 changes: 18 additions & 1 deletion pilot/model/controller/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import itertools

from pilot.model.base import ModelInstance

Expand Down Expand Up @@ -57,6 +58,15 @@ async def get_all_instances(
- List[ModelInstance]: A list of instances for the given model.
"""

@abstractmethod
async def get_all_model_instances(self) -> List[ModelInstance]:
"""
Fetch all instances of all models
Returns:
- List[ModelInstance]: A list of instances for the all models.
"""

async def select_one_health_instance(self, model_name: str) -> ModelInstance:
"""
Selects one healthy and enabled instance for a given model.
Expand Down Expand Up @@ -154,12 +164,15 @@ async def deregister_instance(self, instance: ModelInstance) -> bool:
async def get_all_instances(
self, model_name: str, healthy_only: bool = False
) -> List[ModelInstance]:
print(self.registry)
instances = self.registry[model_name]
if healthy_only:
instances = [ins for ins in instances if ins.healthy == True]
return instances

async def get_all_model_instances(self) -> List[ModelInstance]:
print(self.registry)
return list(itertools.chain(*self.registry.values()))

async def send_heartbeat(self, instance: ModelInstance) -> bool:
_, exist_ins = self._get_instances(
instance.model_name, instance.host, instance.port, healthy_only=False
Expand Down Expand Up @@ -194,6 +207,10 @@ async def get_all_instances(
) -> List[ModelInstance]:
pass

@api_remote(path="/api/controller/models")
async def get_all_model_instances(self) -> List[ModelInstance]:
pass

@api_remote(path="/api/controller/models")
async def select_one_health_instance(self, model_name: str) -> ModelInstance:
instances = await self.get_all_instances(model_name, healthy_only=True)
Expand Down
1 change: 0 additions & 1 deletion pilot/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def loader(
def loader_with_params(self, model_params: ModelParameters):
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type()
param_cls = llm_adapter.model_param_class(model_type)
self.prompt_template = model_params.prompt_template
logger.info(f"model_params:\n{model_params}")
if model_type == ModelType.HF:
Expand Down
13 changes: 13 additions & 0 deletions pilot/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ def update_from(self, source: Union["BaseParameters", dict]) -> bool:

return updated

def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)


@dataclass
class ModelWorkerParameters(BaseParameters):
Expand Down
14 changes: 9 additions & 5 deletions pilot/model/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import torch
from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import EnvArgumentParser, ModelParameters
from pilot.model.worker.base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache

logger = logging.getLogger("model_worker")
Expand All @@ -20,6 +20,8 @@ def __init__(self) -> None:
self.model = None
self.tokenizer = None
self._model_params = None
self.llm_adapter: BaseLLMAdaper = None
self.llm_chat_adapter: BaseChatAdpter = None

def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
Expand All @@ -28,9 +30,9 @@ def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
self.model_name = model_name
self.model_path = model_path

llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type()
self.param_cls = llm_adapter.model_param_class(model_type)
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)

self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
Expand All @@ -50,12 +52,14 @@ def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(self.model_name)
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,
model_type=model_type,
)
if not model_params.device:
model_params.device = DEVICE
Expand Down
29 changes: 27 additions & 2 deletions pilot/model/worker/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import httpx
import itertools
import json
import os
Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(
self.model_registry = model_registry

def _worker_key(self, worker_type: str, model_name: str) -> str:
return f"$${worker_type}_$$_{model_name}"
return f"{model_name}@{worker_type}"

def add_worker(
self,
Expand Down Expand Up @@ -311,7 +312,9 @@ async def _apply_worker(
# Apply to all workers
worker_instances = list(itertools.chain(*self.workers.values()))
logger.info(f"Apply to all workers: {worker_instances}")
await asyncio.gather(*(apply_func(worker) for worker in worker_instances))
return await asyncio.gather(
*(apply_func(worker) for worker in worker_instances)
)

async def _start_all_worker(
self, apply_req: WorkerApplyRequest
Expand Down Expand Up @@ -423,6 +426,28 @@ async def get_model_instances(
worker_instances.append(wr)
return worker_instances

async def worker_apply(self, apply_req: WorkerApplyRequest) -> WorkerApplyOutput:
async def _remote_apply_func(worker_run_data: WorkerRunData):
worker_addr = worker_run_data.worker.worker_addr
async with httpx.AsyncClient() as client:
response = await client.post(
worker_addr + "/apply",
headers=worker_run_data.worker.headers,
json=apply_req.dict(),
timeout=worker_run_data.worker.timeout,
)
if response.status_code == 200:
output = WorkerApplyOutput(**response.json())
logger.info(f"worker_apply success: {output}")
else:
output = WorkerApplyOutput(message=response.text)
logger.warn(f"worker_apply failed: {output}")
return output

results = await self._apply_worker(apply_req, _remote_apply_func)
if results:
return results[0]


class WorkerManagerAdapter(WorkerManager):
def __init__(self, worker_manager: WorkerManager = None) -> None:
Expand Down
Loading

0 comments on commit dd86fb8

Please sign in to comment.