Skip to content

Commit

Permalink
Merge branch 'main' of github.com:YerongLi/ChatGLM-Efficient-Tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
YerongLi committed Jul 6, 2023
2 parents c76bede + dee88ac commit 61fafdd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
3 changes: 2 additions & 1 deletion demo.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
python src/cli_demo.py \
--model_name_or_path THUDM/chatglm-6b \
--checkpoint_dir harry_potter/checkpoint-700
--checkpoint_dir harry_potter/checkpoint-700 \
--quantization_bit 4
25 changes: 21 additions & 4 deletions src/cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
import torch
import random
import signal
import platform

Expand Down Expand Up @@ -44,6 +45,22 @@ def signal_handler(signal, frame):
("If you could give one piece of advice to young witches and wizards starting their magical education, what would it be?", "I would advise them to believe in themselves and not be afraid to ask for help when needed.")
]

def truncate_history(history):
total_words = 0
selected_history = []

for i in range(len(history)-1, -1, -1):
question, answer = history[i]
words = len(question.split()) + len(answer.split())

if total_words + words <= 1800:
selected_history.append((question, answer))
total_words += words
else:
break

return selected_history

def main():

global stop_stream
Expand All @@ -60,7 +77,7 @@ def main():
model.eval()


history = []
history = buffered_history.copy()
print(welcome)
while True:
try:
Expand All @@ -74,13 +91,13 @@ def main():
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
history = buffered_history.copy()
os.system(clear_command)
print(welcome)
continue

history = truncate_history(history)
count = 0
for _, history in model.stream_chat(tokenizer, query, history=(buffered_history + history)[:-6], **generating_args.to_dict()):
for _, history in model.stream_chat(tokenizer, query, history=history, **generating_args.to_dict()):
if stop_stream:
stop_stream = False
break
Expand Down

0 comments on commit 61fafdd

Please sign in to comment.