diff --git a/.github/workflows/reademe-contributors b/.github/workflows/reademe-contributors deleted file mode 100644 index af3fd9d..0000000 --- a/.github/workflows/reademe-contributors +++ /dev/null @@ -1,14 +0,0 @@ -on: - push: - branches: - - main -name: Generate a list of contributors -jobs: - contrib-readme-en-job: - runs-on: ubuntu-latest - name: A job to automate contrib in readme - steps: - - name: Contribute List - uses: akhilmhdh/contributors-readme-action@v2.3.4 - env: - GITHUB_TOKEN: ${{ secrets.CONTRIBUTORS_TOKEN }} diff --git a/.gitignore b/.gitignore index de34f61..df8d9c9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ ESConv.json .DS_Store __pycache__/ tmp/ -data/zhipuai/ \ No newline at end of file +zhipuai/ \ No newline at end of file diff --git a/README.md b/README.md index e4225ed..6302533 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,6 @@ ## 🌟 Contributors -[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=2000)](https://github.com/aJupyter/EmoLLM/graphs/contributors) +[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=200)](https://github.com/aJupyter/EmoLLM/graphs/contributors) diff --git a/demo/cli_qwen.py b/demo/cli_qwen.py new file mode 100644 index 0000000..210b6fe --- /dev/null +++ b/demo/cli_qwen.py @@ -0,0 +1,210 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A simple command-line interactive chat demo.""" + +import argparse +import os +import platform +import shutil +from copy import deepcopy + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig +from transformers.trainer_utils import set_seed + +DEFAULT_CKPT_PATH = './merged' + +_WELCOME_MSG = '''\ +Welcome to use Emo-Chat model, type text to start chat, type :h to show command help. +(欢迎使用 Emo-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。) + +Note: This demo is governed by the original license of Qwen. +We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. +(注:本演示受EmoLLM的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) +''' +_HELP_MSG = '''\ +Commands: + :help / :h Show this help message 显示帮助信息 + :exit / :quit / :q Exit the demo 退出Demo + :clear / :cl Clear screen 清屏 + :clear-his / :clh Clear history 清除对话历史 + :history / :his Show history 显示对话历史 + :seed Show current random seed 显示当前随机种子 + :seed Set random seed to 设置随机种子 + :conf Show current generation config 显示生成配置 + :conf = Change generation config 修改生成配置 + :reset-conf Reset generation config 重置生成配置 +''' + + +def _load_model_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + if args.cpu_only: + device_map = "cpu" + else: + device_map = "auto" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + + config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + return model, tokenizer, config + + +def _gc(): + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _clear_screen(): + if platform.system() == "Windows": + os.system("cls") + else: + os.system("clear") + + +def _print_history(history): + terminal_width = shutil.get_terminal_size()[0] + print(f'History ({len(history)})'.center(terminal_width, '=')) + for index, (query, response) in enumerate(history): + print(f'User[{index}]: {query}') + print(f'QWen[{index}]: {response}') + print('=' * terminal_width) + + +def _get_input() -> str: + while True: + try: + message = input('User> ').strip() + except UnicodeDecodeError: + print('[ERROR] Encoding error in input') + continue + except KeyboardInterrupt: + exit(1) + if message: + return message + print('[ERROR] Query is empty') + + +def main(): + parser = argparse.ArgumentParser( + description='QWen-Chat command-line interactive chat demo.') + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") + args = parser.parse_args() + + history, response = [], '' + + model, tokenizer, config = _load_model_tokenizer(args) + orig_gen_config = deepcopy(model.generation_config) + + _clear_screen() + print(_WELCOME_MSG) + + seed = args.seed + + while True: + query = _get_input() + + # Process commands. + if query.startswith(':'): + command_words = query[1:].strip().split() + if not command_words: + command = '' + else: + command = command_words[0] + + if command in ['exit', 'quit', 'q']: + break + elif command in ['clear', 'cl']: + _clear_screen() + print(_WELCOME_MSG) + _gc() + continue + elif command in ['clear-history', 'clh']: + print(f'[INFO] All {len(history)} history cleared') + history.clear() + _gc() + continue + elif command in ['help', 'h']: + print(_HELP_MSG) + continue + elif command in ['history', 'his']: + _print_history(history) + continue + elif command in ['seed']: + if len(command_words) == 1: + print(f'[INFO] Current random seed: {seed}') + continue + else: + new_seed_s = command_words[1] + try: + new_seed = int(new_seed_s) + except ValueError: + print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number') + else: + print(f'[INFO] Random seed changed to {new_seed}') + seed = new_seed + continue + elif command in ['conf']: + if len(command_words) == 1: + print(model.generation_config) + else: + for key_value_pairs_str in command_words[1:]: + eq_idx = key_value_pairs_str.find('=') + if eq_idx == -1: + print('[WARNING] format: =') + continue + conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:] + try: + conf_value = eval(conf_value_str) + except Exception as e: + print(e) + continue + else: + print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}') + setattr(model.generation_config, conf_key, conf_value) + continue + elif command in ['reset-conf']: + print('[INFO] Reset generation config') + model.generation_config = deepcopy(orig_gen_config) + print(model.generation_config) + continue + else: + # As normal query. + pass + + # Run chat. + set_seed(seed) + try: + for response in model.chat_stream(tokenizer, query, history=history, generation_config=config): + _clear_screen() + print(f"\nUser: {query}") + print(f"\nQwen-Chat: {response}") + except KeyboardInterrupt: + print('[WARNING] Generation interrupted') + continue + + history.append((query, response)) + + +if __name__ == "__main__": + main() diff --git a/demo/requirements_qwen.txt b/demo/requirements_qwen.txt new file mode 100644 index 0000000..6a8a163 --- /dev/null +++ b/demo/requirements_qwen.txt @@ -0,0 +1,2 @@ +gradio<3.42 +mdtex2html \ No newline at end of file diff --git a/demo/web_qwen.py b/demo/web_qwen.py new file mode 100644 index 0000000..9769f81 --- /dev/null +++ b/demo/web_qwen.py @@ -0,0 +1,209 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A simple web interactive chat demo based on gradio.""" +import os +from argparse import ArgumentParser + +import gradio as gr +import mdtex2html + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + + +DEFAULT_CKPT_PATH = './merged' + + +def _get_args(): + parser = ArgumentParser() + parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, + help="Checkpoint name or path, default to %(default)r") + parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") + + parser.add_argument("--share", action="store_true", default=False, + help="Create a publicly shareable link for the interface.") + parser.add_argument("--inbrowser", action="store_true", default=False, + help="Automatically launch the interface in a new tab on the default browser.") + parser.add_argument("--server-port", type=int, default=6006, + help="Demo server port.") + parser.add_argument("--server-name", type=str, default="127.0.0.1", + help="Demo server name.") + + args = parser.parse_args() + return args + + +def _load_model_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + if args.cpu_only: + device_map = "cpu" + else: + device_map = "auto" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + device_map=device_map, + trust_remote_code=True, + resume_download=True, + ).eval() + + config = GenerationConfig.from_pretrained( + args.checkpoint_path, trust_remote_code=True, resume_download=True, + ) + + return model, tokenizer, config + + +def postprocess(self, y): + if y is None: + return [] + for i, (message, response) in enumerate(y): + y[i] = ( + None if message is None else mdtex2html.convert(message), + None if response is None else mdtex2html.convert(response), + ) + return y + + +gr.Chatbot.postprocess = postprocess + + +def _parse_text(text): + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split("`") + if count % 2 == 1: + lines[i] = f'
'
+            else:
+                lines[i] = f"
" + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", r"\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
" + line + text = "".join(lines) + return text + + +def _gc(): + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _launch_demo(args, model, tokenizer, config): + + def predict(_query, _chatbot, _task_history): + print(f"User: {_parse_text(_query)}") + _chatbot.append((_parse_text(_query), "")) + full_response = "" + + for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config): + _chatbot[-1] = (_parse_text(_query), _parse_text(response)) + + yield _chatbot + full_response = _parse_text(response) + + print(f"History: {_task_history}") + _task_history.append((_query, full_response)) + print(f"Qwen-Chat: {_parse_text(full_response)}") + + def regenerate(_chatbot, _task_history): + if not _task_history: + yield _chatbot + return + item = _task_history.pop(-1) + _chatbot.pop(-1) + yield from predict(item[0], _chatbot, _task_history) + + def reset_user_input(): + return gr.update(value="") + + def reset_state(_chatbot, _task_history): + _task_history.clear() + _chatbot.clear() + _gc() + return _chatbot + + with gr.Blocks() as demo: + gr.Markdown("""\ +

