From 5c4d4de9d7a4d4e7d2c402fc534085e78d0d7148 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 00:05:11 +0900 Subject: [PATCH 01/10] update llama3 webdemo --- app.py | 6 +++++- web_demo-Llama3_online.py | 33 +++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/app.py b/app.py index bebcaf3..dd75608 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,8 @@ import os #model = "EmoLLM_aiwei" -model = "EmoLLM_Model" +# model = "EmoLLM_Model" +model = "Llama3_Model" if model == "EmoLLM_aiwei": os.system("python download_model.py ajupyter/EmoLLM_aiwei") @@ -9,5 +10,8 @@ if model == "EmoLLM_aiwei": elif model == "EmoLLM_Model": os.system("python download_model.py jujimeizuo/EmoLLM_Model") os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') +elif model == "Llama3_Model": + os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0") + os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') else: print("Please select one model") \ No newline at end of file diff --git a/web_demo-Llama3_online.py b/web_demo-Llama3_online.py index e80f454..c2c41c2 100644 --- a/web_demo-Llama3_online.py +++ b/web_demo-Llama3_online.py @@ -1,6 +1,7 @@ # isort: skip_file import copy +import os import warnings from dataclasses import asdict, dataclass from typing import Callable, List, Optional @@ -15,17 +16,22 @@ from transformers.utils import logging from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip from peft import PeftModel -import warnings warnings.filterwarnings("ignore") warnings.filterwarnings("ignore", category=DeprecationWarning) logger = logging.get_logger(__name__) +if not os.path.isdir("model"): + print("[ERROR] not find model dir") + exit(0) + online = True -if online: - from openxlab.model import download - download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', - output='model') + +## running on local to test online function +# if online: +# from openxlab.model import download +# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', +# output='model') @dataclass class GenerationConfig: @@ -275,17 +281,17 @@ def combine_history(prompt): return total_prompt -def main(arg1): +def main(arg1=None): if online: model_name_or_path = 'model' adapter_name_or_path = None else: - # model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3" - # adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2' + # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3" + # adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2' - model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc" + model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc" adapter_name_or_path = None # 若开启4bit推理能够节省很多显存,但效果可能下降 @@ -364,6 +370,9 @@ def main(arg1): if __name__ == '__main__': - import sys - arg1 = sys.argv[1] - main(arg1) + if online: + main() + else: + import sys + arg1 = sys.argv[1] + main(arg1) From 430cd7353ceabdb9789898001ff347b2f27a5726 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 00:29:10 +0900 Subject: [PATCH 02/10] update model table --- README.md | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index dfe5ce2..356c290 100644 --- a/README.md +++ b/README.md @@ -44,21 +44,22 @@
-| 模型 | 类型 | 链接 | -| :-------------------: | :------: | :------------------------------------------------------------------------------------------------------: | -| InternLM2_7B_chat | QLORA | [internlm2_7b_chat_qlora_e3.py](./xtuner_config/internlm2_7b_chat_qlora_e3.py) | -| InternLM2_7B_chat | 全量微调 | [internlm2_chat_7b_full.py](./xtuner_config/internlm2_chat_7b_full.py) | -| InternLM2_7B_base | QLORA | [internlm2_7b_base_qlora_e10_M_1e4_32_64.py](./xtuner_config/internlm2_7b_base_qlora_e10_M_1e4_32_64.py) | -| InternLM2_1_8B_chat | 全量微调 | [internlm2_1_8b_full_alpaca_e3.py](./xtuner_config/internlm2_1_8b_full_alpaca_e3.py) | -| InternLM2_20B_chat | LORA |[internlm2_20b_chat_lora_alpaca_e3.py](./xtuner_config/internlm2_20b_chat_lora_alpaca_e3.py)| -| Qwen_7b_chat | QLORA | [qwen_7b_chat_qlora_e3.py](./xtuner_config/qwen_7b_chat_qlora_e3.py) | -| Qwen1_5-0_5B-Chat | 全量微调 | [qwen1_5_0_5_B_full.py](./xtuner_config/qwen1_5_0_5_B_full.py) | -| Baichuan2_13B_chat | QLORA | [baichuan2_13b_chat_qlora_alpaca_e3.py](./xtuner_config/baichuan2_13b_chat_qlora_alpaca_e3.py) | -| ChatGLM3_6B | LORA | [chatglm3_6b_lora_alpaca_e3.py](./xtuner_config/chatglm3_6b_lora_alpaca_e3.py) | -| DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | -| Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | -| LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | -| …… | …… | …… | +| 模型 | 类型 | 链接 | 模型链接 | +| :-------------------: | :------: | :------------------------------------------------------------------------------------------------------: |:------: | +| InternLM2_7B_chat | QLORA | [internlm2_7b_chat_qlora_e3.py](./xtuner_config/internlm2_7b_chat_qlora_e3.py) | | +| InternLM2_7B_chat | 全量微调 | [internlm2_chat_7b_full.py](./xtuner_config/internlm2_chat_7b_full.py) | | +| InternLM2_7B_base | QLORA | [internlm2_7b_base_qlora_e10_M_1e4_32_64.py](./xtuner_config/internlm2_7b_base_qlora_e10_M_1e4_32_64.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e), [ModelScope](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/summary) | +| InternLM2_1_8B_chat | 全量微调 | [internlm2_1_8b_full_alpaca_e3.py](./xtuner_config/internlm2_1_8b_full_alpaca_e3.py) | | +| InternLM2_20B_chat | LORA |[internlm2_20b_chat_lora_alpaca_e3.py](./xtuner_config/internlm2_20b_chat_lora_alpaca_e3.py)| | +| Qwen_7b_chat | QLORA | [qwen_7b_chat_qlora_e3.py](./xtuner_config/qwen_7b_chat_qlora_e3.py) | | +| Qwen1_5-0_5B-Chat | 全量微调 | [qwen1_5_0_5_B_full.py](./xtuner_config/qwen1_5_0_5_B_full.py) | | +| Baichuan2_13B_chat | QLORA | [baichuan2_13b_chat_qlora_alpaca_e3.py](./xtuner_config/baichuan2_13b_chat_qlora_alpaca_e3.py) | | +| ChatGLM3_6B | LORA | [chatglm3_6b_lora_alpaca_e3.py](./xtuner_config/chatglm3_6b_lora_alpaca_e3.py) | | +| DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | | +| Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | | +| LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | | +| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct2.0) | +| …… | …… | …… | …… |
From 7851977e13b22558d5f8b70ff0f23f336c6ad2d6 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 00:33:09 +0900 Subject: [PATCH 03/10] update model table EN --- README_EN.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/README_EN.md b/README_EN.md index 8b1ca0f..0a92ddf 100644 --- a/README_EN.md +++ b/README_EN.md @@ -46,22 +46,22 @@
-| Model | Type | link | -| :-------------------: | :--------------: | :---: | -| InternLM2_7B_chat | QLORA | [internlm2_7b_chat_qlora_e3.py](./xtuner_config/internlm2_7b_chat_qlora_e3.py) | -| InternLM2_7B_chat | full fine-tuning | [internlm2_chat_7b_full.py](./xtuner_config/internlm2_chat_7b_full.py) | -| InternLM2_7B_base | QLORA | [internlm2_7b_base_qlora_e10_M_1e4_32_64.py](./xtuner_config/internlm2_7b_base_qlora_e10_M_1e4_32_64.py) | -| InternLM2_1_8B_chat | full fine-tuning | [internlm2_1_8b_full_alpaca_e3.py](./xtuner_config/internlm2_1_8b_full_alpaca_e3.py) | -| InternLM2_20B_chat | LORA |[internlm2_20b_chat_lora_alpaca_e3.py](./xtuner_config/internlm2_20b_chat_lora_alpaca_e3.py)| -| Qwen_7b_chat | QLORA | [qwen_7b_chat_qlora_e3.py](./xtuner_config/qwen_7b_chat_qlora_e3.py) | -| Qwen1_5-0_5B-Chat | full fine-tuning | [qwen1_5_0_5_B_full.py](./xtuner_config/qwen1_5_0_5_B_full.py) | -| Baichuan2_13B_chat | QLORA | [baichuan2_13b_chat_qlora_alpaca_e3.py](./xtuner_config/baichuan2_13b_chat_qlora_alpaca_e3.py) | -| ChatGLM3_6B | LORA | [chatglm3_6b_lora_alpaca_e3.py](./xtuner_config/chatglm3_6b_lora_alpaca_e3.py) | -| DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | -| Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | -| LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | -| -| …… | …… | …… | +| Model | Type | File Links | Model Links | +| :-------------------: | :------: | :------------------------------------------------------------------------------------------------------: |:------: | +| InternLM2_7B_chat | QLORA | [internlm2_7b_chat_qlora_e3.py](./xtuner_config/internlm2_7b_chat_qlora_e3.py) | | +| InternLM2_7B_chat | full fine-tuning | [internlm2_chat_7b_full.py](./xtuner_config/internlm2_chat_7b_full.py) | | +| InternLM2_7B_base | QLORA | [internlm2_7b_base_qlora_e10_M_1e4_32_64.py](./xtuner_config/internlm2_7b_base_qlora_e10_M_1e4_32_64.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-InternLM7B-base-10e), [ModelScope](https://www.modelscope.cn/models/chg0901/EmoLLM-InternLM7B-base-10e/summary) | +| InternLM2_1_8B_chat | full fine-tuning | [internlm2_1_8b_full_alpaca_e3.py](./xtuner_config/internlm2_1_8b_full_alpaca_e3.py) | | +| InternLM2_20B_chat | LORA |[internlm2_20b_chat_lora_alpaca_e3.py](./xtuner_config/internlm2_20b_chat_lora_alpaca_e3.py)| | +| Qwen_7b_chat | QLORA | [qwen_7b_chat_qlora_e3.py](./xtuner_config/qwen_7b_chat_qlora_e3.py) | | +| Qwen1_5-0_5B-Chat | full fine-tuning | [qwen1_5_0_5_B_full.py](./xtuner_config/qwen1_5_0_5_B_full.py) | | +| Baichuan2_13B_chat | QLORA | [baichuan2_13b_chat_qlora_alpaca_e3.py](./xtuner_config/baichuan2_13b_chat_qlora_alpaca_e3.py) | | +| ChatGLM3_6B | LORA | [chatglm3_6b_lora_alpaca_e3.py](./xtuner_config/chatglm3_6b_lora_alpaca_e3.py) | | +| DeepSeek MoE_16B_chat | QLORA | [deepseek_moe_16b_chat_qlora_oasst1_e3.py](./xtuner_config/deepseek_moe_16b_chat_qlora_oasst1_e3.py) | | +| Mixtral 8x7B_instruct | QLORA | [mixtral_8x7b_instruct_qlora_oasst1_e3.py](./xtuner_config/mixtral_8x7b_instruct_qlora_oasst1_e3.py) | | +| LLaMA3_8b_instruct | QLORA | [aiwei_llama3_8b_instruct_qlora_e3.py](./xtuner_config/aiwei_llama3_8b_instruct_qlora_e3.py) | | +| LLaMA3_8b_instruct | QLORA | [llama3_8b_instruct_qlora_alpaca_e3_M.py](./xtuner_config/llama3_8b_instruct_qlora_alpaca_e3_M.py) |[OpenXLab](https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct2.0) | +| …… | …… | …… | …… |
From 5d0f60478b7b16d6087e8ea8deba16b94a99fbb5 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 09:04:11 +0900 Subject: [PATCH 04/10] update download_model --- app.py | 1 + download_model.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index dd75608..13458bd 100644 --- a/app.py +++ b/app.py @@ -12,6 +12,7 @@ elif model == "EmoLLM_Model": os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') elif model == "Llama3_Model": os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0") + # os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') else: print("Please select one model") \ No newline at end of file diff --git a/download_model.py b/download_model.py index 9079654..93b18d9 100644 --- a/download_model.py +++ b/download_model.py @@ -58,6 +58,6 @@ os.rmdir(temp_dir) os.remove(output_filename) -download(model_repo='jujimeizuo/EmoLLM_Model', output='model') +download(model_repo=model_repo, output='model') print("Model bin file download complete") \ No newline at end of file From bb3baadb7911e4c30fbc07e06682af0e34dadfe3 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 14:40:55 +0900 Subject: [PATCH 05/10] update port num 2024-04-22 13:19:55.349 Port 7860 is already in use --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 13458bd..604ea04 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,6 @@ elif model == "EmoLLM_Model": elif model == "Llama3_Model": os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0") # os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') - os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') + os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7968') else: print("Please select one model") \ No newline at end of file From 452d3299f31b7a1b4c28ca14538b31dba5ccf3e4 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 17:22:17 +0900 Subject: [PATCH 06/10] Create web_demo-Llama3.py --- web_demo-Llama3.py | 376 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 376 insertions(+) create mode 100644 web_demo-Llama3.py diff --git a/web_demo-Llama3.py b/web_demo-Llama3.py new file mode 100644 index 0000000..06c9d7e --- /dev/null +++ b/web_demo-Llama3.py @@ -0,0 +1,376 @@ + +# isort: skip_file +import copy +import os +import warnings +from dataclasses import asdict, dataclass +from typing import Callable, List, Optional + +import streamlit as st +import torch +from torch import nn +from transformers.generation.utils import (LogitsProcessorList, + StoppingCriteriaList) +from transformers.utils import logging + +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip +from peft import PeftModel + + +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning) +logger = logging.get_logger(__name__) + +if not os.path.isdir("model"): + print("[ERROR] not find model dir") + exit(0) + +# online = True + +## running on local to test online function +# if online: +# from openxlab.model import download +# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', +# output='model') + +@dataclass +class GenerationConfig: + # this config is used for chat to provide more diversity + max_length: int = 500 + top_p: float = 0.9 + temperature: float = 0.6 + do_sample: bool = True + repetition_penalty: float = 1.1 + + +@torch.inference_mode() +def generate_interactive( + model, + tokenizer, + prompt, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], + List[int]]] = None, + additional_eos_token_id: Optional[int] = None, + **kwargs, +): + inputs = tokenizer([prompt], return_tensors='pt') + input_length = len(inputs['input_ids'][0]) + for k, v in inputs.items(): + inputs[k] = v.cuda() + input_ids = inputs['input_ids'] + _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + if generation_config is None: + generation_config = model.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 + generation_config.bos_token_id, + generation_config.eos_token_id, + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if additional_eos_token_id is not None: + eos_token_id.append(additional_eos_token_id) + has_default_max_length = kwargs.get( + 'max_length') is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using 'max_length''s default ({repr(generation_config.max_length)}) \ + to control the generation length. " + 'This behaviour is deprecated and will be removed from the \ + config in v5 of Transformers -- we' + ' recommend using `max_new_tokens` to control the maximum \ + length of the generation.', + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + \ + input_ids_seq_length + if not has_default_max_length: + logger.warn( # pylint: disable=W4902 + f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " + f"and 'max_length'(={generation_config.max_length}) seem to " + "have been set. 'max_new_tokens' will take precedence. " + 'Please refer to the documentation for more information. ' + '(https://huggingface.co/docs/transformers/main/' + 'en/main_classes/text_generation)', + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = 'input_ids' + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, " + f"but 'max_length' is set to {generation_config.max_length}. " + 'This can lead to unexpected behavior. You should consider' + " increasing 'max_new_tokens'.") + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None \ + else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None \ + else StoppingCriteriaList() + + logits_processor = model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = model._get_stopping_criteria( + generation_config=generation_config, + stopping_criteria=stopping_criteria) + logits_warper = model._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = model.prepare_inputs_for_generation( + input_ids, **model_kwargs) + # forward pass to get next token + outputs = model( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = model._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False) + unfinished_sequences = unfinished_sequences.mul( + (min(next_tokens != i for i in eos_token_id)).long()) + + output_token_ids = input_ids[0].cpu().tolist() + output_token_ids = output_token_ids[input_length:] + for each_eos_token_id in eos_token_id: + if output_token_ids[-1] == each_eos_token_id: + output_token_ids = output_token_ids[:-1] + response = tokenizer.decode(output_token_ids) + + yield response + # stop when each sentence is finished + # or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria( + input_ids, scores): + break + + +def on_btn_click(): + del st.session_state.messages + + +# @st.cache_resource +# def load_model(arg1): +# # model = AutoModelForCausalLM.from_pretrained(args.m).cuda() +# # tokenizer = AutoTokenizer.from_pretrained(args.m, trust_remote_code=True) +# model = AutoModelForCausalLM.from_pretrained(arg1, torch_dtype=torch.float16).cuda() +# tokenizer = AutoTokenizer.from_pretrained(arg1, trust_remote_code=True) + + +# return model, tokenizer +@st.cache_resource +def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): + if load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + ) + else: + quantization_config = None + + # 加载base model + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + # load_in_4bit=load_in_4bit, + # # ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time. + trust_remote_code=True, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + device_map='auto', + quantization_config=quantization_config + ) + + # 加载adapter + if adapter_name_or_path is not None: + model = PeftModel.from_pretrained(model, adapter_name_or_path) + + ## 加载tokenzier + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path if adapter_name_or_path is None else adapter_name_or_path, + trust_remote_code=True, + use_fast=False + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return model, tokenizer + + +def prepare_generation_config(): + with st.sidebar: + + # 使用 Streamlit 的 markdown 函数添加 Markdown 文本 + st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True) + st.markdown("[访问 **EmoLLM** 官方repo: **SmartFlowAI/EmoLLM**](https://github.com/SmartFlowAI/EmoLLM)") + + max_length = st.slider('Max Length', + min_value=8, + max_value=8192, + value=500) + top_p = st.slider('Top P', 0.0, 1.0, 0.9, step=0.01) + temperature = st.slider('Temperature', 0.0, 1.0, 0.6, step=0.01) + repetition_penalty = st.slider('Repetition penalty', 0.0, 1.5, 1.1, step=0.01) + st.button('Clear Chat History', on_click=on_btn_click) + + generation_config = GenerationConfig(max_length=max_length, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + do_sample=True) + + return generation_config + + +user_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>' +robot_prompt = '<|start_header_id|>assistant<|end_header_id|>\n\n{robot}<|eot_id|>' +cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + + +def combine_history(prompt): + messages = st.session_state.messages + + meta_instruction = ( + "你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。。" + ) + total_prompt = f"<|start_header_id|>system<|end_header_id|>\n{meta_instruction}<|eot_id|>\n" + + for message in messages: + cur_content = message['content'] + if message['role'] == 'user': + cur_prompt = user_prompt.format(user=cur_content) + elif message['role'] == 'robot': + cur_prompt = robot_prompt.format(robot=cur_content) + else: + raise RuntimeError + total_prompt += cur_prompt + total_prompt = total_prompt + cur_query_prompt.format(user=prompt) + return total_prompt + + +def main(): + + st.markdown("我在这里,准备好倾听你的心声了。", unsafe_allow_html=True) + model_name_or_path = 'model' + adapter_name_or_path = None + # if online: + # model_name_or_path = 'model' + # adapter_name_or_path = None + # else: + # # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3" + # # adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2' + + # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc" + # adapter_name_or_path = None + + # 若开启4bit推理能够节省很多显存,但效果可能下降 + load_in_4bit = False # True # 6291MiB + + # torch.cuda.empty_cache() + print('load model begin.') + # 加载模型 + print(f'Loading model from: {model_name_or_path}') + print(f'adapter_name_or_path: {adapter_name_or_path}') + # model, tokenizer = load_model(arg1) + model, tokenizer = load_model( + # arg1 if arg1 is not None else model_name_or_path, + model_name_or_path, + load_in_4bit=load_in_4bit, + adapter_name_or_path=adapter_name_or_path + ) + model.eval() + print('load model end.') + + user_avator = "assets/user.png" + robot_avator = "assets/EmoLLM.png" + + st.title('EmoLLM Llama3心理咨询室V2.0') + + generation_config = prepare_generation_config() + + # Initialize chat history + if 'messages' not in st.session_state: + st.session_state.messages = [] + + # Display chat messages from history on app rerun + for message in st.session_state.messages: + with st.chat_message(message['role'], avatar=message.get("avatar")): + st.markdown(message['content']) + + # Accept user input + if prompt := st.chat_input('你好,欢迎来到Llama3 EmoLLM 心理咨询室'): + # Display user message in chat message container + with st.chat_message('user', avatar=user_avator): + st.markdown(prompt) + real_prompt = combine_history(prompt) + # Add user message to chat history + st.session_state.messages.append({ + 'role': 'user', + 'content': prompt, + 'avatar': user_avator + }) + + + # stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True) + # assert len(stop_token_id) == 1 + # stop_token_id = stop_token_id[0] + + with st.chat_message('robot', avatar=robot_avator): + message_placeholder = st.empty() + for cur_response in generate_interactive( + model=model, + tokenizer=tokenizer, + prompt=real_prompt, + additional_eos_token_id=128009, # <|eot_id|> + eos_token_id=128009, + pad_token_id=128009, + **asdict(generation_config), + ): + # Display robot response in chat message container + message_placeholder.markdown(cur_response + '▌') + message_placeholder.markdown(cur_response) + # Add robot response to chat history + st.session_state.messages.append({ + 'role': 'robot', + 'content': cur_response, # pylint: disable=undefined-loop-variable + "avatar": robot_avator, + }) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() + \ No newline at end of file From 9f0ea20d434bef0b3ddb295d23e89dc37ca320eb Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 17:34:43 +0900 Subject: [PATCH 07/10] update --- web_demo-Llama3.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/web_demo-Llama3.py b/web_demo-Llama3.py index 06c9d7e..216eb0d 100644 --- a/web_demo-Llama3.py +++ b/web_demo-Llama3.py @@ -188,8 +188,20 @@ def on_btn_click(): # return model, tokenizer + @st.cache_resource -def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): +def load_model(): + model = AutoModelForCausalLM.from_pretrained("model", + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.float16) + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) + return model, tokenizer + + +@st.cache_resource +def load_model0(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, @@ -305,12 +317,15 @@ def main(): print(f'Loading model from: {model_name_or_path}') print(f'adapter_name_or_path: {adapter_name_or_path}') # model, tokenizer = load_model(arg1) - model, tokenizer = load_model( - # arg1 if arg1 is not None else model_name_or_path, - model_name_or_path, - load_in_4bit=load_in_4bit, - adapter_name_or_path=adapter_name_or_path - ) + + model, tokenizer = load_model() + + # model, tokenizer = load_model( + # # arg1 if arg1 is not None else model_name_or_path, + # model_name_or_path, + # load_in_4bit=load_in_4bit, + # adapter_name_or_path=adapter_name_or_path + # ) model.eval() print('load model end.') @@ -343,10 +358,6 @@ def main(): 'avatar': user_avator }) - - # stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True) - # assert len(stop_token_id) == 1 - # stop_token_id = stop_token_id[0] with st.chat_message('robot', avatar=robot_avator): message_placeholder = st.empty() @@ -373,4 +384,3 @@ def main(): if __name__ == "__main__": main() - \ No newline at end of file From b7c33ca5b9e2302a6459630371f224ef41d0ff80 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 17:37:09 +0900 Subject: [PATCH 08/10] update, not use online flag and model_path para --- app.py | 2 +- web_demo-Llama3.py | 20 -------------------- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/app.py b/app.py index 604ea04..d0b5b4b 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,6 @@ elif model == "EmoLLM_Model": elif model == "Llama3_Model": os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0") # os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') - os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7968') + os.system('streamlit run web_demo-Llama3.py --server.address=0.0.0.0 --server.port 7968') else: print("Please select one model") \ No newline at end of file diff --git a/web_demo-Llama3.py b/web_demo-Llama3.py index 216eb0d..8e6501d 100644 --- a/web_demo-Llama3.py +++ b/web_demo-Llama3.py @@ -25,14 +25,6 @@ if not os.path.isdir("model"): print("[ERROR] not find model dir") exit(0) -# online = True - -## running on local to test online function -# if online: -# from openxlab.model import download -# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', -# output='model') - @dataclass class GenerationConfig: # this config is used for chat to provide more diversity @@ -298,19 +290,7 @@ def main(): st.markdown("我在这里,准备好倾听你的心声了。", unsafe_allow_html=True) model_name_or_path = 'model' adapter_name_or_path = None - # if online: - # model_name_or_path = 'model' - # adapter_name_or_path = None - # else: - # # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3" - # # adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2' - - # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc" - # adapter_name_or_path = None - # 若开启4bit推理能够节省很多显存,但效果可能下降 - load_in_4bit = False # True # 6291MiB - # torch.cuda.empty_cache() print('load model begin.') # 加载模型 From e2f5ff1c1914d6a2a71b72d34e97216809b05aa8 Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 17:49:03 +0900 Subject: [PATCH 09/10] add websockts version # gradio 3.50.2 and gradio-client 0.6.1 require websockets<12.0,>=10.0, but you have websockets 12.0 which is incompatible. --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f618b19..95ac855 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ openxlab tiktoken einops oss2 -requests \ No newline at end of file +requests +websockets==11.0.3 From 2b8361ee0581269c9ef0c87fef5a1a0cdc2d3a1c Mon Sep 17 00:00:00 2001 From: HongCheng Date: Mon, 22 Apr 2024 17:55:28 +0900 Subject: [PATCH 10/10] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=BA=8C=E7=BB=B4?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 356c290..ac71210 100644 --- a/README.md +++ b/README.md @@ -346,5 +346,5 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git - 如果失效,请移步Issue区

- EmoLLM官方交流群 + EmoLLM官方交流群

\ No newline at end of file