Skip to content

Commit

Permalink
已使用金额优化
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunreal committed May 5, 2023
1 parent a455023 commit 9514f78
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
4 changes: 2 additions & 2 deletions ChuanhuChatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def create_new_model():
changeProxyBtn = gr.Button(i18n("🔄 设置代理地址"))
default_btn = gr.Button(i18n("🔙 恢复默认设置"))

gr.Markdown(CHUANHU_DESCRIPTION, elem_id="description")
gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
# gr.Markdown(CHUANHU_DESCRIPTION, elem_id="description")
# gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")

# https://github.com/gradio-app/gradio/pull/3296
def create_greeting(request: gr.Request):
Expand Down
67 changes: 36 additions & 31 deletions modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import aiohttp
from enum import Enum
import uuid
import timedelta

from ..presets import *
from ..llama_func import *
Expand All @@ -32,13 +33,13 @@

class OpenAIClient(BaseLLMModel):
def __init__(
self,
model_name,
api_key,
system_prompt=INITIAL_SYSTEM_PROMPT,
temperature=1.0,
top_p=1.0,
user_name=""
self,
model_name,
api_key,
system_prompt=INITIAL_SYSTEM_PROMPT,
temperature=1.0,
top_p=1.0,
user_name=""
) -> None:
super().__init__(
model_name=model_name,
Expand Down Expand Up @@ -81,10 +82,11 @@ def count_token(self, user_input):
def billing_info(self):
try:
curr_time = datetime.datetime.now()
curr_time_f = (curr_time + datetime.timedelta(days=1)).strftime("%Y-%m-%d")
last_day_of_month = get_last_day_of_month(
curr_time).strftime("%Y-%m-%d")
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
first_day_of_month = (curr_time.replace(day=1) - datetime.timedelta(weeks=11)).strftime("%Y-%m-%d")
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={curr_time_f}"
try:
usage_data = self._get_billing_data(usage_url)
except Exception as e:
Expand All @@ -95,7 +97,7 @@ def billing_info(self):
usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
return """\
<b>""" + i18n("本月使用金额") + f"""</b>
<b>""" + i18n("已使用金额") + f"""</b>
<div class="progress-bar">
<div class="progress" style="width: {usage_percent}%;">
<span class="progress-text">{usage_percent}%</span>
Expand All @@ -105,7 +107,7 @@ def billing_info(self):
"""
except requests.exceptions.ConnectTimeout:
status_text = (
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
)
return status_text
except requests.exceptions.ReadTimeout:
Expand Down Expand Up @@ -275,7 +277,7 @@ def _get_glm_style_input(self):
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
assert (
len(history) % 2 == 0
len(history) % 2 == 0
), f"History should be even length. current history is: {history}"
history = [[history[i], history[i + 1]]
for i in range(0, len(history), 2)]
Expand All @@ -290,22 +292,22 @@ def get_answer_at_once(self):
def get_answer_stream_iter(self):
history, query = self._get_glm_style_input()
for response, history in CHATGLM_MODEL.stream_chat(
CHATGLM_TOKENIZER,
query,
history,
max_length=self.token_upper_limit,
top_p=self.top_p,
temperature=self.temperature,
CHATGLM_TOKENIZER,
query,
history,
max_length=self.token_upper_limit,
top_p=self.top_p,
temperature=self.temperature,
):
yield response


class LLaMA_Client(BaseLLMModel):
def __init__(
self,
model_name,
lora_path=None,
user_name=""
self,
model_name,
lora_path=None,
user_name=""
) -> None:
super().__init__(model_name=model_name, user=user_name)
from lmflow.datasets.dataset import Dataset
Expand Down Expand Up @@ -334,8 +336,11 @@ def __init__(
# raise Exception(f"models目录下没有这个模型: {model_name}")
if lora_path is not None:
lora_path = f"lora/{lora_path}"
model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None,
config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
use_fast_tokenizer=True, model_revision='main', use_auth_token=False,
torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1,
use_ram_optimized_load=True)
pipeline_args = InferencerArguments(
local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')

Expand Down Expand Up @@ -540,13 +545,13 @@ def get_answer_at_once(self):


def get_model(
model_name,
lora_model_path=None,
access_key=None,
temperature=None,
top_p=None,
system_prompt=None,
user_name=""
model_name,
lora_model_path=None,
access_key=None,
temperature=None,
top_p=None,
system_prompt=None,
user_name=""
) -> BaseLLMModel:
msg = i18n("模型设置为了:") + f" {model_name}"
model_type = ModelType.get_type(model_name)
Expand Down

0 comments on commit 9514f78

Please sign in to comment.