更新测试对话文件,添加执行截图

This commit is contained in:
HongCheng 2024-04-20 04:53:42 +09:00
parent 282fb3bf69
commit f58f6711ca

View File

@ -261,46 +261,222 @@ xtuner convert merge /root/models/LLM-Research/Meta-Llama-3-8B-Instruct ./hf_lla
## 测试
在EmoLLM的demo文件夹下创建cli_Llama3.py注意这里我们采用本地离线测试offline model在线测试可以上传模型到有关平台后再下载测试
在EmoLLM的demo文件夹下创建`cli_Llama3.py`注意这里我们采用本地离线测试offline model在线测试可以上传模型到有关平台后再下载测试
```python
from transformers import AutoTokenizer, AutoConfig, AddedToken, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from dataclasses import dataclass
from typing import Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from openxlab.model import download
from modelscope import snapshot_download
import os
import copy
# download model in openxlab
# download(model_repo='MrCat/Meta-Llama-3-8B-Instruct',
# output='MrCat/Meta-Llama-3-8B-Instruct')
# model_name_or_path = 'MrCat/Meta-Llama-3-8B-Instruct'
import warnings
# # download model in modelscope
# model_name_or_path = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct',
# cache_dir='LLM-Research/Meta-Llama-3-8B-Instruct')
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)
## 定义聊天模板
@dataclass
class Template:
template_name:str
system_format: str
user_format: str
assistant_format: str
system: str
stop_word: str
template_dict: Dict[str, Template] = dict()
def register_template(template_name, system_format, user_format, assistant_format, system, stop_word=None):
template_dict[template_name] = Template(
template_name=template_name,
system_format=system_format,
user_format=user_format,
assistant_format=assistant_format,
system=system,
stop_word=stop_word,
)
# 这里的系统提示词是训练时使用的,推理时可以自行尝试修改效果
register_template(
template_name='llama3',
system_format='<|begin_of_text|><<SYS>>\n{content}\n<</SYS>>\n\n<|eot_id|>',
user_format='<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>',
assistant_format='<|start_header_id|>assistant<|end_header_id|>\n\n{content}\n', # \n\n{content}<|eot_id|>\n
system="你由EmoLLM团队打造的中文领域心理健康助手, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验,接下来你将只使用中文来回答和咨询问题。",
stop_word='<|eot_id|>'
)
# offline model
model_name_or_path = "/root/EmoLLM/xtuner_config/merged_Llama"
## 加载模型
def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
else:
quantization_config = None
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')
model = model.eval()
# 加载base model
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
load_in_4bit=load_in_4bit,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map='auto',
quantization_config=quantization_config
)
system_prompt = '你由EmoLLM团队打造的中文领域心理健康助手, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验,接下来你将只使用中文来回答和咨询问题。'
# 加载adapter
if adapter_name_or_path is not None:
model = PeftModel.from_pretrained(model, adapter_name_or_path)
messages = [(system_prompt, '')]
return model
print("=============Welcome to InternLM chatbot, type 'exit' to exit.=============")
## 加载tokenzier
def load_tokenizer(model_name_or_path):
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
use_fast=False
)
while True:
input_text = input("User >>> ")
input_text.replace(' ', '')
if input_text == "exit":
break
response, history = model.generate(tokenizer, input_text, history=messages)
messages.append((input_text, response))
print(f"robot >>> {response}")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
## 构建prompt
def build_prompt(tokenizer, template, query, history, system=None):
template_name = template.template_name
system_format = template.system_format
user_format = template.user_format
assistant_format = template.assistant_format
system = system if system is not None else template.system
history.append({"role": 'user', 'message': query})
input_ids = []
# 添加系统信息
if system_format is not None:
if system is not None:
system_text = system_format.format(content=system)
input_ids = tokenizer.encode(system_text, add_special_tokens=False)
# 拼接历史对话
for item in history:
role, message = item['role'], item['message']
if role == 'user':
message = user_format.format(content=message, stop_token=tokenizer.eos_token)
else:
message = assistant_format.format(content=message, stop_token=tokenizer.eos_token)
tokens = tokenizer.encode(message, add_special_tokens=False)
input_ids += tokens
input_ids = torch.tensor([input_ids], dtype=torch.long)
return input_ids
def main():
# download model in openxlab
# download(model_repo='MrCat/Meta-Llama-3-8B-Instruct',
# output='MrCat/Meta-Llama-3-8B-Instruct')
# model_name_or_path = 'MrCat/Meta-Llama-3-8B-Instruct'
# # download model in modelscope
# model_name_or_path = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct',
# cache_dir='LLM-Research/Meta-Llama-3-8B-Instruct')
# offline model
model_name_or_path = "/root/EmoLLM/xtuner_config/merged_Llama3_8b_instruct"
print_user = False # 控制是否输入提示输入框用于notebook时改为True
template_name = 'llama3'
adapter_name_or_path = None
template = template_dict[template_name]
# 若开启4bit推理能够节省很多显存但效果可能下降
load_in_4bit = False
# 生成超参配置,可修改以取得更好的效果
max_new_tokens = 500 # 每次回复时AI生成文本的最大长度
top_p = 0.9
temperature = 0.6 # 越大越有创造性,越小越保守
repetition_penalty = 1.1 # 越大越能避免吐字重复
# 加载模型
print(f'Loading model from: {model_name_or_path}')
print(f'adapter_name_or_path: {adapter_name_or_path}')
model = load_model(
model_name_or_path,
load_in_4bit=load_in_4bit,
adapter_name_or_path=adapter_name_or_path
).eval()
tokenizer = load_tokenizer(model_name_or_path if adapter_name_or_path is None else adapter_name_or_path)
if template.stop_word is None:
template.stop_word = tokenizer.eos_token
stop_token_id = tokenizer.encode(template.stop_word, add_special_tokens=True)
assert len(stop_token_id) == 1
stop_token_id = stop_token_id[0]
print("================================================================================")
print("=============欢迎来到Llama3 EmoLLM 心理咨询室, 输入'exit'退出程序==============")
print("================================================================================")
history = []
print('=======================请输入咨询或聊天内容, 按回车键结束=======================')
print("================================================================================")
print("================================================================================")
print("===============================让我们开启对话吧=================================\n\n")
if print_user:
query = input('用户:')
print("# 用户:{}".format(query))
else:
query = input('# 用户: ')
while True:
if query=='exit':
break
query = query.strip()
input_ids = build_prompt(tokenizer, template, query, copy.deepcopy(history), system=None).to(model.device)
outputs = model.generate(
input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
eos_token_id=stop_token_id, pad_token_id=tokenizer.eos_token_id
)
outputs = outputs.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs)
response = response.strip().replace(template.stop_word, "").strip()
# 存储对话历史
history.append({"role": 'user', 'message': query})
history.append({"role": 'assistant', 'message': response})
# 当对话长度超过6轮时清空最早的对话可自行修改
if len(history) > 12:
history = history[:-12]
print("# Llama3 EmoLLM 心理咨询师:{}".format(response.replace('\n','').replace('<|start_header_id|>','').replace('assistant<|end_header_id|>','')))
print()
query = input('# 用户:')
if print_user:
print("# 用户:{}".format(query))
print("\n\n=============感谢使用Llama3 EmoLLM 心理咨询室, 祝您生活愉快~=============\n\n")
if __name__ == '__main__':
main()
```
执行
@ -310,6 +486,12 @@ cd demo
python cli_Llama3.py
```
执行对话结果如下
![](https://cdn.nlark.com/yuque/0/2024/png/43035260/1713556239463-e0cb78f7-d3ab-40d8-9d08-9e30eb9340a8.png?x-oss-process=image%2Fformat%2Cwebp)
![](https://cdn.nlark.com/yuque/0/2024/png/43035260/1713556239545-e7f4e48c-0738-4d28-a3b0-51b6d281800c.png?x-oss-process=image%2Fformat%2Cwebp)
## 其他
欢迎大家给[Xtuner](https://link.zhihu.com/?target=https%3A//github.com/InternLM/xtuner)和[EmoLLM](https://link.zhihu.com/?target=https%3A//github.com/aJupyter/EmoLLM)点点star~