feat: finetune Qwen and demo
This commit is contained in:
parent
dc9208f4d5
commit
45b143b6ef
14
.github/workflows/reademe-contributors
vendored
14
.github/workflows/reademe-contributors
vendored
@ -1,14 +0,0 @@
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
name: Generate a list of contributors
|
||||
jobs:
|
||||
contrib-readme-en-job:
|
||||
runs-on: ubuntu-latest
|
||||
name: A job to automate contrib in readme
|
||||
steps:
|
||||
- name: Contribute List
|
||||
uses: akhilmhdh/contributors-readme-action@v2.3.4
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.CONTRIBUTORS_TOKEN }}
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,4 +2,4 @@ ESConv.json
|
||||
.DS_Store
|
||||
__pycache__/
|
||||
tmp/
|
||||
data/zhipuai/
|
||||
zhipuai/
|
@ -3,6 +3,6 @@
|
||||
|
||||
## 🌟 Contributors
|
||||
|
||||
[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=2000)](https://github.com/aJupyter/EmoLLM/graphs/contributors)
|
||||
[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=200)](https://github.com/aJupyter/EmoLLM/graphs/contributors)
|
||||
|
||||
|
||||
|
210
demo/cli_qwen.py
Normal file
210
demo/cli_qwen.py
Normal file
@ -0,0 +1,210 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""A simple command-line interactive chat demo."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
DEFAULT_CKPT_PATH = './merged'
|
||||
|
||||
_WELCOME_MSG = '''\
|
||||
Welcome to use Emo-Chat model, type text to start chat, type :h to show command help.
|
||||
(欢迎使用 Emo-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。)
|
||||
|
||||
Note: This demo is governed by the original license of Qwen.
|
||||
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
|
||||
(注:本演示受EmoLLM的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
|
||||
'''
|
||||
_HELP_MSG = '''\
|
||||
Commands:
|
||||
:help / :h Show this help message 显示帮助信息
|
||||
:exit / :quit / :q Exit the demo 退出Demo
|
||||
:clear / :cl Clear screen 清屏
|
||||
:clear-his / :clh Clear history 清除对话历史
|
||||
:history / :his Show history 显示对话历史
|
||||
:seed Show current random seed 显示当前随机种子
|
||||
:seed <N> Set random seed to <N> 设置随机种子
|
||||
:conf Show current generation config 显示生成配置
|
||||
:conf <key>=<value> Change generation config 修改生成配置
|
||||
:reset-conf Reset generation config 重置生成配置
|
||||
'''
|
||||
|
||||
|
||||
def _load_model_tokenizer(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
||||
)
|
||||
|
||||
if args.cpu_only:
|
||||
device_map = "cpu"
|
||||
else:
|
||||
device_map = "auto"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
).eval()
|
||||
|
||||
config = GenerationConfig.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
||||
)
|
||||
|
||||
return model, tokenizer, config
|
||||
|
||||
|
||||
def _gc():
|
||||
import gc
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _clear_screen():
|
||||
if platform.system() == "Windows":
|
||||
os.system("cls")
|
||||
else:
|
||||
os.system("clear")
|
||||
|
||||
|
||||
def _print_history(history):
|
||||
terminal_width = shutil.get_terminal_size()[0]
|
||||
print(f'History ({len(history)})'.center(terminal_width, '='))
|
||||
for index, (query, response) in enumerate(history):
|
||||
print(f'User[{index}]: {query}')
|
||||
print(f'QWen[{index}]: {response}')
|
||||
print('=' * terminal_width)
|
||||
|
||||
|
||||
def _get_input() -> str:
|
||||
while True:
|
||||
try:
|
||||
message = input('User> ').strip()
|
||||
except UnicodeDecodeError:
|
||||
print('[ERROR] Encoding error in input')
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
exit(1)
|
||||
if message:
|
||||
return message
|
||||
print('[ERROR] Query is empty')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='QWen-Chat command-line interactive chat demo.')
|
||||
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
|
||||
help="Checkpoint name or path, default to %(default)r")
|
||||
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
|
||||
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
|
||||
args = parser.parse_args()
|
||||
|
||||
history, response = [], ''
|
||||
|
||||
model, tokenizer, config = _load_model_tokenizer(args)
|
||||
orig_gen_config = deepcopy(model.generation_config)
|
||||
|
||||
_clear_screen()
|
||||
print(_WELCOME_MSG)
|
||||
|
||||
seed = args.seed
|
||||
|
||||
while True:
|
||||
query = _get_input()
|
||||
|
||||
# Process commands.
|
||||
if query.startswith(':'):
|
||||
command_words = query[1:].strip().split()
|
||||
if not command_words:
|
||||
command = ''
|
||||
else:
|
||||
command = command_words[0]
|
||||
|
||||
if command in ['exit', 'quit', 'q']:
|
||||
break
|
||||
elif command in ['clear', 'cl']:
|
||||
_clear_screen()
|
||||
print(_WELCOME_MSG)
|
||||
_gc()
|
||||
continue
|
||||
elif command in ['clear-history', 'clh']:
|
||||
print(f'[INFO] All {len(history)} history cleared')
|
||||
history.clear()
|
||||
_gc()
|
||||
continue
|
||||
elif command in ['help', 'h']:
|
||||
print(_HELP_MSG)
|
||||
continue
|
||||
elif command in ['history', 'his']:
|
||||
_print_history(history)
|
||||
continue
|
||||
elif command in ['seed']:
|
||||
if len(command_words) == 1:
|
||||
print(f'[INFO] Current random seed: {seed}')
|
||||
continue
|
||||
else:
|
||||
new_seed_s = command_words[1]
|
||||
try:
|
||||
new_seed = int(new_seed_s)
|
||||
except ValueError:
|
||||
print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
|
||||
else:
|
||||
print(f'[INFO] Random seed changed to {new_seed}')
|
||||
seed = new_seed
|
||||
continue
|
||||
elif command in ['conf']:
|
||||
if len(command_words) == 1:
|
||||
print(model.generation_config)
|
||||
else:
|
||||
for key_value_pairs_str in command_words[1:]:
|
||||
eq_idx = key_value_pairs_str.find('=')
|
||||
if eq_idx == -1:
|
||||
print('[WARNING] format: <key>=<value>')
|
||||
continue
|
||||
conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
|
||||
try:
|
||||
conf_value = eval(conf_value_str)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
else:
|
||||
print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
|
||||
setattr(model.generation_config, conf_key, conf_value)
|
||||
continue
|
||||
elif command in ['reset-conf']:
|
||||
print('[INFO] Reset generation config')
|
||||
model.generation_config = deepcopy(orig_gen_config)
|
||||
print(model.generation_config)
|
||||
continue
|
||||
else:
|
||||
# As normal query.
|
||||
pass
|
||||
|
||||
# Run chat.
|
||||
set_seed(seed)
|
||||
try:
|
||||
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
|
||||
_clear_screen()
|
||||
print(f"\nUser: {query}")
|
||||
print(f"\nQwen-Chat: {response}")
|
||||
except KeyboardInterrupt:
|
||||
print('[WARNING] Generation interrupted')
|
||||
continue
|
||||
|
||||
history.append((query, response))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
2
demo/requirements_qwen.txt
Normal file
2
demo/requirements_qwen.txt
Normal file
@ -0,0 +1,2 @@
|
||||
gradio<3.42
|
||||
mdtex2html
|
209
demo/web_qwen.py
Normal file
209
demo/web_qwen.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""A simple web interactive chat demo based on gradio."""
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import gradio as gr
|
||||
import mdtex2html
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
|
||||
DEFAULT_CKPT_PATH = './merged'
|
||||
|
||||
|
||||
def _get_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
|
||||
help="Checkpoint name or path, default to %(default)r")
|
||||
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
|
||||
|
||||
parser.add_argument("--share", action="store_true", default=False,
|
||||
help="Create a publicly shareable link for the interface.")
|
||||
parser.add_argument("--inbrowser", action="store_true", default=False,
|
||||
help="Automatically launch the interface in a new tab on the default browser.")
|
||||
parser.add_argument("--server-port", type=int, default=6006,
|
||||
help="Demo server port.")
|
||||
parser.add_argument("--server-name", type=str, default="127.0.0.1",
|
||||
help="Demo server name.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def _load_model_tokenizer(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
||||
)
|
||||
|
||||
if args.cpu_only:
|
||||
device_map = "cpu"
|
||||
else:
|
||||
device_map = "auto"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.checkpoint_path,
|
||||
device_map=device_map,
|
||||
trust_remote_code=True,
|
||||
resume_download=True,
|
||||
).eval()
|
||||
|
||||
config = GenerationConfig.from_pretrained(
|
||||
args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
||||
)
|
||||
|
||||
return model, tokenizer, config
|
||||
|
||||
|
||||
def postprocess(self, y):
|
||||
if y is None:
|
||||
return []
|
||||
for i, (message, response) in enumerate(y):
|
||||
y[i] = (
|
||||
None if message is None else mdtex2html.convert(message),
|
||||
None if response is None else mdtex2html.convert(response),
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
gr.Chatbot.postprocess = postprocess
|
||||
|
||||
|
||||
def _parse_text(text):
|
||||
lines = text.split("\n")
|
||||
lines = [line for line in lines if line != ""]
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if "```" in line:
|
||||
count += 1
|
||||
items = line.split("`")
|
||||
if count % 2 == 1:
|
||||
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
||||
else:
|
||||
lines[i] = f"<br></code></pre>"
|
||||
else:
|
||||
if i > 0:
|
||||
if count % 2 == 1:
|
||||
line = line.replace("`", r"\`")
|
||||
line = line.replace("<", "<")
|
||||
line = line.replace(">", ">")
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace("*", "*")
|
||||
line = line.replace("_", "_")
|
||||
line = line.replace("-", "-")
|
||||
line = line.replace(".", ".")
|
||||
line = line.replace("!", "!")
|
||||
line = line.replace("(", "(")
|
||||
line = line.replace(")", ")")
|
||||
line = line.replace("$", "$")
|
||||
lines[i] = "<br>" + line
|
||||
text = "".join(lines)
|
||||
return text
|
||||
|
||||
|
||||
def _gc():
|
||||
import gc
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _launch_demo(args, model, tokenizer, config):
|
||||
|
||||
def predict(_query, _chatbot, _task_history):
|
||||
print(f"User: {_parse_text(_query)}")
|
||||
_chatbot.append((_parse_text(_query), ""))
|
||||
full_response = ""
|
||||
|
||||
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
|
||||
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
|
||||
|
||||
yield _chatbot
|
||||
full_response = _parse_text(response)
|
||||
|
||||
print(f"History: {_task_history}")
|
||||
_task_history.append((_query, full_response))
|
||||
print(f"Qwen-Chat: {_parse_text(full_response)}")
|
||||
|
||||
def regenerate(_chatbot, _task_history):
|
||||
if not _task_history:
|
||||
yield _chatbot
|
||||
return
|
||||
item = _task_history.pop(-1)
|
||||
_chatbot.pop(-1)
|
||||
yield from predict(item[0], _chatbot, _task_history)
|
||||
|
||||
def reset_user_input():
|
||||
return gr.update(value="")
|
||||
|
||||
def reset_state(_chatbot, _task_history):
|
||||
_task_history.clear()
|
||||
_chatbot.clear()
|
||||
_gc()
|
||||
return _chatbot
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""\
|
||||
<p align="center"><img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/logo_qwen.jpg" style="height: 80px"/><p>""")
|
||||
gr.Markdown("""<center><font size=8>Qwen-Chat Bot</center>""")
|
||||
gr.Markdown(
|
||||
"""\
|
||||
<center><font size=3>This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \
|
||||
(本WebUI基于Qwen-Chat打造,实现聊天机器人功能。)</center>""")
|
||||
gr.Markdown("""\
|
||||
<center><font size=4>
|
||||
Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> |
|
||||
<a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>  |
|
||||
Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a> |
|
||||
<a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>  |
|
||||
Qwen-14B <a href="https://modelscope.cn/models/qwen/Qwen-14B/summary">🤖 </a> |
|
||||
<a href="https://huggingface.co/Qwen/Qwen-14B">🤗</a>  |
|
||||
Qwen-14B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary">🤖 </a> |
|
||||
<a href="https://huggingface.co/Qwen/Qwen-14B-Chat">🤗</a>  |
|
||||
 <a href="https://github.com/QwenLM/Qwen">Github</a></center>""")
