EmoLLM V3.0: (#274)

This commit is contained in:
xzw 2024-07-15 16:46:41 +08:00 committed by GitHub
commit f9fcc714ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 539 additions and 2 deletions

View File

@ -281,6 +281,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
| :----------------------------------------------------------: | :------------------------------------------------: | :----------------------------------------------------------: | :-------------------------------------------: | | :----------------------------------------------------------: | :------------------------------------------------: | :----------------------------------------------------------: | :-------------------------------------------: |
| [aJupyter](https://github.com/aJupyter) | 南开大学在读硕士 | DataWhale成员 | 项目发起人 | | [aJupyter](https://github.com/aJupyter) | 南开大学在读硕士 | DataWhale成员 | 项目发起人 |
| [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | LLM x Mental health 研究者 | 项目联合负责人 | | [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | LLM x Mental health 研究者 | 项目联合负责人 |
| [chg0901](https://github.com/chg0901) | 韩国光云大学在读博士 [MiniSora](https://github.com/mini-sora/minisora/) | DataWhale意向成员 DataWhale鲸英助教团成员 | 项目联合负责人 |
| [jujimeizuo](https://github.com/jujimeizuo) | 江南大学在读硕士 | | | | [jujimeizuo](https://github.com/jujimeizuo) | 江南大学在读硕士 | | |
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | 哈尔滨工业大学(威海)在读本科生 | | | | [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | 哈尔滨工业大学(威海)在读本科生 | | |
| [8baby8](https://github.com/8baby8) | 飞桨领航团区域主管 | 文心大模型核心开发者 | | | [8baby8](https://github.com/8baby8) | 飞桨领航团区域主管 | 文心大模型核心开发者 | |
@ -290,7 +291,6 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
| [ZeyuBa](https://github.com/ZeyuBa) | 自动化所在读硕士 | | | | [ZeyuBa](https://github.com/ZeyuBa) | 自动化所在读硕士 | | |
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | 宾夕法尼亚大学在读硕士 | | | | [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | 宾夕法尼亚大学在读硕士 | | |
| [Nobody-ML](https://github.com/Nobody-ML) | 中国石油大学(华东)在读本科生 | | | | [Nobody-ML](https://github.com/Nobody-ML) | 中国石油大学(华东)在读本科生 | | |
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora/) | [MiniSora](https://github.com/mini-sora/minisora/)主要维护者,管理员 | LLM预训练和微调、模型上传、数据清洗、文档翻译 |
| [Mxoder](https://github.com/Mxoder) | 北京航空航天大学在读本科生 | | | | [Mxoder](https://github.com/Mxoder) | 北京航空航天大学在读本科生 | | |
| [Anooyman](https://github.com/Anooyman) | 南京理工大学硕士 | | | | [Anooyman](https://github.com/Anooyman) | 南京理工大学硕士 | | |
| [Vicky-3021](https://github.com/Vicky-3021) | 西安电子科技大学硕士研0 | | | | [Vicky-3021](https://github.com/Vicky-3021) | 西安电子科技大学硕士研0 | | |

View File

@ -285,6 +285,7 @@ This project uses Git for version control. You can see the currently available v
| :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| [aJupyter](https://github.com/aJupyter) | Nankai University, Master's student | DataWhale member | Project initiator | | [aJupyter](https://github.com/aJupyter) | Nankai University, Master's student | DataWhale member | Project initiator |
| [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | LLM X Mental health researcher | Project co-leader | | [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | LLM X Mental health researcher | Project co-leader |
| [chg0901](https://github.com/chg0901) | Ph.D Student of Kwangwoon University in South Korea| [MiniSora](https://github.com/mini-sora/minisora) | Project co-leader |
| [jujimeizuo](https://github.com/jujimeizuo) | Jiangnan University, Master's student | | | | [jujimeizuo](https://github.com/jujimeizuo) | Jiangnan University, Master's student | | |
| [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | Harbin Institute of Technology (Weihai), Undergraduate student | | | | [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | Harbin Institute of Technology (Weihai), Undergraduate student | | |
| [8baby8](https://github.com/8baby8) | PaddlePaddle Pilot Team Regional Director | Wenxin Large Model core developer | | | [8baby8](https://github.com/8baby8) | PaddlePaddle Pilot Team Regional Director | Wenxin Large Model core developer | |
@ -294,7 +295,6 @@ This project uses Git for version control. You can see the currently available v
| [ZeyuBa](https://github.com/ZeyuBa) | Institute of Automation, Master's student | | | | [ZeyuBa](https://github.com/ZeyuBa) | Institute of Automation, Master's student | | |
| [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | University of Pennsylvania, Master's student | | | | [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | University of Pennsylvania, Master's student | | |
| [Nobody-ML](https://github.com/Nobody-ML) | China University of Petroleum (East China), Undergraduate student | | | | [Nobody-ML](https://github.com/Nobody-ML) | China University of Petroleum (East China), Undergraduate student | | |
| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora) | Maintainer and Admin of [MiniSora](https://github.com/mini-sora/minisora) | LLM Pre-Training and Fine-Tuning, Model Uploading, Data Cleaning and Docs Translation |
| [Mxoder](https://github.com/Mxoder) | Beihang University, Undergraduate student | | | | [Mxoder](https://github.com/Mxoder) | Beihang University, Undergraduate student | | |
| [Anooyman](https://github.com/Anooyman) | Nanjing University of Science and Technology, Master's student | | | | [Anooyman](https://github.com/Anooyman) | Nanjing University of Science and Technology, Master's student | | |
| [Vicky-3021](https://github.com/Vicky-3021) | Xidian University, Master's student (Research Year 0) | | | | [Vicky-3021](https://github.com/Vicky-3021) | Xidian University, Master's student (Research Year 0) | | |

294
web_internlm2_5.py Normal file
View File

@ -0,0 +1,294 @@
"""This script refers to the dialogue example of streamlit, the interactive
generation code of chatglm2 and transformers.
We mainly modified part of the code logic to adapt to the
generation of our model.
Please refer to these links below for more information:
1. streamlit chat example:
https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
2. chatglm2:
https://github.com/THUDM/ChatGLM2-6B
3. transformers:
https://github.com/huggingface/transformers
Please run with the command `streamlit run path/to/web_demo.py
--server.address=0.0.0.0 --server.port 7860`.
Using `python path/to/web_demo.py` may cause unknown problems.
"""
# isort: skip_file
import copy
import warnings
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
import streamlit as st
import torch
from torch import nn
from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
logger = logging.get_logger(__name__)
# local
model_path = '/root/EmoLLM/xtuner_config/hf4'
# Online downloading will be added later
@dataclass
class GenerationConfig:
# this config is used for chat to provide more diversity
max_length: int = 32768
top_p: float = 0.8
temperature: float = 0.8
do_sample: bool = True
repetition_penalty: float = 1.005
@torch.inference_mode()
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
inputs = tokenizer([prompt], padding=True, return_tensors='pt')
input_length = len(inputs['input_ids'][0])
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
has_default_max_length = kwargs.get(
'max_length') is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using 'max_length''s default \
({repr(generation_config.max_length)}) \
to control the generation length. "
'This behaviour is deprecated and will be removed from the \
config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the maximum \
length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + \
input_ids_seq_length
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
f"and 'max_length'(={generation_config.max_length}) seem to "
"have been set. 'max_new_tokens' will take precedence. "
'Please refer to the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/'
'en/main_classes/text_generation)',
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string} is {input_ids_seq_length}, '
f"but 'max_length' is set to {generation_config.max_length}. "
'This can lead to unexpected behavior. You should consider'
" increasing 'max_new_tokens'.")
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None \
else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None \
else StoppingCriteriaList()
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)
yield response
# stop when each sentence is finished
# or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores):
break
def on_btn_click():
del st.session_state.messages
@st.cache_resource
def load_model():
model = (AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda())
tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
return model, tokenizer
def prepare_generation_config():
with st.sidebar:
# 使用 Streamlit 的 markdown 函数添加 Markdown 文本
st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True)
st.markdown("[访问 EmoLLM 官方repo](https://github.com/SmartFlowAI/EmoLLM)")
max_length = st.slider('Max Length',
min_value=8,
max_value=32768,
value=32768)
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
st.button('Clear Chat History', on_click=on_btn_click)
generation_config = GenerationConfig(max_length=max_length,
top_p=top_p,
temperature=temperature)
return generation_config
user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
<|im_start|>assistant\n'
def combine_history(prompt):
messages = st.session_state.messages
meta_instruction = ('你是EmoLLM心理咨询师, 由EmoLLM团队打造, 是一个研究过无数具有心理咨询者与顶级专业心理咨询师对话的心理学教授, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断, 利用专业心理学知识与咨询技术一步步帮助来访者解决心理问题。')
total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
for message in messages:
cur_content = message['content']
if message['role'] == 'user':
cur_prompt = user_prompt.format(user=cur_content)
elif message['role'] == 'robot':
cur_prompt = robot_prompt.format(robot=cur_content)
else:
raise RuntimeError
total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
return total_prompt
def main():
st.markdown("我在这里,准备好倾听你的心声了。", unsafe_allow_html=True)
# torch.cuda.empty_cache()
print('load model begin.')
model, tokenizer = load_model()
print('load model end.')
user_avator = 'assets/user.png'
robot_avator = 'assets/EmoLLM.png'
st.title('EmoLLM V3.0 心理咨询室')
generation_config = prepare_generation_config()
# Initialize chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message['role'], avatar=message.get('avatar')):
st.markdown(message['content'])
# Accept user input
if prompt := st.chat_input('我在这里准备好倾听你的心声了~'):
# Display user message in chat message container
with st.chat_message('user', avatar=user_avator):
st.markdown(prompt)
real_prompt = combine_history(prompt)
# Add user message to chat history
st.session_state.messages.append({
'role': 'user',
'content': prompt,
'avatar': user_avator
})
with st.chat_message('robot', avatar=robot_avator):
message_placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(cur_response + '')
message_placeholder.markdown(cur_response)
# Add robot response to chat history
st.session_state.messages.append({
'role': 'robot',
'content': cur_response, # pylint: disable=undefined-loop-variable
'avatar': robot_avator,
})
torch.cuda.empty_cache()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,243 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""
Ref: https://github.com/InternLM/xtuner/edit/main/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_full_finetune_custom_dataset_e1.py
Data format:
[
{
"conversation": [
{
"system": "",
"input": "xxx",
"output": "xxx"
},
{
"input": "xxx",
"output": "xxx"
}
]
},
...
]
Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details.
""" # noqa: E501
from datasets import load_dataset
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from torch.optim import AdamW
from torch.utils.data import BatchSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.dataset.samplers import InternRepoSampler
from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE
#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = '/root/share/new_models/Shanghai_AI_Laboratory/internlm2_5-7b-chat'
use_varlen_attn = True
# Data
data_files = ['/root/EmoLLM/datasets/multi_turn_dataset_2.json']
prompt_template = PROMPT_TEMPLATE.internlm2_chat
# max_length = 32768
max_length = int(32768/4) ## A100*2
pack_to_max_length = True
# parallel
sequence_parallel_size = 1
# Scheduler & Optimizer
# batch size per device, set to 1 if `use_varlen_attn` = True
# To clarify, enlarging the batch size essentially enlarges the `max_length`.
# For example, doubling the max length is tantamount to doubling the batch size
batch_size = 1
accumulative_counts = 1 # 1bs * 1acc * 64gpu = 64 batchsize
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 4
max_epochs = 3
optim_type = AdamW
lr = 4e-5
betas = (0.9, 0.95)
weight_decay = 0.01
max_norm = 1 # grad clip
warm_up_ratio = 0.025
# Save
save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = "你由EmoLLM团队打造的中文领域心理健康助手, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验,你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题, 接下来你将只使用中文来回答和咨询问题。"
evaluation_inputs = [
# '躲在云朵下面就不怕淋雨了', # ruozhi train
# '李白如果告语文书侵权的话能赔不少钱吧', # ruozhi test
# '雨天,我走进水坑里,不小心踩碎了天空。', # ruozhi test
'请介绍你自己', # self cognition
'你好',
'我今天心情不好,感觉不开心,很烦。',
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。',
]
# "这是一句富有想象力和幽默感的表达。在现实生活中,躲在云朵下面并不能避免淋雨,因为云朵实际上是由水蒸气凝结形成的,而雨就是由这些水滴凝结而成的。\n\n这样的表达可能是在夸张和幽默的语境中通过一种天马行空的方式来表达逃避现实或者寻找避难的愿望。在文学或口语表达中常常会运用夸张和幽默来传达情感或者引起共鸣。",
# "如果李白701年2月28日—762年12月真的能就侵权诉讼获得赔偿那确实可能是一笔不小的金额。然而这种想法主要是一种有趣的假设因为现实中有几个关键因素使这种情况不可能发生\n\n1. **时间差异**李白生活在唐朝距今大约1200多年前。那 个时代的法律体系与现代的知识产权法律截然不同,当时没有现代意义上的版权法。\n\n2. **版权法的适用范围**:即使在现代,版权 法也有一定的时效限制。在大多数国家版权保护通常在作者去世后一定年数内有效如我国是作者去世后50年。李白去世已超过1250年因此其作品已经进入公共领域任何人都可以自由使用而无需支付版权费用。\n\n3. **历史与现实的区别**:历史人物无法在现代 法律体系中提起诉讼,因为他们不再是活跃的法律主体。\n\n所以虽然这是一个有趣的想法但在现实中李白或其他古代作者无法因其作品被现代出版物使用而获得赔偿。",
# "这个描述似乎是一个修辞手法,比喻性地描述了雨天的场景。在这个描述中,说“我走进水坑里,不小心踩碎了天空”,实际上并非字面意义上的发生,而是一种用词语来比喻雨天的场景。\n\n通常情况下当雨水落在水坑或者蓄水池时水面会泛起涟漪或者波纹可能会反射天空的颜色或者天空的倒影。因此这句话可能是通过“踩碎了天空”的说法来比喻雨天时踩进水坑的情景描述雨水落在水坑中形成的波纹或者涟漪产生了一种倒映天空的效果。\n\n这种形象化的表达方式可能是为了更生动地描述一个平常的场景赋予它一些诗意或者意境。",
#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')
model = dict(
type=SupervisedFinetune,
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True))
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
train_dataset = dict(
type=process_hf_dataset,
use_varlen_attn=use_varlen_attn,
dataset=dict(type=load_dataset, path='json', data_files=data_files),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=None,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)
train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024),
batch_sampler=dict(
type=BatchSampler, drop_last=True, batch_size=batch_size),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
)
# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type='LinearLR',
start_factor=1 / 40,
by_epoch=True,
begin=0,
end=warm_up_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=lr * 0.15,
by_epoch=True,
begin=warm_up_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]
# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(
type=DatasetInfoHook, tokenizer=tokenizer,
is_intern_repo_dataset=True),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template),
dict(type=ThroughputHook)
]
if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
# set visualizer
visualizer = None
# set log level
log_level = 'INFO'
# load from which checkpoint
load_from = None
# whether to resume training from the loaded checkpoint
resume = False
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
log_processor = dict(
by_epoch=False,
window_size=1,
mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')