diff --git a/src/glmtuner/api/app.py b/src/glmtuner/api/app.py index 3c240fc..63dbd72 100644 --- a/src/glmtuner/api/app.py +++ b/src/glmtuner/api/app.py @@ -10,6 +10,8 @@ from glmtuner.extras.misc import torch_gc from glmtuner.chat.stream_chat import ChatModel from glmtuner.api.protocol import ( + Role, + Finish, ModelCard, ModelList, ChatMessage, @@ -49,18 +51,18 @@ async def list_models(): @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): - if request.messages[-1].role != "user": + if request.messages[-1].role != Role.USER: raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] - if len(prev_messages) > 0 and prev_messages[0].role == "system": + if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: query = prev_messages.pop(0).content + query history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): - if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": + if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: @@ -79,19 +81,19 @@ async def create_chat_completion(request: ChatCompletionRequest): choice_data = ChatCompletionResponseChoice( index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop" + message=ChatMessage(role=Role.ASSISTANT, content=response), + finish_reason=Finish.STOP ) - return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") + return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( index=0, - delta=DeltaMessage(role="assistant"), + delta=DeltaMessage(role=Role.ASSISTANT), finish_reason=None ) - chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield json.dumps(chunk, ensure_ascii=False) for new_text in chat_model.stream_chat( @@ -105,15 +107,15 @@ async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompl delta=DeltaMessage(content=new_text), finish_reason=None ) - chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield json.dumps(chunk, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), - finish_reason="stop" + finish_reason=Finish.STOP ) - chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield json.dumps(chunk, ensure_ascii=False) yield "[DONE]" diff --git a/src/glmtuner/api/protocol.py b/src/glmtuner/api/protocol.py index 08aea3c..cba0b6a 100644 --- a/src/glmtuner/api/protocol.py +++ b/src/glmtuner/api/protocol.py @@ -1,6 +1,18 @@ import time +from enum import Enum from pydantic import BaseModel, Field -from typing import List, Literal, Optional +from typing import List, Optional + + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +class Finish(str, Enum): + STOP = "stop" + LENGTH = "length" class ModelCard(BaseModel): @@ -19,12 +31,12 @@ class ModelList(BaseModel): class ChatMessage(BaseModel): - role: Literal["user", "assistant", "system"] + role: Role content: str class DeltaMessage(BaseModel): - role: Optional[Literal["user", "assistant", "system"]] = None + role: Optional[Role] = None content: Optional[str] = None @@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - finish_reason: Literal["stop", "length"] + finish_reason: Finish class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] = None + finish_reason: Optional[Finish] = None class ChatCompletionResponseUsage(BaseModel): @@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponse(BaseModel): id: Optional[str] = "chatcmpl-default" - object: Literal["chat.completion"] + object: Optional[str] = "chat.completion" created: Optional[int] = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] @@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel): id: Optional[str] = "chatcmpl-default" - object: Literal["chat.completion.chunk"] + object: Optional[str] = "chat.completion.chunk" created: Optional[int] = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] diff --git a/src/glmtuner/extras/callbacks.py b/src/glmtuner/extras/callbacks.py index 493596b..112178b 100644 --- a/src/glmtuner/extras/callbacks.py +++ b/src/glmtuner/extras/callbacks.py @@ -47,6 +47,9 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC r""" Event called after logging the last logs. """ + if not state.is_world_process_zero: + return + cur_time = time.time() cur_steps = state.log_history[-1].get("step") elapsed_time = cur_time - self.start_time