Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 18, 2023
1 parent e4ba88f commit 6ec8108
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
24 changes: 13 additions & 11 deletions src/glmtuner/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from glmtuner.extras.misc import torch_gc
from glmtuner.chat.stream_chat import ChatModel
from glmtuner.api.protocol import (
Role,
Finish,
ModelCard,
ModelList,
ChatMessage,
Expand Down Expand Up @@ -49,18 +51,18 @@ async def list_models():

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user":
if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content

prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
query = prev_messages.pop(0).content + query

history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content])

if request.stream:
Expand All @@ -79,19 +81,19 @@ async def create_chat_completion(request: ChatCompletionRequest):

choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason=Finish.STOP
)

return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)

async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)

for new_text in chat_model.stream_chat(
Expand All @@ -105,15 +107,15 @@ async def predict(query: str, history: List[Tuple[str, str]], request: ChatCompl
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)

choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]"

Expand Down
26 changes: 19 additions & 7 deletions src/glmtuner/api/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import time
from enum import Enum
from pydantic import BaseModel, Field
from typing import List, Literal, Optional
from typing import List, Optional


class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"


class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"


class ModelCard(BaseModel):
Expand All @@ -19,12 +31,12 @@ class ModelList(BaseModel):


class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
role: Role
content: str


class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
role: Optional[Role] = None
content: Optional[str] = None


Expand All @@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
finish_reason: Finish


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
finish_reason: Optional[Finish] = None


class ChatCompletionResponseUsage(BaseModel):
Expand All @@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):

class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion"]
object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
Expand All @@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):

class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"]
object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
3 changes: 3 additions & 0 deletions src/glmtuner/extras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerC
r"""
Event called after logging the last logs.
"""
if not state.is_world_process_zero:
return

cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
Expand Down

0 comments on commit 6ec8108

Please sign in to comment.