diff --git a/README.md b/README.md index f4bdc2a..bbe7be5 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ - 评估和诊断工具:为了有效促进心理健康,需要有科学的工具来评估个体的心理状态,以及诊断可能存在的心理问题。 ### 最近更新 - +- 【2024.2.23】推出基于InternLM2_7B_chat_qlora的 `温柔御姐心理医生艾薇`,[点击获取模型权重](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei),[配置文件](xtuner_config/aiwei-internlm2_chat_7b_qlora.py) - 【2024.2.23】更新[若干微调配置](/xtuner_config/),新增 [data_pro.json](/datasets/data_pro.json)(数量更多、场景更全、更丰富)和 [aiwei.json](/datasets/aiwei.json)(温柔御姐角色扮演专用,带有Emoji表情),即将推出 `温柔御姐心理医生艾薇` - 【2024.2.18】 [基于Qwen1_5-0_5B-Chat全量微调版本开源](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary),算力有限的道友可以玩起来~ - 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验! diff --git a/app.py b/app.py index 7e5d424..b1a6d12 100644 --- a/app.py +++ b/app.py @@ -1,2 +1,3 @@ import os -os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') +# os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') +os.system('streamlit run web_demo-aiwei.py --server.address=0.0.0.0 --server.port 7860') diff --git a/assets/aiwei_logo.jpg b/assets/aiwei_logo.jpg new file mode 100644 index 0000000..acce1d5 Binary files /dev/null and b/assets/aiwei_logo.jpg differ diff --git a/web_demo-aiwei.py b/web_demo-aiwei.py new file mode 100644 index 0000000..3e7d225 --- /dev/null +++ b/web_demo-aiwei.py @@ -0,0 +1,267 @@ +""" +This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers. +We mainly modified part of the code logic to adapt to the generation of our model. +Please refer to these links below for more information: + 1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps + 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B + 3. transformers: https://github.com/huggingface/transformers +Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`. +Using `python path/to/web_demo.py` may cause unknown problems. +""" +import copy +import warnings +from dataclasses import asdict, dataclass +from typing import Callable, List, Optional + +import streamlit as st +import torch +from torch import nn +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList +from transformers.utils import logging + +from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip +from openxlab.model import download + +logger = logging.get_logger(__name__) + +download(model_repo='ajupyter/EmoLLM_aiwei', + output='model') + +@dataclass +class GenerationConfig: + # this config is used for chat to provide more diversity + max_length: int = 32768 + top_p: float = 0.8 + temperature: float = 0.8 + do_sample: bool = True + repetition_penalty: float = 1.005 + + +@torch.inference_mode() +def generate_interactive( + model, + tokenizer, + prompt, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + additional_eos_token_id: Optional[int] = None, + **kwargs, +): + inputs = tokenizer([prompt], padding=True, return_tensors="pt") + input_length = len(inputs["input_ids"][0]) + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs["input_ids"] + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612 + if generation_config is None: + generation_config = model.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if additional_eos_token_id is not None: + eos_token_id.append(additional_eos_token_id) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( # pylint: disable=W4902 + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) + + output_token_ids = input_ids[0].cpu().tolist() + output_token_ids = output_token_ids[input_length:] + for each_eos_token_id in eos_token_id: + if output_token_ids[-1] == each_eos_token_id: + output_token_ids = output_token_ids[:-1] + response = tokenizer.decode(output_token_ids) + + yield response + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + +def on_btn_click(): + del st.session_state.messages + + +@st.cache_resource +def load_model(): + model = ( + AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True) + .to(torch.bfloat16) + .cuda() + ) + tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) + return model, tokenizer + + +def prepare_generation_config(): + with st.sidebar: + # 使用 Streamlit 的 markdown 函数添加 Markdown 文本 + st.image('assets/aiwei_logo.jpg', width=1, caption='EmoLLM-aiwei AI Logo', use_column_width=True) + st.markdown("[访问 EmoLLM 官方repo](https://github.com/aJupyter/EmoLLM)") + + max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768) + top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01) + temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01) + st.button("Clear Chat History", on_click=on_btn_click) + + generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature) + + return generation_config + + +user_prompt = "<|im_start|>user\n{user}<|im_end|>\n" +robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n" +cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n" + + +def combine_history(prompt): + messages = st.session_state.messages + meta_instruction = ( + "你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n" + ) + total_prompt = f"<|im_start|>system\n{meta_instruction}<|im_end|>\n" + for message in messages: + cur_content = message["content"] + if message["role"] == "user": + cur_prompt = user_prompt.format(user=cur_content) + elif message["role"] == "robot": + cur_prompt = robot_prompt.format(robot=cur_content) + else: + raise RuntimeError + total_prompt += cur_prompt + total_prompt = total_prompt + cur_query_prompt.format(user=prompt) + return total_prompt + + +def main(): + # torch.cuda.empty_cache() + print("load model begin.") + model, tokenizer = load_model() + print("load model end.") + + user_avator = "assets/user.png" + robot_avator = "assets/robot.jpeg" + + st.title("EmoLLM-温柔御姐艾薇(aiwei)") + + generation_config = prepare_generation_config() + + # Initialize chat history + if "messages" not in st.session_state: + st.session_state.messages = [] + + # Display chat messages from history on app rerun + for message in st.session_state.messages: + with st.chat_message(message["role"], avatar=message.get("avatar")): + st.markdown(message["content"]) + + # Accept user input + if prompt := st.chat_input("What is up?"): + # Display user message in chat message container + with st.chat_message("user", avatar=user_avator): + st.markdown(prompt) + real_prompt = combine_history(prompt) + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator}) + + with st.chat_message("robot", avatar=robot_avator): + message_placeholder = st.empty() + for cur_response in generate_interactive( + model=model, + tokenizer=tokenizer, + prompt=real_prompt, + additional_eos_token_id=92542, + **asdict(generation_config), + ): + # Display robot response in chat message container + message_placeholder.markdown(cur_response + "▌") + message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable + # Add robot response to chat history + st.session_state.messages.append( + { + "role": "robot", + "content": cur_response, # pylint: disable=undefined-loop-variable + "avatar": robot_avator, + } + ) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/xtuner_config/aiwei-internlm2_chat_7b_qlora.py b/xtuner_config/aiwei-internlm2_chat_7b_qlora.py new file mode 100644 index 0000000..f05b75a --- /dev/null +++ b/xtuner_config/aiwei-internlm2_chat_7b_qlora.py @@ -0,0 +1,218 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE + +from mmengine.visualization import Visualizer,WandbVisBackend, TensorboardVisBackend + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b' +# /root/share/model_repos/internlm2-chat-7b +use_varlen_attn = False + +# Data +data_path = './aiwei.json' +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = 2048 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 16 # per_device +accumulative_counts = 1 +dataloader_num_workers = 0 +max_epochs = 5 +optim_type = AdamW +lr = 1e-5 +betas = (0.9, 0.999) +weight_decay = 0.0001 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 100 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 100 +SYSTEM = "现在你是一个拥有丰富心理学知识的温柔御姐艾薇医生,我有一些心理问题,请你用专业的知识和温柔的口吻帮我解决,可以生成一些可爱的Emoji表情符号或者文本符号。" +evaluation_inputs = [ + '我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +alpaca_en = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=None, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length, + use_varlen_attn=use_varlen_attn) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=alpaca_en, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = dict( + type=Visualizer, + vis_backends=[dict(type=WandbVisBackend)] +) + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = True + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False)