OliveSensorAPI/web_demo-Llama3.py

273 lines
11 KiB
Python
Raw Normal View History

2024-04-22 16:22:17 +08:00
import copy
import os
import warnings
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
import streamlit as st
import torch
from torch import nn
2024-04-23 15:02:39 +08:00
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
2024-04-22 16:22:17 +08:00
from transformers.utils import logging
2024-04-23 15:02:39 +08:00
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
2024-04-22 16:22:17 +08:00
warnings.filterwarnings("ignore")
logger = logging.get_logger(__name__)
@dataclass
class GenerationConfig:
# this config is used for chat to provide more diversity
2024-04-23 15:02:39 +08:00
max_length: int = 32768
top_p: float = 0.8
temperature: float = 0.8
2024-04-22 16:22:17 +08:00
do_sample: bool = True
2024-04-23 15:02:39 +08:00
repetition_penalty: float = 1.005
2024-04-22 16:22:17 +08:00
@torch.inference_mode()
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
2024-04-23 15:02:39 +08:00
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
2024-04-22 16:22:17 +08:00
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
2024-04-23 15:02:39 +08:00
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
2024-04-22 16:22:17 +08:00
for k, v in inputs.items():
inputs[k] = v.cuda()
2024-04-23 15:02:39 +08:00
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
2024-04-22 16:22:17 +08:00
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
2024-04-23 15:02:39 +08:00
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
2024-04-22 16:22:17 +08:00
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
2024-04-23 15:02:39 +08:00
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
2024-04-22 16:22:17 +08:00
UserWarning,
)
elif generation_config.max_new_tokens is not None:
2024-04-23 15:02:39 +08:00
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
2024-04-22 16:22:17 +08:00
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
2024-04-23 15:02:39 +08:00
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
2024-04-22 16:22:17 +08:00
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
2024-04-23 15:02:39 +08:00
input_ids_string = "input_ids"
2024-04-22 16:22:17 +08:00
logger.warning(
2024-04-23 15:02:39 +08:00
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
2024-04-22 16:22:17 +08:00
# 2. Set generation parameters if not already defined
2024-04-23 15:02:39 +08:00
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2024-04-22 16:22:17 +08:00
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
2024-04-23 15:02:39 +08:00
generation_config=generation_config, stopping_criteria=stopping_criteria
)
2024-04-22 16:22:17 +08:00
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
2024-04-23 15:02:39 +08:00
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
2024-04-22 16:22:17 +08:00
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
2024-04-23 15:02:39 +08:00
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
2024-04-22 16:22:17 +08:00
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)
yield response
2024-04-23 15:02:39 +08:00
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
2024-04-22 16:22:17 +08:00
break
def on_btn_click():
del st.session_state.messages
2024-04-22 16:34:43 +08:00
@st.cache_resource
def load_model():
2024-04-23 15:02:39 +08:00
model_name = "./EmoLLM-Llama3-8B-Instruct2.0"
print(model_name)
2024-04-22 16:34:43 +08:00
print('pip install modelscope websockets')
os.system(f'pip install modelscope websockets==11.0.3')
2024-04-23 15:02:39 +08:00
from modelscope import snapshot_download
2024-04-22 16:22:17 +08:00
2024-04-23 15:02:39 +08:00
#模型下载
model_name = snapshot_download('chg0901/EmoLLM-Llama3-8B-Instruct2.0',cache_dir=model_name)
print(model_name)
2024-04-22 16:22:17 +08:00
2024-04-23 15:02:39 +08:00
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
# model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
2024-04-22 16:22:17 +08:00
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
2024-04-23 15:02:39 +08:00
2024-04-22 16:22:17 +08:00
return model, tokenizer
def prepare_generation_config():
with st.sidebar:
# 使用 Streamlit 的 markdown 函数添加 Markdown 文本
st.image('assets/EmoLLM_logo_L.png', width=1, caption='EmoLLM Logo', use_column_width=True)
st.markdown("[访问 **EmoLLM** 官方repo: **SmartFlowAI/EmoLLM**](https://github.com/SmartFlowAI/EmoLLM)")
2024-04-23 15:02:39 +08:00
max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
st.button("Clear Chat History", on_click=on_btn_click)
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
2024-04-22 16:22:17 +08:00
return generation_config
user_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>'
robot_prompt = '<|start_header_id|>assistant<|end_header_id|>\n\n{robot}<|eot_id|>'
cur_query_prompt = '<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
def combine_history(prompt):
messages = st.session_state.messages
meta_instruction = (
2024-04-23 15:02:39 +08:00
"你是心理健康助手EmoLLM, 由EmoLLM团队打造, 是一个研究过无数具有心理健康问题的病人与心理健康医生对话的心理专家, 在心理方面拥有广博的知识储备和丰富的研究咨询经验。你旨在通过专业心理咨询, 协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术, 一步步帮助来访者解决心理问题。\n\n"
2024-04-22 16:22:17 +08:00
)
2024-04-23 15:02:39 +08:00
total_prompt =f"<|start_header_id|>system<|end_header_id|>\n\n{meta_instruction}<|eot_id|>\n\n"
2024-04-22 16:22:17 +08:00
for message in messages:
2024-04-23 15:02:39 +08:00
cur_content = message["content"]
if message["role"] == "user":
2024-04-22 16:22:17 +08:00
cur_prompt = user_prompt.format(user=cur_content)
2024-04-23 15:02:39 +08:00
elif message["role"] == "robot":
2024-04-22 16:22:17 +08:00
cur_prompt = robot_prompt.format(robot=cur_content)
else:
raise RuntimeError
total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
return total_prompt
def main():
2024-04-23 15:02:39 +08:00
torch.cuda.empty_cache()
print("load model begin.")
2024-04-22 16:34:43 +08:00
model, tokenizer = load_model()
2024-04-23 15:02:39 +08:00
print("load model end.")
2024-04-22 16:22:17 +08:00
user_avator = "assets/user.png"
robot_avator = "assets/EmoLLM.png"
2024-04-23 15:02:39 +08:00
st.title("EmoLLM Llama3心理咨询室V2.0")
2024-04-22 16:22:17 +08:00
generation_config = prepare_generation_config()
# Initialize chat history
2024-04-23 15:02:39 +08:00
if "messages" not in st.session_state:
2024-04-22 16:22:17 +08:00
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
2024-04-23 15:02:39 +08:00
with st.chat_message(message["role"], avatar=message.get("avatar")):
st.markdown(message["content"])
2024-04-22 16:22:17 +08:00
# Accept user input
2024-04-23 15:02:39 +08:00
if prompt := st.chat_input("我在这里,准备好倾听你的心声了。"):
2024-04-22 16:22:17 +08:00
# Display user message in chat message container
2024-04-23 15:02:39 +08:00
with st.chat_message("user", avatar=user_avator):
2024-04-22 16:22:17 +08:00
st.markdown(prompt)
2024-04-23 15:02:39 +08:00
2024-04-22 16:22:17 +08:00
real_prompt = combine_history(prompt)
# Add user message to chat history
2024-04-23 15:02:39 +08:00
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
2024-04-22 16:22:17 +08:00
2024-04-23 15:02:39 +08:00
with st.chat_message("robot", avatar=robot_avator):
2024-04-22 16:22:17 +08:00
message_placeholder = st.empty()
for cur_response in generate_interactive(
2024-04-23 15:02:39 +08:00
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=128009,
**asdict(generation_config),
2024-04-22 16:22:17 +08:00
):
# Display robot response in chat message container
2024-04-23 15:02:39 +08:00
message_placeholder.markdown(cur_response + "")
message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable
2024-04-22 16:22:17 +08:00
# Add robot response to chat history
2024-04-23 15:02:39 +08:00
st.session_state.messages.append(
{
"role": "robot",
"content": cur_response, # pylint: disable=undefined-loop-variable
"avatar": robot_avator,
}
)
2024-04-22 16:22:17 +08:00
torch.cuda.empty_cache()
if __name__ == "__main__":
main()