forked from labring/FastGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
添加Baichuan2-7B-Chat模型接口文件 (labring#404)
* 更新镜像 * 更新镜像信息 * 更新镜像信息 * Create openai_api.py * Create requirements.txt
- Loading branch information
Showing
2 changed files
with
247 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
# coding=utf-8 | ||
# Implements API for Baichuan2-7B-Chat in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) | ||
# Usage: python openai_api.py | ||
|
||
import gc | ||
import time | ||
import torch | ||
import uvicorn | ||
from pydantic import BaseModel, Field, validator | ||
from fastapi import FastAPI, HTTPException | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from contextlib import asynccontextmanager | ||
from typing import Any, Dict, List, Optional, Union | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from sse_starlette.sse import ServerSentEvent, EventSourceResponse | ||
from transformers.generation.utils import GenerationConfig | ||
import random | ||
import string | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): # collects GPU memory | ||
yield | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
torch.cuda.ipc_collect() | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
class ModelCard(BaseModel): | ||
id: str | ||
object: str = "model" | ||
created: int = Field(default_factory=lambda: int(time.time())) | ||
owned_by: str = "owner" | ||
root: Optional[str] = None | ||
parent: Optional[str] = None | ||
permission: Optional[list] = None | ||
|
||
class ModelList(BaseModel): | ||
object: str = "list" | ||
data: List[str] = [] # Assuming ModelCard is a string type. Replace with the correct type if not. | ||
|
||
class ChatMessage(BaseModel): | ||
role: str | ||
content: str | ||
|
||
@validator('role') | ||
def check_role(cls, v): | ||
if v not in ["user", "assistant", "system"]: | ||
raise ValueError('role must be one of "user", "assistant", "system"') | ||
return v | ||
|
||
class DeltaMessage(BaseModel): | ||
role: Optional[str] = None | ||
content: Optional[str] = None | ||
|
||
@validator('role', allow_reuse=True) | ||
def check_role(cls, v): | ||
if v is not None and v not in ["user", "assistant", "system"]: | ||
raise ValueError('role must be one of "user", "assistant", "system"') | ||
return v | ||
|
||
class ChatCompletionRequest(BaseModel): | ||
model: str | ||
messages: List[ChatMessage] | ||
temperature: Optional[float] = None | ||
top_p: Optional[float] = None | ||
max_length: Optional[int] = 8192 # max_length should be an integer. | ||
stream: Optional[bool] = False | ||
|
||
class ChatCompletionResponseChoice(BaseModel): | ||
index: int | ||
message: ChatMessage | ||
finish_reason: str | ||
|
||
@validator('finish_reason') | ||
def check_finish_reason(cls, v): | ||
if v not in ["stop", "length"]: | ||
raise ValueError('finish_reason must be one of "stop" or "length"') | ||
return v | ||
|
||
class ChatCompletionResponseStreamChoice(BaseModel): | ||
index: int | ||
delta: DeltaMessage | ||
finish_reason: Optional[str] | ||
|
||
@validator('finish_reason', allow_reuse=True) | ||
def check_finish_reason(cls, v): | ||
if v is not None and v not in ["stop", "length"]: | ||
raise ValueError('finish_reason must be one of "stop" or "length"') | ||
return v | ||
|
||
class ChatCompletionResponse(BaseModel): | ||
id:str | ||
object:str | ||
|
||
@validator('object') | ||
def check_object(cls,v): | ||
if v not in ["chat.completion","chat.completion.chunk"]: | ||
raise ValueError("object must be one of 'chat.completion' or 'chat.completion.chunk'") | ||
return v | ||
|
||
created :Optional[int]=Field(default_factory=lambda:int(time.time())) | ||
model:str | ||
choices :List[Union[ChatCompletionResponseChoice,ChatCompletionResponseStreamChoice]] | ||
|
||
|
||
def generate_id(): | ||
possible_characters = string.ascii_letters + string.digits | ||
random_string = ''.join(random.choices(possible_characters, k=29)) | ||
return 'chatcmpl-' + random_string | ||
|
||
|
||
@app.get("/v1/models", response_model=ModelList) | ||
async def list_models(): | ||
global model_args | ||
model_card = ModelCard(id="gpt-3.5-turbo") | ||
return ModelList(data=[model_card]) | ||
|
||
|
||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) | ||
async def create_chat_completion(request: ChatCompletionRequest): | ||
global model, tokenizer | ||
if request.messages[-1].role != "user": | ||
raise HTTPException(status_code=400, detail="Invalid request") | ||
query = request.messages[-1].content | ||
prev_messages = request.messages[:-1] | ||
if len(prev_messages) > 0 and prev_messages[0].role == "system": | ||
query = prev_messages.pop(0).content + query | ||
messages = [] | ||
for message in prev_messages: | ||
messages.append({"role": message.role, "content": message.content}) | ||
|
||
messages.append({"role": "user", "content": query}) | ||
|
||
if request.stream: | ||
generate = predict(messages, request.model) | ||
return EventSourceResponse(generate, media_type="text/event-stream") | ||
|
||
response = '本接口不支持非stream模式' | ||
choice_data = ChatCompletionResponseChoice( | ||
index=0, | ||
message=ChatMessage(role="assistant", content=response), | ||
finish_reason="stop" | ||
) | ||
id='chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW' | ||
|
||
return ChatCompletionResponse(id=id,model=request.model, choices=[choice_data], object="chat.completion") | ||
|
||
|
||
async def predict(messages: List[List[str]], model_id: str): | ||
global model, tokenizer | ||
id = generate_id() | ||
created = int(time.time()) | ||
choice_data = ChatCompletionResponseStreamChoice( | ||
index=0, | ||
delta=DeltaMessage(role="assistant",content=""), | ||
finish_reason=None | ||
) | ||
chunk = ChatCompletionResponse(id=id,object="chat.completion.chunk",created=created,model=model_id, choices=[choice_data]) | ||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||
|
||
current_length = 0 | ||
|
||
for new_response in model.chat(tokenizer, messages, stream=True): | ||
if len(new_response) == current_length: | ||
continue | ||
|
||
new_text = new_response[current_length:] | ||
current_length = len(new_response) | ||
|
||
choice_data = ChatCompletionResponseStreamChoice( | ||
index=0, | ||
delta=DeltaMessage(content=new_text), | ||
finish_reason=None | ||
) | ||
chunk = ChatCompletionResponse(id=id,object="chat.completion.chunk",created=created,model=model_id, choices=[choice_data]) | ||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||
|
||
|
||
choice_data = ChatCompletionResponseStreamChoice( | ||
index=0, | ||
delta=DeltaMessage(), | ||
finish_reason="stop" | ||
) | ||
chunk = ChatCompletionResponse(id=id,object="chat.completion.chunk",created=created,model=model_id, choices=[choice_data]) | ||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | ||
yield '[DONE]' | ||
|
||
|
||
def load_models(): | ||
print("本次加载的大语言模型为: Baichuan-13B-Chat") | ||
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", use_fast=False, trust_remote_code=True) | ||
# model = AutoModelForCausalLM.from_pretrained("Baichuan2-13B-Chat", torch_dtype=torch.float32, trust_remote_code=True) | ||
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-7B-Chat", torch_dtype=torch.float16, trust_remote_code=True) | ||
model = model.cuda() | ||
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-7B-Chat") | ||
return tokenizer, model | ||
|
||
if __name__ == "__main__": | ||
tokenizer, model = load_models() | ||
uvicorn.run(app, host='0.0.0.0', port=6006, workers=1) | ||
|
||
while True: | ||
try: | ||
# 在这里执行您的程序逻辑 | ||
|
||
# 检查显存使用情况,如果超过阈值(例如90%),则触发垃圾回收 | ||
if torch.cuda.is_available(): | ||
gpu_memory_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() | ||
if gpu_memory_usage > 0.9: | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
except RuntimeError as e: | ||
if "out of memory" in str(e): | ||
print("显存不足,正在重启程序...") | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
time.sleep(5) # 等待一段时间以确保显存已释放 | ||
tokenizer, model = load_models() | ||
else: | ||
raise e | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
protobuf | ||
transformers==4.30.2 | ||
cpm_kernels | ||
torch>=2.0 | ||
gradio | ||
mdtex2html | ||
sentencepiece | ||
accelerate | ||
sse-starlette | ||
fastapi==0.99.1 | ||
pydantic==1.10.7 | ||
uvicorn==0.21.1 | ||
xformers | ||
bitsandbytes |