|
||||
|
||||
chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height")
|
||||
query = gr.Textbox(lines=2, label='Input')
|
||||
task_history = gr.State([])
|
||||
|
||||
with gr.Row():
|
||||
empty_btn = gr.Button("🧹 Clear History (清除历史)")
|
||||
submit_btn = gr.Button("🚀 Submit (发送)")
|
||||
regen_btn = gr.Button("🤔️ Regenerate (重试)")
|
||||
|
||||
submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
|
||||
submit_btn.click(reset_user_input, [], [query])
|
||||
empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
|
||||
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
|
||||
|
||||
gr.Markdown("""\
|
||||
<font size=2>Note: This demo is governed by the original license of Qwen. \
|
||||
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
|
||||
including hate speech, violence, pornography, deception, etc. \
|
||||
(注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
|
||||
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
|
||||
|
||||
demo.queue().launch(
|
||||
share=args.share,
|
||||
inbrowser=args.inbrowser,
|
||||
server_port=args.server_port,
|
||||
server_name=args.server_name,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = _get_args()
|
||||
|
||||
model, tokenizer, config = _load_model_tokenizer(args)
|
||||
|
||||
_launch_demo(args, model, tokenizer, config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
183
finetune/qwen_7b_chat_qlora_e3.py
Normal file
183
finetune/qwen_7b_chat_qlora_e3.py
Normal file
@ -0,0 +1,183 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from bitsandbytes.optim import PagedAdamW32bit
|
||||
from datasets import load_dataset
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
|
||||
from peft import LoraConfig
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig)
|
||||
|
||||
from xtuner.dataset import process_hf_dataset
|
||||
from xtuner.dataset.collate_fns import default_collate_fn
|
||||
from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
|
||||
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
|
||||
from xtuner.model import SupervisedFinetune
|
||||
from xtuner.utils import PROMPT_TEMPLATE
|
||||
|
||||
#######################################################################
|
||||
# PART 1 Settings #
|
||||
#######################################################################
|
||||
# Model
|
||||
pretrained_model_name_or_path = 'Qwen/Qwen-7B-Chat'
|
||||
|
||||
# Data
|
||||
data_path = './data/merge_fzt.json'
|
||||
prompt_template = PROMPT_TEMPLATE.qwen_chat
|
||||
max_length = 2048
|
||||
pack_to_max_length = True
|
||||
|
||||
# Scheduler & Optimizer
|
||||
batch_size = 8 # per_device
|
||||
accumulative_counts = 2
|
||||
dataloader_num_workers = 0
|
||||
max_epochs = 3
|
||||
optim_type = PagedAdamW32bit
|
||||
lr = 2e-4
|
||||
betas = (0.9, 0.999)
|
||||
weight_decay = 0
|
||||
max_norm = 1 # grad clip
|
||||
|
||||
# Evaluate the generation performance during the training
|
||||
evaluation_freq = 500
|
||||
SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
|
||||
evaluation_inputs = [
|
||||
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。'
|
||||
]
|
||||
#######################################################################
|
||||
# PART 2 Model & Tokenizer #
|
||||
#######################################################################
|
||||
tokenizer = dict(
|
||||
type=AutoTokenizer.from_pretrained,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
padding_side='right',
|
||||
eos_token='<|im_end|>')
|
||||
|
||||
model = dict(
|
||||
type=SupervisedFinetune,
|
||||
llm=dict(
|
||||
type=AutoModelForCausalLM.from_pretrained,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
quantization_config=dict(
|
||||
type=BitsAndBytesConfig,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4')),
|
||||
lora=dict(
|
||||
type=LoraConfig,
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
bias='none',
|
||||
task_type='CAUSAL_LM'))
|
||||
|
||||
#######################################################################
|
||||
# PART 3 Dataset & Dataloader #
|
||||
#######################################################################
|
||||
train_dataset = dict(
|
||||
type=process_hf_dataset,
|
||||
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
|
||||
tokenizer=tokenizer,
|
||||
max_length=max_length,
|
||||
dataset_map_fn=None,
|
||||
template_map_fn=dict(
|
||||
type=template_map_fn_factory, template=prompt_template),
|
||||
remove_unused_columns=True,
|
||||
shuffle_before_pack=True,
|
||||
pack_to_max_length=pack_to_max_length)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=batch_size,
|
||||
num_workers=dataloader_num_workers,
|
||||
dataset=train_dataset,
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
collate_fn=dict(type=default_collate_fn))
|
||||
|
||||
#######################################################################
|
||||
# PART 4 Scheduler & Optimizer #
|
||||
#######################################################################
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type=AmpOptimWrapper,
|
||||
optimizer=dict(
|
||||
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
||||
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
||||
accumulative_counts=accumulative_counts,
|
||||
loss_scale='dynamic',
|
||||
dtype='float16')
|
||||
|
||||
# learning policy
|
||||
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
||||
param_scheduler = dict(
|
||||
type=CosineAnnealingLR,
|
||||
eta_min=0.0,
|
||||
by_epoch=True,
|
||||
T_max=max_epochs,
|
||||
convert_to_iter_based=True)
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
|
||||
|
||||
#######################################################################
|
||||
# PART 5 Runtime #
|
||||
#######################################################################
|
||||
# Log the dialogue periodically during the training process, optional
|
||||
custom_hooks = [
|
||||
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
||||
dict(
|
||||
type=EvaluateChatHook,
|
||||
tokenizer=tokenizer,
|
||||
every_n_iters=evaluation_freq,
|
||||
stop_word='<|im_end|>',
|
||||
evaluation_inputs=evaluation_inputs,
|
||||
system=SYSTEM,
|
||||
prompt_template=prompt_template)
|
||||
]
|
||||
|
||||
# configure default hooks
|
||||
default_hooks = dict(
|
||||
# record the time of every iteration.
|
||||
timer=dict(type=IterTimerHook),
|
||||
# print log every 100 iterations.
|
||||
logger=dict(type=LoggerHook, interval=10),
|
||||
# enable the parameter scheduler.
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
# save checkpoint per epoch.
|
||||
checkpoint=dict(type=CheckpointHook, interval=1),
|
||||
# set sampler seed in distributed evrionment.
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
)
|
||||
|
||||
# configure environment
|
||||
env_cfg = dict(
|
||||
# whether to enable cudnn benchmark
|
||||
cudnn_benchmark=False,
|
||||
# set multi process parameters
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
# set distributed parameters
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
|
||||
# set visualizer
|
||||
visualizer = None
|
||||
|
||||
# set log level
|
||||
log_level = 'INFO'
|
||||
|
||||
# load from which checkpoint
|
||||
load_from = None
|
||||
|
||||
# whether to resume training from the loaded checkpoint
|
||||
resume = False
|
||||
|
||||
# Defaults to use random seed and disable `deterministic`
|
||||
randomness = dict(seed=None, deterministic=False)
|
Loading…
Reference in New Issue
Block a user