diff --git a/app_Llama3_Gradio.py b/app_Llama3_Gradio.py new file mode 100644 index 0000000..8cb71ed --- /dev/null +++ b/app_Llama3_Gradio.py @@ -0,0 +1,177 @@ +import gradio as gr +import os +import torch +from transformers import GemmaTokenizer, AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer +from threading import Thread + + +DESCRIPTION = ''' +
+

EmoLLM Llama3 心理咨询室 V4.0

+ +

+ + Logo + +

+ +
+ + +[![OpenXLab_Model][OpenXLab_Model-image]][OpenXLab_Model-url] + +

EmoLLM是一系列能够支持 理解用户-支持用户-帮助用户 心理健康辅导链路的 心理健康大模型 ,欢迎大家star~⭐⭐

+

https://github.com/SmartFlowAI/EmoLLM

+
+ +
+ +[OpenXLab_Model-image]: https://cdn-static.openxlab.org.cn/header/openxlab_models.svg +[OpenXLab_Model-url]: https://openxlab.org.cn/models/detail/chg0901/EmoLLM-Llama3-8B-Instruct3.0 + +''' + +LICENSE = """ +

Built with Meta Llama 3 +""" + +PLACEHOLDER = """ +

+ +
+""" + + +css = """ +h1 { + text-align: center; + display: block; +} + +""" + +# download internlm2 to the base_path directory using git tool +base_path = './EmoLLM-Llama3-8B-Instruct3.0' +os.system(f'git clone https://code.openxlab.org.cn/chg0901/EmoLLM-Llama3-8B-Instruct3.0.git {base_path}') +os.system(f'cd {base_path} && git lfs pull') + + +# Load the tokenizer and model +tokenizer = AutoTokenizer.from_pretrained(base_path,trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(base_path,trust_remote_code=True, device_map="auto", torch_dtype=torch.float16).eval() # to("cuda:0") +terminators = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>") +] + +def chat_llama3_8b(message: str, + history: list, + temperature: float, + max_new_tokens: int, + top_p: float + ) -> str: + """ + Generate a streaming response using the llama3-8b model. + Args: + message (str): The input message. + history (list): The conversation history used by ChatInterface. + temperature (float): The temperature for generating the response. + max_new_tokens (int): The maximum number of new tokens to generate. + Returns: + str: The generated response. + """ + conversation = [] + + for user, assistant in history: + conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) + conversation.append({"role": "user", "content": message}) + + input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) + + streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + + generate_kwargs = dict( + input_ids= input_ids, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p = top_p, + eos_token_id=terminators, + ) + # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash. + if temperature == 0: + generate_kwargs['do_sample'] = False + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + outputs = [] + for text in streamer: + outputs.append(text) + yield "".join(outputs) + + + +# Gradio block +chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='EmoLLM Chat') + +with gr.Blocks(fill_height=True, css=css) as demo: + + gr.Markdown(DESCRIPTION) + # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") + gr.ChatInterface( + fn=chat_llama3_8b, + chatbot=chatbot, + fill_height=True, + additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), + additional_inputs=[ + gr.Slider(minimum=0, + maximum=1, + step=0.1, + value=0.95, + label="Temperature", + render=False), + gr.Slider(minimum=128, + maximum=4096, + step=1, + value=4096, + label="Max new tokens", + render=False ), + gr.Slider(minimum=0.0, + maximum=1, + step=0.01, + value=0.8, + label="Top P", + render=False ), + # gr.Slider(minimum=128, + # maximum=4096, + # step=1, + # value=512, + # label="Max new tokens", + # render=False ), + ], + examples=[ + ['请介绍你自己。'], + ['我觉得我在学校的学习压力好大啊,虽然我真的很喜欢我的专业,但最近总是担心自己无法达到自己的期望,这让我有点焦虑。'], + ['我最近总觉得自己在感情上陷入了困境,我喜欢上了我的朋友,但又害怕表达出来会破坏我们现在的关系...'], + ['我感觉自己像是被困在一个无尽的循环中。每天醒来都感到身体沉重,对日常活动提不起兴趣,工作、锻炼甚至是我曾经喜欢的事物都让我觉得厌倦'], + ['最近工作压力特别大,还有一些家庭矛盾'] + ], + cache_examples=False, + ) + + gr.Markdown(LICENSE) + +if __name__ == "__main__": + demo.launch() + + \ No newline at end of file diff --git a/app_web_demo-Llama3.py b/app_web_demo-Llama3.py index bd5a6a8..2f4f678 100644 --- a/app_web_demo-Llama3.py +++ b/app_web_demo-Llama3.py @@ -207,7 +207,7 @@ cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|> def combine_history(prompt): messages = st.session_state.messages meta_instruction = ( - "你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。\n\n" + "你是心理健康助手EmoLLM, 由EmoLLM团队打造。\n\n" ) total_prompt =f"<|start_header_id|>system<|end_header_id|>\n\n{meta_instruction}<|eot_id|>\n\n" for message in messages: diff --git a/web_demo-Llama3.py b/web_demo-Llama3.py index 391b0b4..6cf2eee 100644 --- a/web_demo-Llama3.py +++ b/web_demo-Llama3.py @@ -205,7 +205,7 @@ cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|> def combine_history(prompt): messages = st.session_state.messages meta_instruction = ( - "你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。\n\n" + "你是心理健康助手EmoLLM, 由EmoLLM团队打造\n\n" ) total_prompt =f"<|start_header_id|>system<|end_header_id|>\n\n{meta_instruction}<|eot_id|>\n\n" for message in messages: