diff --git a/README.md b/README.md index 50f49b9..2360b61 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git | :----------------------------------------------------------: | :------------------------------------------------: | :----------------------------------------------------------: | :-------------------------------------------: | | [aJupyter](https://github.com/aJupyter) | 南开大学在读硕士 | DataWhale成员 | 项目发起人 | | [MING-ZCH](https://github.com/MING-ZCH) | 华中科技大学在读本科生 | LLM x Mental health 研究者 | 项目联合负责人 | +| [chg0901](https://github.com/chg0901) | 韩国光云大学在读博士 [MiniSora](https://github.com/mini-sora/minisora/) | DataWhale意向成员 DataWhale鲸英助教团成员 | 项目联合负责人 | | [jujimeizuo](https://github.com/jujimeizuo) | 江南大学在读硕士 | | | | [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | 哈尔滨工业大学(威海)在读本科生 | | | | [8baby8](https://github.com/8baby8) | 飞桨领航团区域主管 | 文心大模型核心开发者 | | @@ -290,7 +291,6 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git | [ZeyuBa](https://github.com/ZeyuBa) | 自动化所在读硕士 | | | | [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | 宾夕法尼亚大学在读硕士 | | | | [Nobody-ML](https://github.com/Nobody-ML) | 中国石油大学(华东)在读本科生 | | | -| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora/) | [MiniSora](https://github.com/mini-sora/minisora/)主要维护者,管理员 | LLM预训练和微调、模型上传、数据清洗、文档翻译 | | [Mxoder](https://github.com/Mxoder) | 北京航空航天大学在读本科生 | | | | [Anooyman](https://github.com/Anooyman) | 南京理工大学硕士 | | | | [Vicky-3021](https://github.com/Vicky-3021) | 西安电子科技大学硕士(研0) | | | diff --git a/README_EN.md b/README_EN.md index 12da481..8a21bb7 100644 --- a/README_EN.md +++ b/README_EN.md @@ -285,6 +285,7 @@ This project uses Git for version control. You can see the currently available v | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | [aJupyter](https://github.com/aJupyter) | Nankai University, Master's student | DataWhale member | Project initiator | | [MING-ZCH](https://github.com/MING-ZCH) | Huazhong University of Science and Technology, Undergraduate student | LLM X Mental health researcher | Project co-leader | +| [chg0901](https://github.com/chg0901) | Ph.D Student of Kwangwoon University in South Korea| [MiniSora](https://github.com/mini-sora/minisora) | Project co-leader | | [jujimeizuo](https://github.com/jujimeizuo) | Jiangnan University, Master's student | | | | [Smiling-Weeping-zhr](https://github.com/Smiling-Weeping-zhr) | Harbin Institute of Technology (Weihai), Undergraduate student | | | | [8baby8](https://github.com/8baby8) | PaddlePaddle Pilot Team Regional Director | Wenxin Large Model core developer | | @@ -294,7 +295,6 @@ This project uses Git for version control. You can see the currently available v | [ZeyuBa](https://github.com/ZeyuBa) | Institute of Automation, Master's student | | | | [aiyinyuedejustin](https://github.com/aiyinyuedejustin) | University of Pennsylvania, Master's student | | | | [Nobody-ML](https://github.com/Nobody-ML) | China University of Petroleum (East China), Undergraduate student | | | -| [chg0901](https://github.com/chg0901) | [MiniSora](https://github.com/mini-sora/minisora) | Maintainer and Admin of [MiniSora](https://github.com/mini-sora/minisora) | LLM Pre-Training and Fine-Tuning, Model Uploading, Data Cleaning and Docs Translation | | [Mxoder](https://github.com/Mxoder) | Beihang University, Undergraduate student | | | | [Anooyman](https://github.com/Anooyman) | Nanjing University of Science and Technology, Master's student | | | | [Vicky-3021](https://github.com/Vicky-3021) | Xidian University, Master's student (Research Year 0) | | | diff --git a/rag/src/pipeline.py b/rag/src/pipeline.py index 08b9b96..2550381 100644 --- a/rag/src/pipeline.py +++ b/rag/src/pipeline.py @@ -45,7 +45,7 @@ class EmoLLMRAG(object): def get_retrieval_content(self, query) -> str: """ Input: 用户提问, 是否需要rerank - ouput: 检索后并且 rerank 的内容 + output: 检索后并且 rerank 的内容 """ content = [] diff --git a/web_internlm2_5.py b/web_internlm2_5.py new file mode 100644 index 0000000..64dbd57 --- /dev/null +++ b/web_internlm2_5.py @@ -0,0 +1,294 @@ +"""This script refers to the dialogue example of streamlit, the interactive +generation code of chatglm2 and transformers. + +We mainly modified part of the code logic to adapt to the +generation of our model. +Please refer to these links below for more information: + 1. streamlit chat example: + https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps + 2. chatglm2: + https://github.com/THUDM/ChatGLM2-6B + 3. transformers: + https://github.com/huggingface/transformers +Please run with the command `streamlit run path/to/web_demo.py + --server.address=0.0.0.0 --server.port 7860`. +Using `python path/to/web_demo.py` may cause unknown problems. +""" +# isort: skip_file +import copy +import warnings +from dataclasses import asdict, dataclass +from typing import Callable, List, Optional + +import streamlit as st +import torch +from torch import nn +from transformers.generation.utils import (LogitsProcessorList, + StoppingCriteriaList) +from transformers.utils import logging + +from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip + +logger = logging.get_logger(__name__) + +# local +model_path = '/root/EmoLLM/xtuner_config/hf4' +# Online downloading will be added later + +@dataclass +class GenerationConfig: + # this config is used for chat to provide more diversity + max_length: int = 32768 + top_p: float = 0.8 + temperature: float = 0.8 + do_sample: bool = True + repetition_penalty: float = 1.005 + + +@torch.inference_mode() +def generate_interactive( + model, + tokenizer, + prompt, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, + additional_eos_token_id: Optional[int] = None, + **kwargs, +): + inputs = tokenizer([prompt], padding=True, return_tensors='pt') + input_length = len(inputs['input_ids'][0]) + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs['input_ids'] + _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + if generation_config is None: + generation_config = model.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if additional_eos_token_id is not None: + eos_token_id.append(additional_eos_token_id) + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using 'max_length''s default \ + ({repr(generation_config.max_length)}) \ + to control the generation length. " + 'This behaviour is deprecated and will be removed from the \ + config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum \ + length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + \ + input_ids_seq_length + if not has_default_max_length: + logger.warn( # pylint: disable=W4902 + f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " + f"and 'max_length'(={generation_config.max_length}) seem to " + "have been set. 'max_new_tokens' will take precedence. " + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/' + 'en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'input_ids' + logger.warning( + f'Input length of {input_ids_string} is {input_ids_seq_length}, ' + f"but 'max_length' is set to {generation_config.max_length}. " + 'This can lead to unexpected behavior. You should consider' + " increasing 'max_new_tokens'.") + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None \ + else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None \ + else StoppingCriteriaList() + + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = model.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul( + (min(next_tokens != i for i in eos_token_id)).long()) + + output_token_ids = input_ids[0].cpu().tolist() + output_token_ids = output_token_ids[input_length:] + for each_eos_token_id in eos_token_id: + if output_token_ids[-1] == each_eos_token_id: + output_token_ids = output_token_ids[:-1] + response = tokenizer.decode(output_token_ids) + + yield response + # stop when each sentence is finished + # or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores): + break + + +def on_btn_click(): + del st.session_state.messages + + +@st.cache_resource +def load_model(): + model = (AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()) + tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True) + return model, tokenizer + + +def prepare_generation_config(): + with st.sidebar: + # 使用 Streamlit 的 markdown 函数添加 Markdown 文本 + st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True) + st.markdown("[访问 EmoLLM 官方repo](https://github.com/SmartFlowAI/EmoLLM)") + + max_length = st.slider('Max Length', + min_value=8, + max_value=32768, + value=32768) + top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01) + temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01) + st.button('Clear Chat History', on_click=on_btn_click) + + generation_config = GenerationConfig(max_length=max_length, + top_p=top_p, + temperature=temperature) + + return generation_config + + +user_prompt = '<|im_start|>user\n{user}<|im_end|>\n' +robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n' +cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\ + <|im_start|>assistant\n' + + +def combine_history(prompt): + messages = st.session_state.messages + meta_instruction = ('你是EmoLLM心理咨询师, 由EmoLLM团队打造, 是一个研究过无数具有心理咨询者与顶级专业心理咨询师对话的心理学教授, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断, 利用专业心理学知识与咨询技术一步步帮助来访者解决心理问题。') + total_prompt = f'<|im_start|>system\n{meta_instruction}<|im_end|>\n' + for message in messages: + cur_content = message['content'] + if message['role'] == 'user': + cur_prompt = user_prompt.format(user=cur_content) + elif message['role'] == 'robot': + cur_prompt = robot_prompt.format(robot=cur_content) + else: + raise RuntimeError + total_prompt += cur_prompt + total_prompt = total_prompt + cur_query_prompt.format(user=prompt) + return total_prompt + + +def main(): + st.markdown("我在这里,准备好倾听你的心声了。", unsafe_allow_html=True) + # torch.cuda.empty_cache() + print('load model begin.') + model, tokenizer = load_model() + print('load model end.') + + user_avator = 'assets/user.png' + robot_avator = 'assets/EmoLLM.png' + + st.title('EmoLLM V3.0 心理咨询室') + + generation_config = prepare_generation_config() + + # Initialize chat history + if 'messages' not in st.session_state: + st.session_state.messages = [] + + # Display chat messages from history on app rerun + for message in st.session_state.messages: + with st.chat_message(message['role'], avatar=message.get('avatar')): + st.markdown(message['content']) + + # Accept user input + if prompt := st.chat_input('我在这里准备好倾听你的心声了~'): + # Display user message in chat message container + with st.chat_message('user', avatar=user_avator): + st.markdown(prompt) + real_prompt = combine_history(prompt) + # Add user message to chat history + st.session_state.messages.append({ + 'role': 'user', + 'content': prompt, + 'avatar': user_avator + }) + + with st.chat_message('robot', avatar=robot_avator): + message_placeholder = st.empty() + for cur_response in generate_interactive( + model=model, + tokenizer=tokenizer, + prompt=real_prompt, + additional_eos_token_id=92542, + **asdict(generation_config), + ): + # Display robot response in chat message container + message_placeholder.markdown(cur_response + '▌') + message_placeholder.markdown(cur_response) + # Add robot response to chat history + st.session_state.messages.append({ + 'role': 'robot', + 'content': cur_response, # pylint: disable=undefined-loop-variable + 'avatar': robot_avator, + }) + torch.cuda.empty_cache() + + +if __name__ == '__main__': + main() diff --git a/xtuner_config/README_internlm2_7b_base_qlora.md b/xtuner_config/README_internlm2_7b_base_qlora.md index f583b42..3276ae6 100644 --- a/xtuner_config/README_internlm2_7b_base_qlora.md +++ b/xtuner_config/README_internlm2_7b_base_qlora.md @@ -138,7 +138,7 @@ model = dict( ### 数据处理 - 使用 `../datasets/process.py` 以处理 **multi_turn_dataset(1 和 2,QA数据转单轮对话)**, `data.json` 和 `data_pro.json` 文件(两个多轮对话),以添加或者调整 **`system` prompt** -- 使用 `../datasets/processed/process_single_turn_conversation_construction.py` 处理 **single-turn dataset** (1 和 2),修改 (`input` 和 `ouput`) ,并在每次 **conversation** 中添加 **`system` prompt** +- 使用 `../datasets/processed/process_single_turn_conversation_construction.py` 处理 **single-turn dataset** (1 和 2),修改 (`input` 和 `output`) ,并在每次 **conversation** 中添加 **`system` prompt** - 使用 `../datasets/processed/process_merge.py` 用于合并 `../datasets/processed/` 目录下**6个更新后的数据集**,生成一个合并后的数据集 `combined_data.json`用于最终训练 ## 基于XTuner的微调🎉🎉🎉🎉🎉 diff --git a/xtuner_config/internlm2_5_chat_7b_full.py b/xtuner_config/internlm2_5_chat_7b_full.py new file mode 100644 index 0000000..d580c77 --- /dev/null +++ b/xtuner_config/internlm2_5_chat_7b_full.py @@ -0,0 +1,243 @@ +# Copyright (c) OpenMMLab. All rights reserved. +""" +Ref: https://github.com/InternLM/xtuner/edit/main/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_full_finetune_custom_dataset_e1.py + +Data format: +[ + { + "conversation": [ + { + "system": "", + "input": "xxx", + "output": "xxx" + }, + { + "input": "xxx", + "output": "xxx" + } + ] + }, +... +] +Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details. +""" # noqa: E501 +from datasets import load_dataset +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from torch.optim import AdamW +from torch.utils.data import BatchSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import template_map_fn_factory +from xtuner.dataset.samplers import InternRepoSampler +from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = '/root/share/new_models/Shanghai_AI_Laboratory/internlm2_5-7b-chat' +use_varlen_attn = True + +# Data +data_files = ['/root/EmoLLM/datasets/multi_turn_dataset_2.json'] +prompt_template = PROMPT_TEMPLATE.internlm2_chat +# max_length = 32768 +max_length = int(32768/4) ## A100*2 +pack_to_max_length = True + +# parallel +sequence_parallel_size = 1 + +# Scheduler & Optimizer +# batch size per device, set to 1 if `use_varlen_attn` = True +# To clarify, enlarging the batch size essentially enlarges the `max_length`. +# For example, doubling the max length is tantamount to doubling the batch size +batch_size = 1 +accumulative_counts = 1 # 1bs * 1acc * 64gpu = 64 batchsize +accumulative_counts *= sequence_parallel_size +dataloader_num_workers = 4 +max_epochs = 3 + +optim_type = AdamW +lr = 4e-5 +betas = (0.9, 0.95) +weight_decay = 0.01 +max_norm = 1 # grad clip +warm_up_ratio = 0.025 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = "你由EmoLLM团队打造的中文领域心理健康助手, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验,你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题, 接下来你将只使用中文来回答和咨询问题。" +evaluation_inputs = [ + # '躲在云朵下面就不怕淋雨了', # ruozhi train + # '李白如果告语文书侵权的话能赔不少钱吧', # ruozhi test + # '雨天,我走进水坑里,不小心踩碎了天空。', # ruozhi test + '请介绍你自己', # self cognition + '你好', + '我今天心情不好,感觉不开心,很烦。', + '我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', +] + +# "这是一句富有想象力和幽默感的表达。在现实生活中,躲在云朵下面并不能避免淋雨,因为云朵实际上是由水蒸气凝结形成的,而雨就是由这些水滴凝结而成的。\n\n这样的表达可能是在夸张和幽默的语境中,通过一种天马行空的方式来表达逃避现实或者寻找避难的愿望。在文学或口语表达中,常常会运用夸张和幽默来传达情感或者引起共鸣。", + +# "如果李白(701年2月28日—762年12月)真的能就侵权诉讼获得赔偿,那确实可能是一笔不小的金额。然而,这种想法主要是一种有趣的假设,因为现实中有几个关键因素使这种情况不可能发生:\n\n1. **时间差异**:李白生活在唐朝,距今大约1200多年前。那 个时代的法律体系与现代的知识产权法律截然不同,当时没有现代意义上的版权法。\n\n2. **版权法的适用范围**:即使在现代,版权 法也有一定的时效限制。在大多数国家,版权保护通常在作者去世后一定年数内有效(如我国是作者去世后50年)。李白去世已超过1250年,因此其作品已经进入公共领域,任何人都可以自由使用而无需支付版权费用。\n\n3. **历史与现实的区别**:历史人物无法在现代 法律体系中提起诉讼,因为他们不再是活跃的法律主体。\n\n所以,虽然这是一个有趣的想法,但在现实中,李白或其他古代作者无法因其作品被现代出版物使用而获得赔偿。", + +# "这个描述似乎是一个修辞手法,比喻性地描述了雨天的场景。在这个描述中,说“我走进水坑里,不小心踩碎了天空”,实际上并非字面意义上的发生,而是一种用词语来比喻雨天的场景。\n\n通常情况下,当雨水落在水坑或者蓄水池时,水面会泛起涟漪或者波纹,可能会反射天空的颜色或者天空的倒影。因此,这句话可能是通过“踩碎了天空”的说法来比喻雨天时踩进水坑的情景,描述雨水落在水坑中形成的波纹或者涟漪,产生了一种倒映天空的效果。\n\n这种形象化的表达方式可能是为了更生动地描述一个平常的场景,赋予它一些诗意或者意境。", + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + use_varlen_attn=use_varlen_attn, + dataset=dict(type=load_dataset, path='json', data_files=data_files), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=None, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024), + batch_sampler=dict( + type=BatchSampler, drop_last=True, batch_size=batch_size), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', +) + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1 / 40, + by_epoch=True, + begin=0, + end=warm_up_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=lr * 0.15, + by_epoch=True, + begin=warm_up_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict( + type=DatasetInfoHook, tokenizer=tokenizer, + is_intern_repo_dataset=True), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template), + dict(type=ThroughputHook) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +log_processor = dict( + by_epoch=False, + window_size=1, + mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')