forked from modelscope/ms-swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
30 lines (20 loc) · 917 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from dataclasses import dataclass
from swift.llm import get_default_template_type, get_template, get_vllm_engine, inference_vllm
from swift.utils import get_main
@dataclass
class VLLMTestArgs:
model_type: str
def test_vllm(args: VLLMTestArgs) -> None:
model_type = args.model_type
llm_engine = get_vllm_engine(model_type)
template_type = get_default_template_type(model_type)
template = get_template(template_type, llm_engine.hf_tokenizer)
llm_engine.generation_config.max_new_tokens = 256
request_list = [{'query': '你好!'}, {'query': '浙江的省会在哪?'}]
resp_list = inference_vllm(llm_engine, template, request_list)
for request, resp in zip(request_list, resp_list):
print(f"query: {request['query']}")
print(f"response: {resp['response']}")
test_vllm_main = get_main(VLLMTestArgs, test_vllm)
if __name__ == '__main__':
test_vllm_main()