Skip to content

Commit

Permalink
fix streaming response in API
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 5, 2023
1 parent 05ac8d0 commit 5f29b8a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Our script now supports the following fine-tuning methods:
- protobuf, cpm_kernels and sentencepiece
- jieba, rouge_chinese and nltk (used at evaluation)
- gradio and mdtex2html (used in web_demo.py)
- uvicorn, fastapi and sse_starlette (used in api_demo.py)

And **powerful GPUs**!

Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ huggingface-cli login
- protobuf, cpm_kernels, sentencepiece
- jieba, rouge_chinese, nltk(用于评估)
- gradio, mdtex2html(用于网页端交互)
- uvicorn, fastapi, sse_starlette(用于 API)

以及 **强而有力的 GPU**

Expand Down
13 changes: 8 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
torch>=1.13.1
protobuf
cpm_kernels
sentencepiece
transformers>=4.27.4
datasets>=2.10.0
accelerate>=0.18.0
accelerate>=0.19.0
peft>=0.3.0
trl>=0.4.1
trl>=0.4.4
protobuf
cpm_kernels
sentencepiece
jieba
rouge_chinese
nltk
gradio
mdtex2html
uvicorn
fastapi
sse_starlette
14 changes: 7 additions & 7 deletions src/api_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from sse_starlette import EventSourceResponse
from typing import Any, Dict, List, Literal, Optional, Union

from utils import (
Expand Down Expand Up @@ -134,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest):

if request.stream:
generate = predict(query, history, gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream")

response, _ = model.chat(tokenizer, query, history=history, **gen_kwargs)
choice_data = ChatCompletionResponseChoice(
Expand All @@ -155,7 +155,7 @@ async def predict(query: str, history: List[List[str]], gen_kwargs: Dict[str, An
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)

current_length = 0

Expand All @@ -172,15 +172,16 @@ async def predict(query: str, history: List[List[str]], gen_kwargs: Dict[str, An
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)

choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]"


if __name__ == "__main__":
Expand All @@ -195,5 +196,4 @@ async def predict(query: str, history: List[List[str]], gen_kwargs: Dict[str, An
model = model.cuda()

model.eval()

uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

0 comments on commit 5f29b8a

Please sign in to comment.