""") + gr.Markdown("""

Qwen-Chat Bot
""") + gr.Markdown( + """\ +
This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \ +(本WebUI基于Qwen-Chat打造,实现聊天机器人功能。)
""") + gr.Markdown("""\ +
+Qwen-7B 🤖 | +🤗  | +Qwen-7B-Chat 🤖 | +🤗  | +Qwen-14B 🤖 | +🤗  | +Qwen-14B-Chat 🤖 | +🤗  | + Github
""") + + chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height") + query = gr.Textbox(lines=2, label='Input') + task_history = gr.State([]) + + with gr.Row(): + empty_btn = gr.Button("🧹 Clear History (清除历史)") + submit_btn = gr.Button("🚀 Submit (发送)") + regen_btn = gr.Button("🤔️ Regenerate (重试)") + + submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True) + submit_btn.click(reset_user_input, [], [query]) + empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True) + regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) + + gr.Markdown("""\ +Note: This demo is governed by the original license of Qwen. \ +We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ +including hate speech, violence, pornography, deception, etc. \ +(注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ +包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") + + demo.queue().launch( + share=args.share, + inbrowser=args.inbrowser, + server_port=args.server_port, + server_name=args.server_name, + ) + + +def main(): + args = _get_args() + + model, tokenizer, config = _load_model_tokenizer(args) + + _launch_demo(args, model, tokenizer, config) + + +if __name__ == '__main__': + main() diff --git a/finetune/qwen_7b_chat_qlora_e3.py b/finetune/qwen_7b_chat_qlora_e3.py new file mode 100644 index 0000000..25914e4 --- /dev/null +++ b/finetune/qwen_7b_chat_qlora_e3.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from bitsandbytes.optim import PagedAdamW32bit +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from peft import LoraConfig +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory +from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'Qwen/Qwen-7B-Chat' + +# Data +data_path = './data/merge_fzt.json' +prompt_template = PROMPT_TEMPLATE.qwen_chat +max_length = 2048 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 0 +max_epochs = 3 +optim_type = PagedAdamW32bit +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。" +evaluation_inputs = [ + '我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。' +] +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right', + eos_token='<|im_end|>') + +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=None, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + T_max=max_epochs, + convert_to_iter_based=True) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + stop_word='<|im_end|>', + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) diff --git a/data/.env b/scripts/.env similarity index 100% rename from data/.env rename to scripts/.env diff --git a/data/merge_json.py b/scripts/merge_json.py similarity index 100% rename from data/merge_json.py rename to scripts/merge_json.py diff --git a/data/process.py b/scripts/process.py similarity index 100% rename from data/process.py rename to scripts/process.py diff --git a/data/qwen_gen_data.py b/scripts/qwen_gen_data.py similarity index 100% rename from data/qwen_gen_data.py rename to scripts/qwen_gen_data.py diff --git a/data/run_qwen.bash b/scripts/run_qwen.bash similarity index 100% rename from data/run_qwen.bash rename to scripts/run_qwen.bash diff --git a/data/trans_process.py b/scripts/trans_process.py similarity index 100% rename from data/trans_process.py rename to scripts/trans_process.py diff --git a/data/zhipuai_gen_data.py b/scripts/zhipuai_gen_data.py similarity index 100% rename from data/zhipuai_gen_data.py rename to scripts/zhipuai_gen_data.py