Skip to content

Commit

Permalink
Merge pull request hiyouga#321 from Yang-HangWA/main
Browse files Browse the repository at this point in the history
code simplification
  • Loading branch information
hiyouga committed Jul 19, 2023
2 parents 2053fc7 + 75282d3 commit 81453be
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
9 changes: 4 additions & 5 deletions src/glmtuner/chat/stream_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def __init__(
else:
self.model = self.model.cuda()

self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
self.source_prefix = data_args.source_prefix or ""
self.generating_args = generating_args

def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None
) -> str:
prefix = prefix + "\n" if prefix else "" # add separator for non-empty prefix
history = history if history else []
history = history or []
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i+1, old_query, response)
Expand All @@ -45,7 +45,7 @@ def get_prompt(
def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix
prefix = prefix or self.source_prefix

inputs = self.tokenizer([self.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
Expand Down Expand Up @@ -100,5 +100,4 @@ def stream_chat(
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()

for new_text in streamer:
yield new_text
yield from streamer
2 changes: 1 addition & 1 deletion src/glmtuner/dsets/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> BatchEncoding:
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
else:
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
input_ids = input_ids + labels # pad them to the same length
input_ids += labels # pad them to the same length

input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
Expand Down
14 changes: 3 additions & 11 deletions src/web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,9 @@ def main():

manager = Manager([{"lang": lang}, chat_elems])

demo.load(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)

lang.change(
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))

lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))

demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
Expand Down

0 comments on commit 81453be

Please sign in to comment.