Skip to content

Commit

Permalink
feat: add xunfei spark bot
Browse files Browse the repository at this point in the history
  • Loading branch information
zhayujie committed Aug 25, 2023
1 parent 1171b04 commit a086f19
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 16 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
最新版本支持的功能如下:

- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言模型
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言, 讯飞星火
- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate, midjourney模型
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件
Expand Down Expand Up @@ -113,7 +113,7 @@ pip3 install azure-cognitiveservices-speech
# config.json文件内容示例
{
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
Expand All @@ -129,7 +129,10 @@ pip3 install azure-cognitiveservices-speech
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。"
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
"linkai_api_key": "", # LinkAI Api Key
"linkai_app_code": "" # LinkAI 应用code
}
```
**配置说明:**
Expand Down Expand Up @@ -166,6 +169,12 @@ pip3 install azure-cognitiveservices-speech
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。

**5.LinkAI配置 (可选)**

+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://chat.link-ai.tech/console/interface) 创建
+ `linkai_app_code`: LinkAI 应用code,选填

**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**

## 运行
Expand Down
10 changes: 4 additions & 6 deletions bot/bot_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,29 @@ def create_bot(bot_type):
# 替换Baidu Unit为Baidu文心千帆对话接口
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
# return BaiduUnitBot()

from bot.baidu.baidu_wenxin import BaiduWenxinBot

return BaiduWenxinBot()

elif bot_type == const.CHATGPT:
# ChatGPT 网页端web接口
from bot.chatgpt.chat_gpt_bot import ChatGPTBot

return ChatGPTBot()

elif bot_type == const.OPEN_AI:
# OpenAI 官方对话模型API
from bot.openai.open_ai_bot import OpenAIBot

return OpenAIBot()

elif bot_type == const.CHATGPTONAZURE:
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot

return AzureChatGPTBot()

elif bot_type == const.XUNFEI:
from bot.xunfei.xunfei_spark_bot import XunFeiBot
return XunFeiBot()

elif bot_type == const.LINKAI:
from bot.linkai.link_ai_bot import LinkAIBot
return LinkAIBot()

raise RuntimeError
15 changes: 14 additions & 1 deletion bot/chatgpt/chat_gpt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ def calc_tokens(self):
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""

if model in ["wenxin", "xunfei"]:
return num_tokens_by_character(messages)

import tiktoken

if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]:
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
return num_tokens_from_messages(messages, model="gpt-4")

try:
Expand All @@ -85,3 +90,11 @@ def num_tokens_from_messages(messages, model):
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def num_tokens_by_character(messages):
"""Returns the number of tokens used by a list of messages."""
tokens = 0
for msg in messages:
tokens += len(msg["content"])
return tokens
246 changes: 246 additions & 0 deletions bot/xunfei/xunfei_spark_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# encoding:utf-8

import requests, json
from bot.bot import Bot
from bot.session_manager import SessionManager
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from bridge.context import ContextType, Context
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
from common import const
import time
import _thread as thread
import datetime
from datetime import datetime
from wsgiref.handlers import format_date_time
from urllib.parse import urlencode
import base64
import ssl
import hashlib
import hmac
import json
from time import mktime
from urllib.parse import urlparse
import websocket
import queue
import threading
import random

# 消息队列 map
queue_map = dict()


class XunFeiBot(Bot):
def __init__(self):
super().__init__()
self.app_id = conf().get("xunfei_app_id")
self.api_key = conf().get("xunfei_api_key")
self.api_secret = conf().get("xunfei_api_secret")
# 默认使用v2.0版本,1.5版本可设置为 general
self.domain = "generalv2"
# 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat"
self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
self.host = urlparse(self.spark_url).netloc
self.path = urlparse(self.spark_url).path
self.answer = ""
# 和wenxin使用相同的session机制
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)

def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
logger.info("[XunFei] query={}".format(query))
session_id = context["session_id"]
request_id = self.gen_request_id(session_id)
session = self.sessions.session_query(query, session_id)
threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
depth = 0
time.sleep(0.1)
t1 = time.time()
usage = {}
while depth <= 300:
try:
data_queue = queue_map.get(request_id)
if not data_queue:
depth += 1
time.sleep(0.1)
continue
data_item = data_queue.get(block=True, timeout=0.1)
if data_item.is_end:
# 请求结束
del queue_map[request_id]
if data_item.reply:
self.answer += data_item.reply
usage = data_item.usage
break

self.answer += data_item.reply
depth += 1
except Exception as e:
depth += 1
continue
t2 = time.time()
logger.info(f"[XunFei-API] response={self.answer}, time={t2 - t1}s, usage={usage}")
self.sessions.session_reply(self.answer, session_id, usage.get("total_tokens"))
reply = Reply(ReplyType.TEXT, self.answer)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def create_web_socket(self, prompt, session_id, temperature=0.5):
logger.info(f"[XunFei] start connect, prompt={prompt}")
websocket.enableTrace(False)
wsUrl = self.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close,
on_open=on_open)
data_queue = queue.Queue(1000)
queue_map[session_id] = data_queue
ws.appid = self.app_id
ws.question = prompt
ws.domain = self.domain
ws.session_id = session_id
ws.temperature = temperature
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

def gen_request_id(self, session_id: str):
return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100))

# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))

# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"

# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()

signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
f'signature="{signature_sha_base64}"'

authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数,生成url
url = self.spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url

def gen_params(self, appid, domain, question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data


class ReplyItem:
def __init__(self, reply, usage=None, is_end=False):
self.is_end = is_end
self.reply = reply
self.usage = usage


# 收到websocket错误的处理
def on_error(ws, error):
logger.error("[XunFei] error:", error)


# 收到websocket关闭的处理
def on_close(ws, one, two):
data_queue = queue_map.get(ws.session_id)
data_queue.put("END")


# 收到websocket连接建立的处理
def on_open(ws):
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
thread.start_new_thread(run, (ws,))


def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature))
ws.send(data)


# Websocket 操作
# 收到websocket消息的处理
def on_message(ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
logger.error(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
data_queue = queue_map.get(ws.session_id)
if not data_queue:
logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}")
return
reply_item = ReplyItem(content)
if status == 2:
usage = data["payload"].get("usage")
reply_item = ReplyItem(content, usage)
reply_item.is_end = True
ws.close()
data_queue.put(reply_item)


def gen_params(appid, domain, question, temperature=0.5):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"temperature": temperature,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
2 changes: 2 additions & 0 deletions bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self):
self.btype["chat"] = const.CHATGPTONAZURE
if model_type in ["wenxin"]:
self.btype["chat"] = const.BAIDU
if model_type in ["xunfei"]:
self.btype["chat"] = const.XUNFEI
if conf().get("use_linkai") and conf().get("linkai_api_key"):
self.btype["chat"] = const.LINKAI
self.bots = {}
Expand Down
1 change: 1 addition & 0 deletions common/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
OPEN_AI = "openAI"
CHATGPT = "chatGPT"
BAIDU = "baidu"
XUNFEI = "xunfei"
CHATGPTONAZURE = "chatGPTOnAzure"
LINKAI = "linkai"

Expand Down
12 changes: 8 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"open_ai_api_base": "https://api.openai.com/v1",
"proxy": "", # openai使用的代理
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin
"model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
"azure_deployment_id": "", # azure 模型部署名称
"azure_api_version": "", # azure api版本
Expand Down Expand Up @@ -52,9 +52,13 @@
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
# Baidu 文心一言参数
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
"baidu_wenxin_api_key": "", # Baidu api key
"baidu_wenxin_secret_key": "", # Baidu secret key
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
"baidu_wenxin_api_key": "", # Baidu api key
"baidu_wenxin_secret_key": "", # Baidu secret key
# 讯飞星火API
"xunfei_app_id": "", # 讯飞应用ID
"xunfei_api_key": "", # 讯飞 API key
"xunfei_api_secret": "", # 讯飞 API secret
# 语音设置
"speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别
Expand Down
Loading

0 comments on commit a086f19

Please sign in to comment.