diff --git a/README.md b/README.md index ac71210..c6625b5 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ | DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | | | Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | | | LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | | -| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct2.0) | +| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct2.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct2.0/summary) | | …… | …… | …… | …… | diff --git a/README_EN.md b/README_EN.md index 0a92ddf..2ccccf3 100644 --- a/README_EN.md +++ b/README_EN.md @@ -60,7 +60,7 @@ | DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | | | Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | | | LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | | -| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct2.0) | +| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct2.0), [ModelScope](https://modelscope.cn/models/chg0901/EmoLLM-Llama3-8B-Instruct2.0/summary) | | …… | …… | …… | …… | diff --git a/app.py b/app.py index d0b5b4b..dded4e2 100644 --- a/app.py +++ b/app.py @@ -11,8 +11,6 @@ elif model == "EmoLLM_Model": os.system("python download_model.py jujimeizuo/EmoLLM_Model") os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') elif model == "Llama3_Model": - os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0") - # os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') os.system('streamlit run web_demo-Llama3.py --server.address=0.0.0.0 --server.port 7968') else: print("Please select one model") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 95ac855..62de5a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,3 @@ tiktoken einops oss2 requests -websockets==11.0.3 diff --git a/web_demo-Llama3.py b/web_demo-Llama3.py index 8e6501d..00a668f 100644 --- a/web_demo-Llama3.py +++ b/web_demo-Llama3.py @@ -1,5 +1,3 @@ - -# isort: skip_file import copy import os import warnings @@ -9,30 +7,24 @@ from typing import Callable, List, Optional import streamlit as st import torch from torch import nn -from transformers.generation.utils import (LogitsProcessorList, - StoppingCriteriaList) +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.utils import logging -from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip -from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip warnings.filterwarnings("ignore") -warnings.filterwarnings("ignore", category=DeprecationWarning) logger = logging.get_logger(__name__) -if not os.path.isdir("model"): - print("[ERROR] not find model dir") - exit(0) @dataclass class GenerationConfig: # this config is used for chat to provide more diversity - max_length: int = 500 - top_p: float = 0.9 - temperature: float = 0.6 + max_length: int = 32768 + top_p: float = 0.8 + temperature: float = 0.8 do_sample: bool = True - repetition_penalty: float = 1.1 + repetition_penalty: float = 1.005 @torch.inference_mode() @@ -43,17 +35,16 @@ def generate_interactive( generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], - List[int]]] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, additional_eos_token_id: Optional[int] = None, **kwargs, ): - inputs = tokenizer([prompt], return_tensors='pt') - input_length = len(inputs['input_ids'][0]) + inputs = tokenizer([prompt], padding=True, return_tensors="pt") + input_length = len(inputs["input_ids"][0]) for k, v in inputs.items(): inputs[k] = v.cuda() - input_ids = inputs['input_ids'] - _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + input_ids = inputs["input_ids"] + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612 if generation_config is None: generation_config = model.generation_config generation_config = copy.deepcopy(generation_config) @@ -66,45 +57,36 @@ def generate_interactive( eos_token_id = [eos_token_id] if additional_eos_token_id is not None: eos_token_id.append(additional_eos_token_id) - has_default_max_length = kwargs.get( - 'max_length') is None and generation_config.max_length is not None + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - f"Using 'max_length''s default ({repr(generation_config.max_length)}) \ - to control the generation length. " - 'This behaviour is deprecated and will be removed from the \ - config in v5 of Transformers -- we' - ' recommend using `max_new_tokens` to control the maximum \ - length of the generation.', + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + \ - input_ids_seq_length + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( # pylint: disable=W4902 - f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " - f"and 'max_length'(={generation_config.max_length}) seem to " - "have been set. 'max_new_tokens' will take precedence. " - 'Please refer to the documentation for more information. ' - '(https://huggingface.co/docs/transformers/main/' - 'en/main_classes/text_generation)', + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", UserWarning, ) if input_ids_seq_length >= generation_config.max_length: - input_ids_string = 'input_ids' + input_ids_string = "input_ids" logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, " - f"but 'max_length' is set to {generation_config.max_length}. " - 'This can lead to unexpected behavior. You should consider' - " increasing 'max_new_tokens'.") + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None \ - else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None \ - else StoppingCriteriaList() + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor = model._get_logits_processor( generation_config=generation_config, @@ -115,15 +97,14 @@ def generate_interactive( ) stopping_criteria = model._get_stopping_criteria( - generation_config=generation_config, - stopping_criteria=stopping_criteria) + generation_config=generation_config, stopping_criteria=stopping_criteria + ) logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None while True: - model_inputs = model.prepare_inputs_for_generation( - input_ids, **model_kwargs) + model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = model( **model_inputs, @@ -147,10 +128,8 @@ def generate_interactive( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = model._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=False) - unfinished_sequences = unfinished_sequences.mul( - (min(next_tokens != i for i in eos_token_id)).long()) + model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) output_token_ids = input_ids[0].cpu().tolist() output_token_ids = output_token_ids[input_length:] @@ -160,10 +139,8 @@ def generate_interactive( response = tokenizer.decode(output_token_ids) yield response - # stop when each sentence is finished - # or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria( - input_ids, scores): + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): break @@ -171,90 +148,42 @@ def on_btn_click(): del st.session_state.messages -# @st.cache_resource -# def load_model(arg1): -# # model = AutoModelForCausalLM.from_pretrained(args.m).cuda() -# # tokenizer = AutoTokenizer.from_pretrained(args.m, trust_remote_code=True) -# model = AutoModelForCausalLM.from_pretrained(arg1, torch_dtype=torch.float16).cuda() -# tokenizer = AutoTokenizer.from_pretrained(arg1, trust_remote_code=True) - - -# return model, tokenizer - @st.cache_resource def load_model(): - model = AutoModelForCausalLM.from_pretrained("model", - device_map="auto", - trust_remote_code=True, - torch_dtype=torch.float16) - model = model.eval() - tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) - return model, tokenizer + + model_name = "./EmoLLM-Llama3-8B-Instruct2.0" + print(model_name) + print('pip install modelscope websockets') + os.system(f'pip install modelscope websockets==11.0.3') + from modelscope import snapshot_download -@st.cache_resource -def load_model0(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): - if load_in_4bit: - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - ) - else: - quantization_config = None - - # 加载base model - model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, - # load_in_4bit=load_in_4bit, - # # ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time. - trust_remote_code=True, - low_cpu_mem_usage=True, - torch_dtype=torch.float16, - device_map='auto', - quantization_config=quantization_config - ) - - # 加载adapter - if adapter_name_or_path is not None: - model = PeftModel.from_pretrained(model, adapter_name_or_path) - - ## 加载tokenzier - tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path if adapter_name_or_path is None else adapter_name_or_path, - trust_remote_code=True, - use_fast=False - ) + #模型下载 + model_name = snapshot_download('chg0901/EmoLLM-Llama3-8B-Instruct2.0',cache_dir=model_name) + print(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval() + # model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + return model, tokenizer def prepare_generation_config(): with st.sidebar: - # 使用 Streamlit 的 markdown 函数添加 Markdown 文本 st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True) st.markdown("[访问 **EmoLLM** 官方repo: **SmartFlowAI/EmoLLM**](https://github.com/SmartFlowAI/EmoLLM)") - - max_length = st.slider('Max Length', - min_value=8, - max_value=8192, - value=500) - top_p = st.slider('Top P', 0.0, 1.0, 0.9, step=0.01) - temperature = st.slider('Temperature', 0.0, 1.0, 0.6, step=0.01) - repetition_penalty = st.slider('Repetition penalty', 0.0, 1.5, 1.1, step=0.01) - st.button('Clear Chat History', on_click=on_btn_click) - generation_config = GenerationConfig(max_length=max_length, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - do_sample=True) + max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768) + top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01) + temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01) + st.button("Clear Chat History", on_click=on_btn_click) + + generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature) return generation_config @@ -266,17 +195,15 @@ cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|> def combine_history(prompt): messages = st.session_state.messages - meta_instruction = ( - "你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。。" + "你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。\n\n" ) - total_prompt = f"<|start_header_id|>system<|end_header_id|>\n{meta_instruction}<|eot_id|>\n" - + total_prompt =f"<|start_header_id|>system<|end_header_id|>\n\n{meta_instruction}<|eot_id|>\n\n" for message in messages: - cur_content = message['content'] - if message['role'] == 'user': + cur_content = message["content"] + if message["role"] == "user": cur_prompt = user_prompt.format(user=cur_content) - elif message['role'] == 'robot': + elif message["role"] == "robot": cur_prompt = robot_prompt.format(robot=cur_content) else: raise RuntimeError @@ -286,79 +213,58 @@ def combine_history(prompt): def main(): - - st.markdown("我在这里,准备好倾听你的心声了。", unsafe_allow_html=True) - model_name_or_path = 'model' - adapter_name_or_path = None - # torch.cuda.empty_cache() - print('load model begin.') - # 加载模型 - print(f'Loading model from: {model_name_or_path}') - print(f'adapter_name_or_path: {adapter_name_or_path}') - # model, tokenizer = load_model(arg1) - + torch.cuda.empty_cache() + print("load model begin.") model, tokenizer = load_model() - - # model, tokenizer = load_model( - # # arg1 if arg1 is not None else model_name_or_path, - # model_name_or_path, - # load_in_4bit=load_in_4bit, - # adapter_name_or_path=adapter_name_or_path - # ) - model.eval() - print('load model end.') - + print("load model end.") + user_avator = "assets/user.png" robot_avator = "assets/EmoLLM.png" - st.title('EmoLLM Llama3心理咨询室V2.0') + st.title("EmoLLM Llama3心理咨询室V2.0") generation_config = prepare_generation_config() # Initialize chat history - if 'messages' not in st.session_state: + if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: - with st.chat_message(message['role'], avatar=message.get("avatar")): - st.markdown(message['content']) + with st.chat_message(message["role"], avatar=message.get("avatar")): + st.markdown(message["content"]) # Accept user input - if prompt := st.chat_input('你好,欢迎来到Llama3 EmoLLM 心理咨询室'): + if prompt := st.chat_input("我在这里,准备好倾听你的心声了。"): # Display user message in chat message container - with st.chat_message('user', avatar=user_avator): + with st.chat_message("user", avatar=user_avator): st.markdown(prompt) + real_prompt = combine_history(prompt) # Add user message to chat history - st.session_state.messages.append({ - 'role': 'user', - 'content': prompt, - 'avatar': user_avator - }) + st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator}) - - with st.chat_message('robot', avatar=robot_avator): + with st.chat_message("robot", avatar=robot_avator): message_placeholder = st.empty() for cur_response in generate_interactive( - model=model, - tokenizer=tokenizer, - prompt=real_prompt, - additional_eos_token_id=128009, # <|eot_id|> - eos_token_id=128009, - pad_token_id=128009, - **asdict(generation_config), + model=model, + tokenizer=tokenizer, + prompt=real_prompt, + additional_eos_token_id=128009, + **asdict(generation_config), ): # Display robot response in chat message container - message_placeholder.markdown(cur_response + '▌') - message_placeholder.markdown(cur_response) + message_placeholder.markdown(cur_response + "▌") + message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable # Add robot response to chat history - st.session_state.messages.append({ - 'role': 'robot', - 'content': cur_response, # pylint: disable=undefined-loop-variable - "avatar": robot_avator, - }) + st.session_state.messages.append( + { + "role": "robot", + "content": cur_response, # pylint: disable=undefined-loop-variable + "avatar": robot_avator, + } + ) torch.cuda.empty_cache() diff --git a/web_demo-Llama3_online.py b/web_demo-Llama3_online.py deleted file mode 100644 index c2c41c2..0000000 --- a/web_demo-Llama3_online.py +++ /dev/null @@ -1,378 +0,0 @@ - -# isort: skip_file -import copy -import os -import warnings -from dataclasses import asdict, dataclass -from typing import Callable, List, Optional - -import streamlit as st -import torch -from torch import nn -from transformers.generation.utils import (LogitsProcessorList, - StoppingCriteriaList) -from transformers.utils import logging - -from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip -from peft import PeftModel - - -warnings.filterwarnings("ignore") -warnings.filterwarnings("ignore", category=DeprecationWarning) -logger = logging.get_logger(__name__) - -if not os.path.isdir("model"): - print("[ERROR] not find model dir") - exit(0) - -online = True - -## running on local to test online function -# if online: -# from openxlab.model import download -# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', -# output='model') - -@dataclass -class GenerationConfig: - # this config is used for chat to provide more diversity - max_length: int = 500 - top_p: float = 0.9 - temperature: float = 0.6 - do_sample: bool = True - repetition_penalty: float = 1.1 - - -@torch.inference_mode() -def generate_interactive( - model, - tokenizer, - prompt, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], - List[int]]] = None, - additional_eos_token_id: Optional[int] = None, - **kwargs, -): - inputs = tokenizer([prompt], return_tensors='pt') - input_length = len(inputs['input_ids'][0]) - for k, v in inputs.items(): - inputs[k] = v.cuda() - input_ids = inputs['input_ids'] - _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - if generation_config is None: - generation_config = model.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 - generation_config.bos_token_id, - generation_config.eos_token_id, - ) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - if additional_eos_token_id is not None: - eos_token_id.append(additional_eos_token_id) - has_default_max_length = kwargs.get( - 'max_length') is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using 'max_length''s default ({repr(generation_config.max_length)}) \ - to control the generation length. " - 'This behaviour is deprecated and will be removed from the \ - config in v5 of Transformers -- we' - ' recommend using `max_new_tokens` to control the maximum \ - length of the generation.', - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + \ - input_ids_seq_length - if not has_default_max_length: - logger.warn( # pylint: disable=W4902 - f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " - f"and 'max_length'(={generation_config.max_length}) seem to " - "have been set. 'max_new_tokens' will take precedence. " - 'Please refer to the documentation for more information. ' - '(https://huggingface.co/docs/transformers/main/' - 'en/main_classes/text_generation)', - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = 'input_ids' - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, " - f"but 'max_length' is set to {generation_config.max_length}. " - 'This can lead to unexpected behavior. You should consider' - " increasing 'max_new_tokens'.") - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None \ - else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None \ - else StoppingCriteriaList() - - logits_processor = model._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = model._get_stopping_criteria( - generation_config=generation_config, - stopping_criteria=stopping_criteria) - logits_warper = model._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = model.prepare_inputs_for_generation( - input_ids, **model_kwargs) - # forward pass to get next token - outputs = model( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = model._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=False) - unfinished_sequences = unfinished_sequences.mul( - (min(next_tokens != i for i in eos_token_id)).long()) - - output_token_ids = input_ids[0].cpu().tolist() - output_token_ids = output_token_ids[input_length:] - for each_eos_token_id in eos_token_id: - if output_token_ids[-1] == each_eos_token_id: - output_token_ids = output_token_ids[:-1] - response = tokenizer.decode(output_token_ids) - - yield response - # stop when each sentence is finished - # or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria( - input_ids, scores): - break - - -def on_btn_click(): - del st.session_state.messages - - -# @st.cache_resource -# def load_model(arg1): -# # model = AutoModelForCausalLM.from_pretrained(args.m).cuda() -# # tokenizer = AutoTokenizer.from_pretrained(args.m, trust_remote_code=True) -# model = AutoModelForCausalLM.from_pretrained(arg1, torch_dtype=torch.float16).cuda() -# tokenizer = AutoTokenizer.from_pretrained(arg1, trust_remote_code=True) - - -# return model, tokenizer -@st.cache_resource -def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): - if load_in_4bit: - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - ) - else: - quantization_config = None - - # 加载base model - model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, - # load_in_4bit=load_in_4bit, - # # ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time. - trust_remote_code=True, - low_cpu_mem_usage=True, - torch_dtype=torch.float16, - device_map='auto', - quantization_config=quantization_config - ) - - # 加载adapter - if adapter_name_or_path is not None: - model = PeftModel.from_pretrained(model, adapter_name_or_path) - - ## 加载tokenzier - tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path if adapter_name_or_path is None else adapter_name_or_path, - trust_remote_code=True, - use_fast=False - ) - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - return model, tokenizer - - -def prepare_generation_config(): - with st.sidebar: - - # 使用 Streamlit 的 markdown 函数添加 Markdown 文本 - st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True) - st.markdown("[访问 **EmoLLM** 官方repo: **SmartFlowAI/EmoLLM**](https://github.com/SmartFlowAI/EmoLLM)") - - max_length = st.slider('Max Length', - min_value=8, - max_value=8192, - value=500) - top_p = st.slider('Top P', 0.0, 1.0, 0.9, step=0.01) - temperature = st.slider('Temperature', 0.0, 1.0, 0.6, step=0.01) - repetition_penalty = st.slider('Repetition penalty', 0.0, 1.5, 1.1, step=0.01) - st.button('Clear Chat History', on_click=on_btn_click) - - generation_config = GenerationConfig(max_length=max_length, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - do_sample=True) - - return generation_config - - -user_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>' -robot_prompt = '<|start_header_id|>assistant<|end_header_id|>\n\n{robot}<|eot_id|>' -cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' - - -def combine_history(prompt): - messages = st.session_state.messages - - meta_instruction = ( - "你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。。" - ) - total_prompt = f"<|start_header_id|>system<|end_header_id|>\n{meta_instruction}<|eot_id|>\n" - - for message in messages: - cur_content = message['content'] - if message['role'] == 'user': - cur_prompt = user_prompt.format(user=cur_content) - elif message['role'] == 'robot': - cur_prompt = robot_prompt.format(robot=cur_content) - else: - raise RuntimeError - total_prompt += cur_prompt - total_prompt = total_prompt + cur_query_prompt.format(user=prompt) - return total_prompt - - -def main(arg1=None): - - - if online: - model_name_or_path = 'model' - adapter_name_or_path = None - else: - # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3" - # adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2' - - model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc" - adapter_name_or_path = None - - # 若开启4bit推理能够节省很多显存,但效果可能下降 - load_in_4bit = False # True # 6291MiB - - # torch.cuda.empty_cache() - print('load model begin.') - # 加载模型 - print(f'Loading model from: {model_name_or_path}') - print(f'adapter_name_or_path: {adapter_name_or_path}') - # model, tokenizer = load_model(arg1) - model, tokenizer = load_model( - arg1 if arg1 is not None else model_name_or_path, - load_in_4bit=load_in_4bit, - adapter_name_or_path=adapter_name_or_path - ) - model.eval() - print('load model end.') - - user_avator = "assets/user.png" - robot_avator = "assets/EmoLLM.png" - - st.title('EmoLLM Llama3心理咨询室V2.0') - - generation_config = prepare_generation_config() - - # Initialize chat history - if 'messages' not in st.session_state: - st.session_state.messages = [] - - # Display chat messages from history on app rerun - for message in st.session_state.messages: - with st.chat_message(message['role'], avatar=message.get("avatar")): - st.markdown(message['content']) - - # Accept user input - if prompt := st.chat_input('你好,欢迎来到Llama3 EmoLLM 心理咨询室'): - # Display user message in chat message container - with st.chat_message('user', avatar=user_avator): - st.markdown(prompt) - real_prompt = combine_history(prompt) - # Add user message to chat history - st.session_state.messages.append({ - 'role': 'user', - 'content': prompt, - 'avatar': user_avator - }) - - - # stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True) - # assert len(stop_token_id) == 1 - # stop_token_id = stop_token_id[0] - - with st.chat_message('robot', avatar=robot_avator): - message_placeholder = st.empty() - for cur_response in generate_interactive( - model=model, - tokenizer=tokenizer, - prompt=real_prompt, - additional_eos_token_id=128009, # <|eot_id|> - eos_token_id=128009, - pad_token_id=128009, - **asdict(generation_config), - ): - # Display robot response in chat message container - message_placeholder.markdown(cur_response + '▌') - message_placeholder.markdown(cur_response) - # Add robot response to chat history - st.session_state.messages.append({ - 'role': 'robot', - 'content': cur_response, # pylint: disable=undefined-loop-variable - "avatar": robot_avator, - }) - torch.cuda.empty_cache() - - -if __name__ == '__main__': - - if online: - main() - else: - import sys - arg1 = sys.argv[1] - main(arg1)