Dev (#208)
This commit is contained in:
commit
eced39fc81
BIN
assets/EmoLLM.png
Normal file
BIN
assets/EmoLLM.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
BIN
assets/EmoLLM_logo.png
Normal file
BIN
assets/EmoLLM_logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
BIN
assets/EmoLLM_logo_L.png
Normal file
BIN
assets/EmoLLM_logo_L.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 33 KiB |
233
demo/cli_Llama3-8B-Instruct.py
Normal file
233
demo/cli_Llama3-8B-Instruct.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoConfig, AddedToken, AutoModelForCausalLM, BitsAndBytesConfig
|
||||||
|
from peft import PeftModel
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
|
|
||||||
|
## 定义聊天模板
|
||||||
|
@dataclass
|
||||||
|
class Template:
|
||||||
|
template_name:str
|
||||||
|
system_format: str
|
||||||
|
user_format: str
|
||||||
|
assistant_format: str
|
||||||
|
system: str
|
||||||
|
stop_word: str
|
||||||
|
|
||||||
|
template_dict: Dict[str, Template] = dict()
|
||||||
|
|
||||||
|
def register_template(template_name, system_format, user_format, assistant_format, system, stop_word=None):
|
||||||
|
template_dict[template_name] = Template(
|
||||||
|
template_name=template_name,
|
||||||
|
system_format=system_format,
|
||||||
|
user_format=user_format,
|
||||||
|
assistant_format=assistant_format,
|
||||||
|
system=system,
|
||||||
|
stop_word=stop_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 这里的系统提示词是训练时使用的,推理时可以自行尝试修改效果
|
||||||
|
register_template(
|
||||||
|
template_name='llama3',
|
||||||
|
system_format='<|begin_of_text|><system>\n{content}\n<system>\n\n<|eot_id|>',
|
||||||
|
user_format='<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>',
|
||||||
|
assistant_format='<|start_header_id|>assistant<|end_header_id|>\n\n{content}\n', # \n\n{content}<|eot_id|>\n
|
||||||
|
system="你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。",
|
||||||
|
stop_word='<|eot_id|>'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## 加载模型
|
||||||
|
def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
|
||||||
|
if load_in_4bit:
|
||||||
|
quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
quantization_config = None
|
||||||
|
|
||||||
|
# 加载base model
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
# load_in_4bit=load_in_4bit,
|
||||||
|
# # ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time.
|
||||||
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map='auto',
|
||||||
|
quantization_config=quantization_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载adapter
|
||||||
|
if adapter_name_or_path is not None:
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_name_or_path)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
## 加载tokenzier
|
||||||
|
def load_tokenizer(model_name_or_path):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_fast=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
## 构建prompt
|
||||||
|
def build_prompt(tokenizer, template, query, history, system=None):
|
||||||
|
template_name = template.template_name
|
||||||
|
system_format = template.system_format
|
||||||
|
user_format = template.user_format
|
||||||
|
assistant_format = template.assistant_format
|
||||||
|
system = system if system is not None else template.system
|
||||||
|
|
||||||
|
history.append({"role": 'user', 'message': query})
|
||||||
|
input_ids = []
|
||||||
|
|
||||||
|
# 添加系统信息
|
||||||
|
if system_format is not None:
|
||||||
|
if system is not None:
|
||||||
|
system_text = system_format.format(content=system)
|
||||||
|
input_ids = tokenizer.encode(system_text, add_special_tokens=False)
|
||||||
|
# 拼接历史对话
|
||||||
|
for item in history:
|
||||||
|
role, message = item['role'], item['message']
|
||||||
|
if role == 'user':
|
||||||
|
message = user_format.format(content=message, stop_token=tokenizer.eos_token)
|
||||||
|
else:
|
||||||
|
message = assistant_format.format(content=message, stop_token=tokenizer.eos_token)
|
||||||
|
tokens = tokenizer.encode(message, add_special_tokens=False)
|
||||||
|
input_ids += tokens
|
||||||
|
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
# download model in openxlab
|
||||||
|
# download(model_repo='MrCat/Meta-Llama-3-8B-Instruct',
|
||||||
|
# output='MrCat/Meta-Llama-3-8B-Instruct')
|
||||||
|
# model_name_or_path = 'MrCat/Meta-Llama-3-8B-Instruct'
|
||||||
|
|
||||||
|
# # download model in modelscope
|
||||||
|
# model_name_or_path = snapshot_download('LLM-Research/Meta-Llama-3-8B-Instruct',
|
||||||
|
# cache_dir='LLM-Research/Meta-Llama-3-8B-Instruct')
|
||||||
|
|
||||||
|
# # offline model
|
||||||
|
# model_name_or_path = '/root/share/new_models/meta-llama/Meta-Llama-3-8B-Instruct'
|
||||||
|
# adapter_name_or_path = None
|
||||||
|
|
||||||
|
# model_name_or_path = "xtuner_config/merged_Llama3_8b_instruct_e3"
|
||||||
|
# adapter_name_or_path = 'xtuner_config/hf_llama3_e1_sc2'
|
||||||
|
|
||||||
|
# model_name_or_path = "xtuner_config/merged_Llama3_8b_instruct_e1_sc"
|
||||||
|
# adapter_name_or_path = None
|
||||||
|
|
||||||
|
print_user = False # 控制是否输入提示输入框,用于notebook时,改为True
|
||||||
|
|
||||||
|
template_name = 'llama3'
|
||||||
|
|
||||||
|
|
||||||
|
template = template_dict[template_name]
|
||||||
|
|
||||||
|
# 若开启4bit推理能够节省很多显存,但效果可能下降
|
||||||
|
load_in_4bit = False # True # 6291MiB
|
||||||
|
|
||||||
|
# 生成超参配置,可修改以取得更好的效果
|
||||||
|
max_new_tokens = 500 # 每次回复时,AI生成文本的最大长度
|
||||||
|
top_p = 0.9
|
||||||
|
temperature = 0.6 # 越大越有创造性,越小越保守
|
||||||
|
repetition_penalty = 1.1 # 越大越能避免吐字重复
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
print(f'Loading model from: {model_name_or_path}')
|
||||||
|
print(f'adapter_name_or_path: {adapter_name_or_path}')
|
||||||
|
model = load_model(
|
||||||
|
model_name_or_path,
|
||||||
|
load_in_4bit=load_in_4bit,
|
||||||
|
adapter_name_or_path=adapter_name_or_path
|
||||||
|
).eval()
|
||||||
|
tokenizer = load_tokenizer(model_name_or_path if adapter_name_or_path is None else adapter_name_or_path)
|
||||||
|
if template.stop_word is None:
|
||||||
|
template.stop_word = tokenizer.eos_token
|
||||||
|
stop_token_id = tokenizer.encode(template.stop_word, add_special_tokens=True)
|
||||||
|
assert len(stop_token_id) == 1
|
||||||
|
stop_token_id = stop_token_id[0]
|
||||||
|
|
||||||
|
|
||||||
|
print("================================================================================")
|
||||||
|
print("=============欢迎来到Llama3 EmoLLM 心理咨询室, 输入'exit'退出程序===============")
|
||||||
|
print("================================================================================")
|
||||||
|
history = []
|
||||||
|
|
||||||
|
print("============请输入聊天内容, 按回车键结束输入, 输入'clear'清空聊天信息==============")
|
||||||
|
print("================================================================================")
|
||||||
|
print("================================================================================")
|
||||||
|
print("===============================让我们开启对话吧=================================\n\n")
|
||||||
|
if print_user:
|
||||||
|
query = input('用户:')
|
||||||
|
print("# 用户:{}".format(query))
|
||||||
|
else:
|
||||||
|
query = input('# 用户: ')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if query=='exit':
|
||||||
|
break
|
||||||
|
if query.strip() == "clear":
|
||||||
|
history = []
|
||||||
|
print("\n\n=============欢迎来到Llama3 EmoLLM 心理咨询室, 输入'exit'退出程序===============")
|
||||||
|
print("============请输入聊天内容, 按回车键结束输入, 输入'clear'清空聊天信息===========")
|
||||||
|
print("================================================================================")
|
||||||
|
print("================================================================================")
|
||||||
|
if print_user:
|
||||||
|
query = input('用户:')
|
||||||
|
print("# 用户:{}".format(query))
|
||||||
|
else:
|
||||||
|
query = input('# 用户: ')
|
||||||
|
continue
|
||||||
|
|
||||||
|
query = query.strip()
|
||||||
|
input_ids = build_prompt(tokenizer, template, query, copy.deepcopy(history), system=None).to(model.device)
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
|
||||||
|
top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
|
||||||
|
eos_token_id=stop_token_id, pad_token_id=tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
outputs = outputs.tolist()[0][len(input_ids[0]):]
|
||||||
|
response = tokenizer.decode(outputs)
|
||||||
|
response = response.strip().replace(template.stop_word, "").strip()
|
||||||
|
|
||||||
|
# 存储对话历史
|
||||||
|
history.append({"role": 'user', 'message': query})
|
||||||
|
history.append({"role": 'assistant', 'message': response})
|
||||||
|
|
||||||
|
# 当对话长度超过6轮时,清空最早的对话,可自行修改
|
||||||
|
if len(history) > 12:
|
||||||
|
history = history[:-12]
|
||||||
|
|
||||||
|
print("# Llama3 EmoLLM 心理咨询师:{}".format(response.replace('\n','').replace('<|start_header_id|>','').replace('assistant<|end_header_id|>','').replace('>','')))
|
||||||
|
print()
|
||||||
|
query = input('# 用户:')
|
||||||
|
if print_user:
|
||||||
|
print("# 用户:{}".format(query))
|
||||||
|
print("\n\n===============感谢使用Llama3 EmoLLM 心理咨询室, 祝您生活愉快~===============\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -72,8 +72,12 @@ OpenXLab浦源 内容平台 是面向 AI 研究员和开发者提供 AI 领域
|
|||||||
### 1. 安装git lfs
|
### 1. 安装git lfs
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh
|
# 更好的方法
|
||||||
apt install git-lfs
|
conda install git-lfs
|
||||||
|
|
||||||
|
# 旧方法
|
||||||
|
# curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh
|
||||||
|
# apt install git-lfs
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. 配置git和lfs
|
### 2. 配置git和lfs
|
||||||
|
369
web_demo-Llama3_online.py
Normal file
369
web_demo-Llama3_online.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
|
||||||
|
# isort: skip_file
|
||||||
|
import copy
|
||||||
|
import warnings
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers.generation.utils import (LogitsProcessorList,
|
||||||
|
StoppingCriteriaList)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
online = True
|
||||||
|
if online:
|
||||||
|
from openxlab.model import download
|
||||||
|
download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
|
||||||
|
output='model')
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationConfig:
|
||||||
|
# this config is used for chat to provide more diversity
|
||||||
|
max_length: int = 500
|
||||||
|
top_p: float = 0.9
|
||||||
|
temperature: float = 0.6
|
||||||
|
do_sample: bool = True
|
||||||
|
repetition_penalty: float = 1.1
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_interactive(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
|
||||||
|
List[int]]] = None,
|
||||||
|
additional_eos_token_id: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
inputs = tokenizer([prompt], return_tensors='pt')
|
||||||
|
input_length = len(inputs['input_ids'][0])
|
||||||
|
for k, v in inputs.items():
|
||||||
|
inputs[k] = v.cuda()
|
||||||
|
input_ids = inputs['input_ids']
|
||||||
|
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
||||||
|
if generation_config is None:
|
||||||
|
generation_config = model.generation_config
|
||||||
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
|
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||||
|
generation_config.bos_token_id,
|
||||||
|
generation_config.eos_token_id,
|
||||||
|
)
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
if additional_eos_token_id is not None:
|
||||||
|
eos_token_id.append(additional_eos_token_id)
|
||||||
|
has_default_max_length = kwargs.get(
|
||||||
|
'max_length') is None and generation_config.max_length is not None
|
||||||
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
|
warnings.warn(
|
||||||
|
f"Using 'max_length''s default ({repr(generation_config.max_length)}) \
|
||||||
|
to control the generation length. "
|
||||||
|
'This behaviour is deprecated and will be removed from the \
|
||||||
|
config in v5 of Transformers -- we'
|
||||||
|
' recommend using `max_new_tokens` to control the maximum \
|
||||||
|
length of the generation.',
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
elif generation_config.max_new_tokens is not None:
|
||||||
|
generation_config.max_length = generation_config.max_new_tokens + \
|
||||||
|
input_ids_seq_length
|
||||||
|
if not has_default_max_length:
|
||||||
|
logger.warn( # pylint: disable=W4902
|
||||||
|
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
|
||||||
|
f"and 'max_length'(={generation_config.max_length}) seem to "
|
||||||
|
"have been set. 'max_new_tokens' will take precedence. "
|
||||||
|
'Please refer to the documentation for more information. '
|
||||||
|
'(https://huggingface.co/docs/transformers/main/'
|
||||||
|
'en/main_classes/text_generation)',
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_ids_seq_length >= generation_config.max_length:
|
||||||
|
input_ids_string = 'input_ids'
|
||||||
|
logger.warning(
|
||||||
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, "
|
||||||
|
f"but 'max_length' is set to {generation_config.max_length}. "
|
||||||
|
'This can lead to unexpected behavior. You should consider'
|
||||||
|
" increasing 'max_new_tokens'.")
|
||||||
|
|
||||||
|
# 2. Set generation parameters if not already defined
|
||||||
|
logits_processor = logits_processor if logits_processor is not None \
|
||||||
|
else LogitsProcessorList()
|
||||||
|
stopping_criteria = stopping_criteria if stopping_criteria is not None \
|
||||||
|
else StoppingCriteriaList()
|
||||||
|
|
||||||
|
logits_processor = model._get_logits_processor(
|
||||||
|
generation_config=generation_config,
|
||||||
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
|
encoder_input_ids=input_ids,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
stopping_criteria = model._get_stopping_criteria(
|
||||||
|
generation_config=generation_config,
|
||||||
|
stopping_criteria=stopping_criteria)
|
||||||
|
logits_warper = model._get_logits_warper(generation_config)
|
||||||
|
|
||||||
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
scores = None
|
||||||
|
while True:
|
||||||
|
model_inputs = model.prepare_inputs_for_generation(
|
||||||
|
input_ids, **model_kwargs)
|
||||||
|
# forward pass to get next token
|
||||||
|
outputs = model(
|
||||||
|
**model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# pre-process distribution
|
||||||
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
|
# sample
|
||||||
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
if generation_config.do_sample:
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
else:
|
||||||
|
next_tokens = torch.argmax(probs, dim=-1)
|
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step
|
||||||
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
model_kwargs = model._update_model_kwargs_for_generation(
|
||||||
|
outputs, model_kwargs, is_encoder_decoder=False)
|
||||||
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
(min(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
|
output_token_ids = input_ids[0].cpu().tolist()
|
||||||
|
output_token_ids = output_token_ids[input_length:]
|
||||||
|
for each_eos_token_id in eos_token_id:
|
||||||
|
if output_token_ids[-1] == each_eos_token_id:
|
||||||
|
output_token_ids = output_token_ids[:-1]
|
||||||
|
response = tokenizer.decode(output_token_ids)
|
||||||
|
|
||||||
|
yield response
|
||||||
|
# stop when each sentence is finished
|
||||||
|
# or if we exceed the maximum length
|
||||||
|
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||||
|
input_ids, scores):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def on_btn_click():
|
||||||
|
del st.session_state.messages
|
||||||
|
|
||||||
|
|
||||||
|
# @st.cache_resource
|
||||||
|
# def load_model(arg1):
|
||||||
|
# # model = AutoModelForCausalLM.from_pretrained(args.m).cuda()
|
||||||
|
# # tokenizer = AutoTokenizer.from_pretrained(args.m, trust_remote_code=True)
|
||||||
|
# model = AutoModelForCausalLM.from_pretrained(arg1, torch_dtype=torch.float16).cuda()
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained(arg1, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
||||||
|
# return model, tokenizer
|
||||||
|
@st.cache_resource
|
||||||
|
def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
|
||||||
|
if load_in_4bit:
|
||||||
|
quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
quantization_config = None
|
||||||
|
|
||||||
|
# 加载base model
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name_or_path,
|
||||||
|
# load_in_4bit=load_in_4bit,
|
||||||
|
# # ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time.
|
||||||
|
trust_remote_code=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map='auto',
|
||||||
|
quantization_config=quantization_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载adapter
|
||||||
|
if adapter_name_or_path is not None:
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_name_or_path)
|
||||||
|
|
||||||
|
## 加载tokenzier
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name_or_path if adapter_name_or_path is None else adapter_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
use_fast=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_config():
|
||||||
|
with st.sidebar:
|
||||||
|
|
||||||
|
# 使用 Streamlit 的 markdown 函数添加 Markdown 文本
|
||||||
|
st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True)
|
||||||
|
st.markdown("[访问 **EmoLLM** 官方repo: **SmartFlowAI/EmoLLM**](https://github.com/SmartFlowAI/EmoLLM)")
|
||||||
|
|
||||||
|
max_length = st.slider('Max Length',
|
||||||
|
min_value=8,
|
||||||
|
max_value=8192,
|
||||||
|
value=500)
|
||||||
|
top_p = st.slider('Top P', 0.0, 1.0, 0.9, step=0.01)
|
||||||
|
temperature = st.slider('Temperature', 0.0, 1.0, 0.6, step=0.01)
|
||||||
|
repetition_penalty = st.slider('Repetition penalty', 0.0, 1.5, 1.1, step=0.01)
|
||||||
|
st.button('Clear Chat History', on_click=on_btn_click)
|
||||||
|
|
||||||
|
generation_config = GenerationConfig(max_length=max_length,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
do_sample=True)
|
||||||
|
|
||||||
|
return generation_config
|
||||||
|
|
||||||
|
|
||||||
|
user_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>'
|
||||||
|
robot_prompt = '<|start_header_id|>assistant<|end_header_id|>\n\n{robot}<|eot_id|>'
|
||||||
|
cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
|
||||||
|
|
||||||
|
|
||||||
|
def combine_history(prompt):
|
||||||
|
messages = st.session_state.messages
|
||||||
|
|
||||||
|
meta_instruction = (
|
||||||
|
"你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。。"
|
||||||
|
)
|
||||||
|
total_prompt = f"<|start_header_id|>system<|end_header_id|>\n{meta_instruction}<|eot_id|>\n"
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
cur_content = message['content']
|
||||||
|
if message['role'] == 'user':
|
||||||
|
cur_prompt = user_prompt.format(user=cur_content)
|
||||||
|
elif message['role'] == 'robot':
|
||||||
|
cur_prompt = robot_prompt.format(robot=cur_content)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
total_prompt += cur_prompt
|
||||||
|
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
|
||||||
|
return total_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def main(arg1):
|
||||||
|
|
||||||
|
|
||||||
|
if online:
|
||||||
|
model_name_or_path = 'model'
|
||||||
|
adapter_name_or_path = None
|
||||||
|
else:
|
||||||
|
# model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3"
|
||||||
|
# adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2'
|
||||||
|
|
||||||
|
model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc"
|
||||||
|
adapter_name_or_path = None
|
||||||
|
|
||||||
|
# 若开启4bit推理能够节省很多显存,但效果可能下降
|
||||||
|
load_in_4bit = False # True # 6291MiB
|
||||||
|
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
print('load model begin.')
|
||||||
|
# 加载模型
|
||||||
|
print(f'Loading model from: {model_name_or_path}')
|
||||||
|
print(f'adapter_name_or_path: {adapter_name_or_path}')
|
||||||
|
# model, tokenizer = load_model(arg1)
|
||||||
|
model, tokenizer = load_model(
|
||||||
|
arg1 if arg1 is not None else model_name_or_path,
|
||||||
|
load_in_4bit=load_in_4bit,
|
||||||
|
adapter_name_or_path=adapter_name_or_path
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
print('load model end.')
|
||||||
|
|
||||||
|
user_avator = "assets/user.png"
|
||||||
|
robot_avator = "assets/EmoLLM.png"
|
||||||
|
|
||||||
|
st.title('EmoLLM Llama3心理咨询室V2.0')
|
||||||
|
|
||||||
|
generation_config = prepare_generation_config()
|
||||||
|
|
||||||
|
# Initialize chat history
|
||||||
|
if 'messages' not in st.session_state:
|
||||||
|
st.session_state.messages = []
|
||||||
|
|
||||||
|
# Display chat messages from history on app rerun
|
||||||
|
for message in st.session_state.messages:
|
||||||
|
with st.chat_message(message['role'], avatar=message.get("avatar")):
|
||||||
|
st.markdown(message['content'])
|
||||||
|
|
||||||
|
# Accept user input
|
||||||
|
if prompt := st.chat_input('你好,欢迎来到Llama3 EmoLLM 心理咨询室'):
|
||||||
|
# Display user message in chat message container
|
||||||
|
with st.chat_message('user', avatar=user_avator):
|
||||||
|
st.markdown(prompt)
|
||||||
|
real_prompt = combine_history(prompt)
|
||||||
|
# Add user message to chat history
|
||||||
|
st.session_state.messages.append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': prompt,
|
||||||
|
'avatar': user_avator
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True)
|
||||||
|
# assert len(stop_token_id) == 1
|
||||||
|
# stop_token_id = stop_token_id[0]
|
||||||
|
|
||||||
|
with st.chat_message('robot', avatar=robot_avator):
|
||||||
|
message_placeholder = st.empty()
|
||||||
|
for cur_response in generate_interactive(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=real_prompt,
|
||||||
|
additional_eos_token_id=128009, # <|eot_id|>
|
||||||
|
eos_token_id=128009,
|
||||||
|
pad_token_id=128009,
|
||||||
|
**asdict(generation_config),
|
||||||
|
):
|
||||||
|
# Display robot response in chat message container
|
||||||
|
message_placeholder.markdown(cur_response + '▌')
|
||||||
|
message_placeholder.markdown(cur_response)
|
||||||
|
# Add robot response to chat history
|
||||||
|
st.session_state.messages.append({
|
||||||
|
'role': 'robot',
|
||||||
|
'content': cur_response, # pylint: disable=undefined-loop-variable
|
||||||
|
"avatar": robot_avator,
|
||||||
|
})
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
import sys
|
||||||
|
arg1 = sys.argv[1]
|
||||||
|
main(arg1)
|
@ -1,219 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from mmengine.dataset import DefaultSampler
|
|
||||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
|
||||||
LoggerHook, ParamSchedulerHook)
|
|
||||||
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
|
||||||
from peft import LoraConfig
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|
||||||
BitsAndBytesConfig)
|
|
||||||
|
|
||||||
from xtuner.dataset import process_hf_dataset
|
|
||||||
from xtuner.dataset.collate_fns import default_collate_fn
|
|
||||||
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
|
||||||
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
|
||||||
VarlenAttnArgsToMessageHubHook)
|
|
||||||
from xtuner.engine.runner import TrainLoop
|
|
||||||
from xtuner.model import SupervisedFinetune
|
|
||||||
from xtuner.parallel.sequence import SequenceParallelSampler
|
|
||||||
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 1 Settings #
|
|
||||||
#######################################################################
|
|
||||||
# Model
|
|
||||||
pretrained_model_name_or_path = 'meta-llama/Meta-Llama-3-8B-Instruct'
|
|
||||||
use_varlen_attn = False
|
|
||||||
|
|
||||||
# Data
|
|
||||||
alpaca_en_path = 'tatsu-lab/alpaca'
|
|
||||||
prompt_template = PROMPT_TEMPLATE.llama3_chat
|
|
||||||
max_length = 512
|
|
||||||
pack_to_max_length = True
|
|
||||||
|
|
||||||
# parallel
|
|
||||||
sequence_parallel_size = 1
|
|
||||||
|
|
||||||
# Scheduler & Optimizer
|
|
||||||
batch_size = 1 # per_device
|
|
||||||
accumulative_counts = 16
|
|
||||||
accumulative_counts *= sequence_parallel_size
|
|
||||||
dataloader_num_workers = 0
|
|
||||||
max_epochs = 3
|
|
||||||
optim_type = AdamW
|
|
||||||
lr = 2e-4
|
|
||||||
betas = (0.9, 0.999)
|
|
||||||
weight_decay = 0
|
|
||||||
max_norm = 1 # grad clip
|
|
||||||
warmup_ratio = 0.03
|
|
||||||
|
|
||||||
# Save
|
|
||||||
save_steps = 500
|
|
||||||
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
|
||||||
|
|
||||||
# Evaluate the generation performance during the training
|
|
||||||
evaluation_freq = 500
|
|
||||||
SYSTEM = SYSTEM_TEMPLATE.alpaca
|
|
||||||
evaluation_inputs = [
|
|
||||||
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
|
|
||||||
]
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 2 Model & Tokenizer #
|
|
||||||
#######################################################################
|
|
||||||
tokenizer = dict(
|
|
||||||
type=AutoTokenizer.from_pretrained,
|
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
padding_side='right')
|
|
||||||
|
|
||||||
model = dict(
|
|
||||||
type=SupervisedFinetune,
|
|
||||||
use_varlen_attn=use_varlen_attn,
|
|
||||||
llm=dict(
|
|
||||||
type=AutoModelForCausalLM.from_pretrained,
|
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
quantization_config=dict(
|
|
||||||
type=BitsAndBytesConfig,
|
|
||||||
load_in_4bit=True,
|
|
||||||
load_in_8bit=False,
|
|
||||||
llm_int8_threshold=6.0,
|
|
||||||
llm_int8_has_fp16_weight=False,
|
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type='nf4')),
|
|
||||||
lora=dict(
|
|
||||||
type=LoraConfig,
|
|
||||||
r=16,
|
|
||||||
lora_alpha=16,
|
|
||||||
lora_dropout=0.1,
|
|
||||||
bias='none',
|
|
||||||
task_type='CAUSAL_LM'))
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 3 Dataset & Dataloader #
|
|
||||||
#######################################################################
|
|
||||||
alpaca_en = dict(
|
|
||||||
type=process_hf_dataset,
|
|
||||||
dataset=dict(type=load_dataset, path=alpaca_en_path),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_length=max_length,
|
|
||||||
dataset_map_fn=alpaca_map_fn,
|
|
||||||
template_map_fn=dict(
|
|
||||||
type=template_map_fn_factory, template=prompt_template),
|
|
||||||
remove_unused_columns=True,
|
|
||||||
shuffle_before_pack=True,
|
|
||||||
pack_to_max_length=pack_to_max_length,
|
|
||||||
use_varlen_attn=use_varlen_attn)
|
|
||||||
|
|
||||||
sampler = SequenceParallelSampler \
|
|
||||||
if sequence_parallel_size > 1 else DefaultSampler
|
|
||||||
train_dataloader = dict(
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=dataloader_num_workers,
|
|
||||||
dataset=alpaca_en,
|
|
||||||
sampler=dict(type=sampler, shuffle=True),
|
|
||||||
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 4 Scheduler & Optimizer #
|
|
||||||
#######################################################################
|
|
||||||
# optimizer
|
|
||||||
optim_wrapper = dict(
|
|
||||||
type=AmpOptimWrapper,
|
|
||||||
optimizer=dict(
|
|
||||||
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
|
||||||
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
|
||||||
accumulative_counts=accumulative_counts,
|
|
||||||
loss_scale='dynamic',
|
|
||||||
dtype='float16')
|
|
||||||
|
|
||||||
# learning policy
|
|
||||||
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
|
||||||
param_scheduler = [
|
|
||||||
dict(
|
|
||||||
type=LinearLR,
|
|
||||||
start_factor=1e-5,
|
|
||||||
by_epoch=True,
|
|
||||||
begin=0,
|
|
||||||
end=warmup_ratio * max_epochs,
|
|
||||||
convert_to_iter_based=True),
|
|
||||||
dict(
|
|
||||||
type=CosineAnnealingLR,
|
|
||||||
eta_min=0.0,
|
|
||||||
by_epoch=True,
|
|
||||||
begin=warmup_ratio * max_epochs,
|
|
||||||
end=max_epochs,
|
|
||||||
convert_to_iter_based=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
# train, val, test setting
|
|
||||||
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 5 Runtime #
|
|
||||||
#######################################################################
|
|
||||||
# Log the dialogue periodically during the training process, optional
|
|
||||||
custom_hooks = [
|
|
||||||
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
|
||||||
dict(
|
|
||||||
type=EvaluateChatHook,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
every_n_iters=evaluation_freq,
|
|
||||||
evaluation_inputs=evaluation_inputs,
|
|
||||||
system=SYSTEM,
|
|
||||||
prompt_template=prompt_template)
|
|
||||||
]
|
|
||||||
|
|
||||||
if use_varlen_attn:
|
|
||||||
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
|
||||||
|
|
||||||
# configure default hooks
|
|
||||||
default_hooks = dict(
|
|
||||||
# record the time of every iteration.
|
|
||||||
timer=dict(type=IterTimerHook),
|
|
||||||
# print log every 10 iterations.
|
|
||||||
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
|
||||||
# enable the parameter scheduler.
|
|
||||||
param_scheduler=dict(type=ParamSchedulerHook),
|
|
||||||
# save checkpoint per `save_steps`.
|
|
||||||
checkpoint=dict(
|
|
||||||
type=CheckpointHook,
|
|
||||||
by_epoch=False,
|
|
||||||
interval=save_steps,
|
|
||||||
max_keep_ckpts=save_total_limit),
|
|
||||||
# set sampler seed in distributed evrionment.
|
|
||||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
|
||||||
)
|
|
||||||
|
|
||||||
# configure environment
|
|
||||||
env_cfg = dict(
|
|
||||||
# whether to enable cudnn benchmark
|
|
||||||
cudnn_benchmark=False,
|
|
||||||
# set multi process parameters
|
|
||||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
|
||||||
# set distributed parameters
|
|
||||||
dist_cfg=dict(backend='nccl'),
|
|
||||||
)
|
|
||||||
|
|
||||||
# set visualizer
|
|
||||||
visualizer = None
|
|
||||||
|
|
||||||
# set log level
|
|
||||||
log_level = 'INFO'
|
|
||||||
|
|
||||||
# load from which checkpoint
|
|
||||||
load_from = None
|
|
||||||
|
|
||||||
# whether to resume training from the loaded checkpoint
|
|
||||||
resume = False
|
|
||||||
|
|
||||||
# Defaults to use random seed and disable `deterministic`
|
|
||||||
randomness = dict(seed=None, deterministic=False)
|
|
||||||
|
|
||||||
# set log processor
|
|
||||||
log_processor = dict(by_epoch=False)
|
|
@ -1,219 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from mmengine.dataset import DefaultSampler
|
|
||||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
|
||||||
LoggerHook, ParamSchedulerHook)
|
|
||||||
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
|
||||||
from peft import LoraConfig
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|
||||||
BitsAndBytesConfig)
|
|
||||||
|
|
||||||
from xtuner.dataset import process_hf_dataset
|
|
||||||
from xtuner.dataset.collate_fns import default_collate_fn
|
|
||||||
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
|
||||||
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
|
||||||
VarlenAttnArgsToMessageHubHook)
|
|
||||||
from xtuner.engine.runner import TrainLoop
|
|
||||||
from xtuner.model import SupervisedFinetune
|
|
||||||
from xtuner.parallel.sequence import SequenceParallelSampler
|
|
||||||
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 1 Settings #
|
|
||||||
#######################################################################
|
|
||||||
# Model
|
|
||||||
pretrained_model_name_or_path = 'meta-llama/Meta-Llama-3-8B-Instruct'
|
|
||||||
use_varlen_attn = False
|
|
||||||
|
|
||||||
# Data
|
|
||||||
alpaca_en_path = 'tatsu-lab/alpaca'
|
|
||||||
prompt_template = PROMPT_TEMPLATE.llama3_chat
|
|
||||||
max_length = 8192
|
|
||||||
pack_to_max_length = True
|
|
||||||
|
|
||||||
# parallel
|
|
||||||
sequence_parallel_size = 1
|
|
||||||
|
|
||||||
# Scheduler & Optimizer
|
|
||||||
batch_size = 1 # per_device
|
|
||||||
accumulative_counts = 16
|
|
||||||
accumulative_counts *= sequence_parallel_size
|
|
||||||
dataloader_num_workers = 0
|
|
||||||
max_epochs = 3
|
|
||||||
optim_type = AdamW
|
|
||||||
lr = 2e-4
|
|
||||||
betas = (0.9, 0.999)
|
|
||||||
weight_decay = 0
|
|
||||||
max_norm = 1 # grad clip
|
|
||||||
warmup_ratio = 0.03
|
|
||||||
|
|
||||||
# Save
|
|
||||||
save_steps = 500
|
|
||||||
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
|
||||||
|
|
||||||
# Evaluate the generation performance during the training
|
|
||||||
evaluation_freq = 500
|
|
||||||
SYSTEM = SYSTEM_TEMPLATE.alpaca
|
|
||||||
evaluation_inputs = [
|
|
||||||
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
|
|
||||||
]
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 2 Model & Tokenizer #
|
|
||||||
#######################################################################
|
|
||||||
tokenizer = dict(
|
|
||||||
type=AutoTokenizer.from_pretrained,
|
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
padding_side='right')
|
|
||||||
|
|
||||||
model = dict(
|
|
||||||
type=SupervisedFinetune,
|
|
||||||
use_varlen_attn=use_varlen_attn,
|
|
||||||
llm=dict(
|
|
||||||
type=AutoModelForCausalLM.from_pretrained,
|
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
quantization_config=dict(
|
|
||||||
type=BitsAndBytesConfig,
|
|
||||||
load_in_4bit=True,
|
|
||||||
load_in_8bit=False,
|
|
||||||
llm_int8_threshold=6.0,
|
|
||||||
llm_int8_has_fp16_weight=False,
|
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type='nf4')),
|
|
||||||
lora=dict(
|
|
||||||
type=LoraConfig,
|
|
||||||
r=16,
|
|
||||||
lora_alpha=16,
|
|
||||||
lora_dropout=0.1,
|
|
||||||
bias='none',
|
|
||||||
task_type='CAUSAL_LM'))
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 3 Dataset & Dataloader #
|
|
||||||
#######################################################################
|
|
||||||
alpaca_en = dict(
|
|
||||||
type=process_hf_dataset,
|
|
||||||
dataset=dict(type=load_dataset, path=alpaca_en_path),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_length=max_length,
|
|
||||||
dataset_map_fn=alpaca_map_fn,
|
|
||||||
template_map_fn=dict(
|
|
||||||
type=template_map_fn_factory, template=prompt_template),
|
|
||||||
remove_unused_columns=True,
|
|
||||||
shuffle_before_pack=True,
|
|
||||||
pack_to_max_length=pack_to_max_length,
|
|
||||||
use_varlen_attn=use_varlen_attn)
|
|
||||||
|
|
||||||
sampler = SequenceParallelSampler \
|
|
||||||
if sequence_parallel_size > 1 else DefaultSampler
|
|
||||||
train_dataloader = dict(
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=dataloader_num_workers,
|
|
||||||
dataset=alpaca_en,
|
|
||||||
sampler=dict(type=sampler, shuffle=True),
|
|
||||||
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 4 Scheduler & Optimizer #
|
|
||||||
#######################################################################
|
|
||||||
# optimizer
|
|
||||||
optim_wrapper = dict(
|
|
||||||
type=AmpOptimWrapper,
|
|
||||||
optimizer=dict(
|
|
||||||
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
|
||||||
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
|
||||||
accumulative_counts=accumulative_counts,
|
|
||||||
loss_scale='dynamic',
|
|
||||||
dtype='float16')
|
|
||||||
|
|
||||||
# learning policy
|
|
||||||
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
|
||||||
param_scheduler = [
|
|
||||||
dict(
|
|
||||||
type=LinearLR,
|
|
||||||
start_factor=1e-5,
|
|
||||||
by_epoch=True,
|
|
||||||
begin=0,
|
|
||||||
end=warmup_ratio * max_epochs,
|
|
||||||
convert_to_iter_based=True),
|
|
||||||
dict(
|
|
||||||
type=CosineAnnealingLR,
|
|
||||||
eta_min=0.0,
|
|
||||||
by_epoch=True,
|
|
||||||
begin=warmup_ratio * max_epochs,
|
|
||||||
end=max_epochs,
|
|
||||||
convert_to_iter_based=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
# train, val, test setting
|
|
||||||
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
|
||||||
|
|
||||||
#######################################################################
|
|
||||||
# PART 5 Runtime #
|
|
||||||
#######################################################################
|
|
||||||
# Log the dialogue periodically during the training process, optional
|
|
||||||
custom_hooks = [
|
|
||||||
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
|
||||||
dict(
|
|
||||||
type=EvaluateChatHook,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
every_n_iters=evaluation_freq,
|
|
||||||
evaluation_inputs=evaluation_inputs,
|
|
||||||
system=SYSTEM,
|
|
||||||
prompt_template=prompt_template)
|
|
||||||
]
|
|
||||||
|
|
||||||
if use_varlen_attn:
|
|
||||||
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
|
||||||
|
|
||||||
# configure default hooks
|
|
||||||
default_hooks = dict(
|
|
||||||
# record the time of every iteration.
|
|
||||||
timer=dict(type=IterTimerHook),
|
|
||||||
# print log every 10 iterations.
|
|
||||||
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
|
||||||
# enable the parameter scheduler.
|
|
||||||
param_scheduler=dict(type=ParamSchedulerHook),
|
|
||||||
# save checkpoint per `save_steps`.
|
|
||||||
checkpoint=dict(
|
|
||||||
type=CheckpointHook,
|
|
||||||
by_epoch=False,
|
|
||||||
interval=save_steps,
|
|
||||||
max_keep_ckpts=save_total_limit),
|
|
||||||
# set sampler seed in distributed evrionment.
|
|
||||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
|
||||||
)
|
|
||||||
|
|
||||||
# configure environment
|
|
||||||
env_cfg = dict(
|
|
||||||
# whether to enable cudnn benchmark
|
|
||||||
cudnn_benchmark=False,
|
|
||||||
# set multi process parameters
|
|
||||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
|
||||||
# set distributed parameters
|
|
||||||
dist_cfg=dict(backend='nccl'),
|
|
||||||
)
|
|
||||||
|
|
||||||
# set visualizer
|
|
||||||
visualizer = None
|
|
||||||
|
|
||||||
# set log level
|
|
||||||
log_level = 'INFO'
|
|
||||||
|
|
||||||
# load from which checkpoint
|
|
||||||
load_from = None
|
|
||||||
|
|
||||||
# whether to resume training from the loaded checkpoint
|
|
||||||
resume = False
|
|
||||||
|
|
||||||
# Defaults to use random seed and disable `deterministic`
|
|
||||||
randomness = dict(seed=None, deterministic=False)
|
|
||||||
|
|
||||||
# set log processor
|
|
||||||
log_processor = dict(by_epoch=False)
|
|
Loading…
Reference in New Issue
Block a user