Skip to content

Commit

Permalink
添加Baichuan2-7B-Chat模型接口文件 (labring#404)
Browse files Browse the repository at this point in the history
* 更新镜像

* 更新镜像信息

* 更新镜像信息

* Create openai_api.py

* Create requirements.txt
  • Loading branch information
stakeswky authored Oct 18, 2023
1 parent 3b776b6 commit b23e00f
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 0 deletions.
233 changes: 233 additions & 0 deletions files/models/Baichuan2/openai_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# coding=utf-8
# Implements API for Baichuan2-7B-Chat in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py

import gc
import time
import torch
import uvicorn
from pydantic import BaseModel, Field, validator
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Union
from transformers import AutoModelForCausalLM, AutoTokenizer
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
from transformers.generation.utils import GenerationConfig
import random
import string


@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


app = FastAPI(lifespan=lifespan)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None

class ModelList(BaseModel):
object: str = "list"
data: List[str] = [] # Assuming ModelCard is a string type. Replace with the correct type if not.

class ChatMessage(BaseModel):
role: str
content: str

@validator('role')
def check_role(cls, v):
if v not in ["user", "assistant", "system"]:
raise ValueError('role must be one of "user", "assistant", "system"')
return v

class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None

@validator('role', allow_reuse=True)
def check_role(cls, v):
if v is not None and v not in ["user", "assistant", "system"]:
raise ValueError('role must be one of "user", "assistant", "system"')
return v

class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = 8192 # max_length should be an integer.
stream: Optional[bool] = False

class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str

@validator('finish_reason')
def check_finish_reason(cls, v):
if v not in ["stop", "length"]:
raise ValueError('finish_reason must be one of "stop" or "length"')
return v

class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[str]

@validator('finish_reason', allow_reuse=True)
def check_finish_reason(cls, v):
if v is not None and v not in ["stop", "length"]:
raise ValueError('finish_reason must be one of "stop" or "length"')
return v

class ChatCompletionResponse(BaseModel):
id:str
object:str

@validator('object')
def check_object(cls,v):
if v not in ["chat.completion","chat.completion.chunk"]:
raise ValueError("object must be one of 'chat.completion' or 'chat.completion.chunk'")
return v

created :Optional[int]=Field(default_factory=lambda:int(time.time()))
model:str
choices :List[Union[ChatCompletionResponseChoice,ChatCompletionResponseStreamChoice]]


def generate_id():
possible_characters = string.ascii_letters + string.digits
random_string = ''.join(random.choices(possible_characters, k=29))
return 'chatcmpl-' + random_string


@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])


@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
messages = []
for message in prev_messages:
messages.append({"role": message.role, "content": message.content})

messages.append({"role": "user", "content": query})

if request.stream:
generate = predict(messages, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")

response = '本接口不支持非stream模式'
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
id='chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW'

return ChatCompletionResponse(id=id,model=request.model, choices=[choice_data], object="chat.completion")


async def predict(messages: List[List[str]], model_id: str):
global model, tokenizer
id = generate_id()
created = int(time.time())
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant",content=""),
finish_reason=None
)
chunk = ChatCompletionResponse(id=id,object="chat.completion.chunk",created=created,model=model_id, choices=[choice_data])
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

current_length = 0

for new_response in model.chat(tokenizer, messages, stream=True):
if len(new_response) == current_length:
continue

new_text = new_response[current_length:]
current_length = len(new_response)

choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(id=id,object="chat.completion.chunk",created=created,model=model_id, choices=[choice_data])
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))


choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(id=id,object="chat.completion.chunk",created=created,model=model_id, choices=[choice_data])
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield '[DONE]'


def load_models():
print("本次加载的大语言模型为: Baichuan-13B-Chat")
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", use_fast=False, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained("Baichuan2-13B-Chat", torch_dtype=torch.float32, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", torch_dtype=torch.float16, trust_remote_code=True)
model = model.cuda()
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-7B-Chat")
return tokenizer, model

if __name__ == "__main__":
tokenizer, model = load_models()
uvicorn.run(app, host='0.0.0.0', port=6006, workers=1)

while True:
try:
# 在这里执行您的程序逻辑

# 检查显存使用情况,如果超过阈值(例如90%),则触发垃圾回收
if torch.cuda.is_available():
gpu_memory_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
if gpu_memory_usage > 0.9:
gc.collect()
torch.cuda.empty_cache()
except RuntimeError as e:
if "out of memory" in str(e):
print("显存不足,正在重启程序...")
gc.collect()
torch.cuda.empty_cache()
time.sleep(5) # 等待一段时间以确保显存已释放
tokenizer, model = load_models()
else:
raise e


14 changes: 14 additions & 0 deletions files/models/Baichuan2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
protobuf
transformers==4.30.2
cpm_kernels
torch>=2.0
gradio
mdtex2html
sentencepiece
accelerate
sse-starlette
fastapi==0.99.1
pydantic==1.10.7
uvicorn==0.21.1
xformers
bitsandbytes

0 comments on commit b23e00f

Please sign in to comment.