feat: Update Aiwei configuration.
This commit is contained in:
commit
c696e163cd
@ -60,7 +60,7 @@
|
|||||||
- 评估和诊断工具:为了有效促进心理健康,需要有科学的工具来评估个体的心理状态,以及诊断可能存在的心理问题。
|
- 评估和诊断工具:为了有效促进心理健康,需要有科学的工具来评估个体的心理状态,以及诊断可能存在的心理问题。
|
||||||
|
|
||||||
### 最近更新
|
### 最近更新
|
||||||
|
- 【2024.2.23】推出基于InternLM2_7B_chat_qlora的 `温柔御姐心理医生艾薇`,[点击获取模型权重](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei),[配置文件](xtuner_config/aiwei-internlm2_chat_7b_qlora.py)
|
||||||
- 【2024.2.23】更新[若干微调配置](/xtuner_config/),新增 [data_pro.json](/datasets/data_pro.json)(数量更多、场景更全、更丰富)和 [aiwei.json](/datasets/aiwei.json)(温柔御姐角色扮演专用,带有Emoji表情),即将推出 `温柔御姐心理医生艾薇`
|
- 【2024.2.23】更新[若干微调配置](/xtuner_config/),新增 [data_pro.json](/datasets/data_pro.json)(数量更多、场景更全、更丰富)和 [aiwei.json](/datasets/aiwei.json)(温柔御姐角色扮演专用,带有Emoji表情),即将推出 `温柔御姐心理医生艾薇`
|
||||||
- 【2024.2.18】 [基于Qwen1_5-0_5B-Chat全量微调版本开源](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary),算力有限的道友可以玩起来~
|
- 【2024.2.18】 [基于Qwen1_5-0_5B-Chat全量微调版本开源](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary),算力有限的道友可以玩起来~
|
||||||
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
|
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
|
||||||
|
3
app.py
3
app.py
@ -1,2 +1,3 @@
|
|||||||
import os
|
import os
|
||||||
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
# os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
||||||
|
os.system('streamlit run web_demo-aiwei.py --server.address=0.0.0.0 --server.port 7860')
|
||||||
|
BIN
assets/aiwei_logo.jpg
Normal file
BIN
assets/aiwei_logo.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 80 KiB |
267
web_demo-aiwei.py
Normal file
267
web_demo-aiwei.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
|
||||||
|
We mainly modified part of the code logic to adapt to the generation of our model.
|
||||||
|
Please refer to these links below for more information:
|
||||||
|
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
||||||
|
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
|
||||||
|
3. transformers: https://github.com/huggingface/transformers
|
||||||
|
Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`.
|
||||||
|
Using `python path/to/web_demo.py` may cause unknown problems.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import warnings
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
||||||
|
from openxlab.model import download
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
download(model_repo='ajupyter/EmoLLM_aiwei',
|
||||||
|
output='model')
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationConfig:
|
||||||
|
# this config is used for chat to provide more diversity
|
||||||
|
max_length: int = 32768
|
||||||
|
top_p: float = 0.8
|
||||||
|
temperature: float = 0.8
|
||||||
|
do_sample: bool = True
|
||||||
|
repetition_penalty: float = 1.005
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_interactive(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
|
additional_eos_token_id: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
|
||||||
|
input_length = len(inputs["input_ids"][0])
|
||||||
|
for k, v in inputs.items():
|
||||||
|
inputs[k] = v.cuda()
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
|
||||||
|
if generation_config is None:
|
||||||
|
generation_config = model.generation_config
|
||||||
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
|
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||||
|
generation_config.bos_token_id,
|
||||||
|
generation_config.eos_token_id,
|
||||||
|
)
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
if additional_eos_token_id is not None:
|
||||||
|
eos_token_id.append(additional_eos_token_id)
|
||||||
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
|
warnings.warn(
|
||||||
|
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||||
|
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||||||
|
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
elif generation_config.max_new_tokens is not None:
|
||||||
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
|
if not has_default_max_length:
|
||||||
|
logger.warn( # pylint: disable=W4902
|
||||||
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
|
"Please refer to the documentation for more information. "
|
||||||
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_ids_seq_length >= generation_config.max_length:
|
||||||
|
input_ids_string = "input_ids"
|
||||||
|
logger.warning(
|
||||||
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||||
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||||
|
" increasing `max_new_tokens`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Set generation parameters if not already defined
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
|
||||||
|
logits_processor = model._get_logits_processor(
|
||||||
|
generation_config=generation_config,
|
||||||
|
input_ids_seq_length=input_ids_seq_length,
|
||||||
|
encoder_input_ids=input_ids,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
stopping_criteria = model._get_stopping_criteria(
|
||||||
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
|
)
|
||||||
|
logits_warper = model._get_logits_warper(generation_config)
|
||||||
|
|
||||||
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
scores = None
|
||||||
|
while True:
|
||||||
|
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
# forward pass to get next token
|
||||||
|
outputs = model(
|
||||||
|
**model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# pre-process distribution
|
||||||
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
|
# sample
|
||||||
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
if generation_config.do_sample:
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
else:
|
||||||
|
next_tokens = torch.argmax(probs, dim=-1)
|
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step
|
||||||
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||||
|
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
|
||||||
|
|
||||||
|
output_token_ids = input_ids[0].cpu().tolist()
|
||||||
|
output_token_ids = output_token_ids[input_length:]
|
||||||
|
for each_eos_token_id in eos_token_id:
|
||||||
|
if output_token_ids[-1] == each_eos_token_id:
|
||||||
|
output_token_ids = output_token_ids[:-1]
|
||||||
|
response = tokenizer.decode(output_token_ids)
|
||||||
|
|
||||||
|
yield response
|
||||||
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def on_btn_click():
|
||||||
|
del st.session_state.messages
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def load_model():
|
||||||
|
model = (
|
||||||
|
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
|
||||||
|
.to(torch.bfloat16)
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_config():
|
||||||
|
with st.sidebar:
|
||||||
|
# 使用 Streamlit 的 markdown 函数添加 Markdown 文本
|
||||||
|
st.image('assets/aiwei_logo.jpg', width=1, caption='EmoLLM-aiwei AI Logo', use_column_width=True)
|
||||||
|
st.markdown("[访问 EmoLLM 官方repo](https://github.com/aJupyter/EmoLLM)")
|
||||||
|
|
||||||
|
max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
|
||||||
|
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
|
||||||
|
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
|
||||||
|
st.button("Clear Chat History", on_click=on_btn_click)
|
||||||
|
|
||||||
|
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
||||||
|
|
||||||
|
return generation_config
|
||||||
|
|
||||||
|
|
||||||
|
user_prompt = "<|im_start|>user\n{user}<|im_end|>\n"
|
||||||
|
robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
|
||||||
|
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
|
||||||
|
def combine_history(prompt):
|
||||||
|
messages = st.session_state.messages
|
||||||
|
meta_instruction = (
|
||||||
|
"你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n"
|
||||||
|
)
|
||||||
|
total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
|
||||||
|
for message in messages:
|
||||||
|
cur_content = message["content"]
|
||||||
|
if message["role"] == "user":
|
||||||
|
cur_prompt = user_prompt.format(user=cur_content)
|
||||||
|
elif message["role"] == "robot":
|
||||||
|
cur_prompt = robot_prompt.format(robot=cur_content)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
total_prompt += cur_prompt
|
||||||
|
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
|
||||||
|
return total_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
print("load model begin.")
|
||||||
|
model, tokenizer = load_model()
|
||||||
|
print("load model end.")
|
||||||
|
|
||||||
|
user_avator = "assets/user.png"
|
||||||
|
robot_avator = "assets/robot.jpeg"
|
||||||
|
|
||||||
|
st.title("EmoLLM-温柔御姐艾薇(aiwei)")
|
||||||
|
|
||||||
|
generation_config = prepare_generation_config()
|
||||||
|
|
||||||
|
# Initialize chat history
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state.messages = []
|
||||||
|
|
||||||
|
# Display chat messages from history on app rerun
|
||||||
|
for message in st.session_state.messages:
|
||||||
|
with st.chat_message(message["role"], avatar=message.get("avatar")):
|
||||||
|
st.markdown(message["content"])
|
||||||
|
|
||||||
|
# Accept user input
|
||||||
|
if prompt := st.chat_input("What is up?"):
|
||||||
|
# Display user message in chat message container
|
||||||
|
with st.chat_message("user", avatar=user_avator):
|
||||||
|
st.markdown(prompt)
|
||||||
|
real_prompt = combine_history(prompt)
|
||||||
|
# Add user message to chat history
|
||||||
|
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
||||||
|
|
||||||
|
with st.chat_message("robot", avatar=robot_avator):
|
||||||
|
message_placeholder = st.empty()
|
||||||
|
for cur_response in generate_interactive(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=real_prompt,
|
||||||
|
additional_eos_token_id=92542,
|
||||||
|
**asdict(generation_config),
|
||||||
|
):
|
||||||
|
# Display robot response in chat message container
|
||||||
|
message_placeholder.markdown(cur_response + "▌")
|
||||||
|
message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable
|
||||||
|
# Add robot response to chat history
|
||||||
|
st.session_state.messages.append(
|
||||||
|
{
|
||||||
|
"role": "robot",
|
||||||
|
"content": cur_response, # pylint: disable=undefined-loop-variable
|
||||||
|
"avatar": robot_avator,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
218
xtuner_config/aiwei-internlm2_chat_7b_qlora.py
Normal file
218
xtuner_config/aiwei-internlm2_chat_7b_qlora.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from mmengine.dataset import DefaultSampler
|
||||||
|
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||||
|
LoggerHook, ParamSchedulerHook)
|
||||||
|
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
||||||
|
from peft import LoraConfig
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
|
BitsAndBytesConfig)
|
||||||
|
|
||||||
|
from xtuner.dataset import process_hf_dataset
|
||||||
|
from xtuner.dataset.collate_fns import default_collate_fn
|
||||||
|
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
||||||
|
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
||||||
|
VarlenAttnArgsToMessageHubHook)
|
||||||
|
from xtuner.engine.runner import TrainLoop
|
||||||
|
from xtuner.model import SupervisedFinetune
|
||||||
|
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
||||||
|
|
||||||
|
from mmengine.visualization import Visualizer,WandbVisBackend, TensorboardVisBackend
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 1 Settings #
|
||||||
|
#######################################################################
|
||||||
|
# Model
|
||||||
|
pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b'
|
||||||
|
# /root/share/model_repos/internlm2-chat-7b
|
||||||
|
use_varlen_attn = False
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data_path = './aiwei.json'
|
||||||
|
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
||||||
|
max_length = 2048
|
||||||
|
pack_to_max_length = True
|
||||||
|
|
||||||
|
# Scheduler & Optimizer
|
||||||
|
batch_size = 16 # per_device
|
||||||
|
accumulative_counts = 1
|
||||||
|
dataloader_num_workers = 0
|
||||||
|
max_epochs = 5
|
||||||
|
optim_type = AdamW
|
||||||
|
lr = 1e-5
|
||||||
|
betas = (0.9, 0.999)
|
||||||
|
weight_decay = 0.0001
|
||||||
|
max_norm = 1 # grad clip
|
||||||
|
warmup_ratio = 0.03
|
||||||
|
|
||||||
|
# Save
|
||||||
|
save_steps = 100
|
||||||
|
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
||||||
|
|
||||||
|
# Evaluate the generation performance during the training
|
||||||
|
evaluation_freq = 100
|
||||||
|
SYSTEM = "现在你是一个拥有丰富心理学知识的温柔御姐艾薇医生,我有一些心理问题,请你用专业的知识和温柔的口吻帮我解决,可以生成一些可爱的Emoji表情符号或者文本符号。"
|
||||||
|
evaluation_inputs = [
|
||||||
|
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。'
|
||||||
|
]
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 2 Model & Tokenizer #
|
||||||
|
#######################################################################
|
||||||
|
tokenizer = dict(
|
||||||
|
type=AutoTokenizer.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
padding_side='right')
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type=SupervisedFinetune,
|
||||||
|
use_varlen_attn=use_varlen_attn,
|
||||||
|
llm=dict(
|
||||||
|
type=AutoModelForCausalLM.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
quantization_config=dict(
|
||||||
|
type=BitsAndBytesConfig,
|
||||||
|
load_in_4bit=True,
|
||||||
|
load_in_8bit=False,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type='nf4')),
|
||||||
|
lora=dict(
|
||||||
|
type=LoraConfig,
|
||||||
|
r=64,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.1,
|
||||||
|
bias='none',
|
||||||
|
task_type='CAUSAL_LM'))
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 3 Dataset & Dataloader #
|
||||||
|
#######################################################################
|
||||||
|
alpaca_en = dict(
|
||||||
|
type=process_hf_dataset,
|
||||||
|
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=max_length,
|
||||||
|
dataset_map_fn=None,
|
||||||
|
template_map_fn=dict(
|
||||||
|
type=template_map_fn_factory, template=prompt_template),
|
||||||
|
remove_unused_columns=True,
|
||||||
|
shuffle_before_pack=True,
|
||||||
|
pack_to_max_length=pack_to_max_length,
|
||||||
|
use_varlen_attn=use_varlen_attn)
|
||||||
|
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=dataloader_num_workers,
|
||||||
|
dataset=alpaca_en,
|
||||||
|
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||||
|
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 4 Scheduler & Optimizer #
|
||||||
|
#######################################################################
|
||||||
|
# optimizer
|
||||||
|
optim_wrapper = dict(
|
||||||
|
type=AmpOptimWrapper,
|
||||||
|
optimizer=dict(
|
||||||
|
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
||||||
|
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
||||||
|
accumulative_counts=accumulative_counts,
|
||||||
|
loss_scale='dynamic',
|
||||||
|
dtype='float16')
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
||||||
|
param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type=LinearLR,
|
||||||
|
start_factor=1e-5,
|
||||||
|
by_epoch=True,
|
||||||
|
begin=0,
|
||||||
|
end=warmup_ratio * max_epochs,
|
||||||
|
convert_to_iter_based=True),
|
||||||
|
dict(
|
||||||
|
type=CosineAnnealingLR,
|
||||||
|
eta_min=0.0,
|
||||||
|
by_epoch=True,
|
||||||
|
begin=warmup_ratio * max_epochs,
|
||||||
|
end=max_epochs,
|
||||||
|
convert_to_iter_based=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
# train, val, test setting
|
||||||
|
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
||||||
|
|
||||||
|
#######################################################################
|
||||||
|
# PART 5 Runtime #
|
||||||
|
#######################################################################
|
||||||
|
# Log the dialogue periodically during the training process, optional
|
||||||
|
custom_hooks = [
|
||||||
|
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
||||||
|
dict(
|
||||||
|
type=EvaluateChatHook,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
every_n_iters=evaluation_freq,
|
||||||
|
evaluation_inputs=evaluation_inputs,
|
||||||
|
system=SYSTEM,
|
||||||
|
prompt_template=prompt_template)
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_varlen_attn:
|
||||||
|
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
||||||
|
|
||||||
|
# configure default hooks
|
||||||
|
default_hooks = dict(
|
||||||
|
# record the time of every iteration.
|
||||||
|
timer=dict(type=IterTimerHook),
|
||||||
|
# print log every 10 iterations.
|
||||||
|
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
||||||
|
# enable the parameter scheduler.
|
||||||
|
param_scheduler=dict(type=ParamSchedulerHook),
|
||||||
|
# save checkpoint per `save_steps`.
|
||||||
|
checkpoint=dict(
|
||||||
|
type=CheckpointHook,
|
||||||
|
by_epoch=False,
|
||||||
|
interval=save_steps,
|
||||||
|
max_keep_ckpts=save_total_limit),
|
||||||
|
# set sampler seed in distributed evrionment.
|
||||||
|
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure environment
|
||||||
|
env_cfg = dict(
|
||||||
|
# whether to enable cudnn benchmark
|
||||||
|
cudnn_benchmark=False,
|
||||||
|
# set multi process parameters
|
||||||
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||||
|
# set distributed parameters
|
||||||
|
dist_cfg=dict(backend='nccl'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# set visualizer
|
||||||
|
visualizer = dict(
|
||||||
|
type=Visualizer,
|
||||||
|
vis_backends=[dict(type=WandbVisBackend)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# set log level
|
||||||
|
log_level = 'INFO'
|
||||||
|
|
||||||
|
# load from which checkpoint
|
||||||
|
load_from = None
|
||||||
|
|
||||||
|
# whether to resume training from the loaded checkpoint
|
||||||
|
resume = True
|
||||||
|
|
||||||
|
# Defaults to use random seed and disable `deterministic`
|
||||||
|
randomness = dict(seed=None, deterministic=False)
|
||||||
|
|
||||||
|
# set log processor
|
||||||
|
log_processor = dict(by_epoch=False)
|
Loading…
Reference in New Issue
Block a user