update llama3 webdemo

This commit is contained in:
HongCheng 2024-04-22 00:05:11 +09:00
parent f314baf75b
commit 5c4d4de9d7
2 changed files with 26 additions and 13 deletions

6
app.py
View File

@ -1,7 +1,8 @@
import os
#model = "EmoLLM_aiwei"
model = "EmoLLM_Model"
# model = "EmoLLM_Model"
model = "Llama3_Model"
if model == "EmoLLM_aiwei":
os.system("python download_model.py ajupyter/EmoLLM_aiwei")
@ -9,5 +10,8 @@ if model == "EmoLLM_aiwei":
elif model == "EmoLLM_Model":
os.system("python download_model.py jujimeizuo/EmoLLM_Model")
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
elif model == "Llama3_Model":
os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0")
os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860')
else:
print("Please select one model")

View File

@ -1,6 +1,7 @@
# isort: skip_file
import copy
import os
import warnings
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
@ -15,17 +16,22 @@ from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip
from peft import PeftModel
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)
logger = logging.get_logger(__name__)
if not os.path.isdir("model"):
print("[ERROR] not find model dir")
exit(0)
online = True
if online:
from openxlab.model import download
download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
output='model')
## running on local to test online function
# if online:
# from openxlab.model import download
# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
# output='model')
@dataclass
class GenerationConfig:
@ -275,17 +281,17 @@ def combine_history(prompt):
return total_prompt
def main(arg1):
def main(arg1=None):
if online:
model_name_or_path = 'model'
adapter_name_or_path = None
else:
# model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3"
# adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2'
# model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3"
# adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2'
model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc"
model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc"
adapter_name_or_path = None
# 若开启4bit推理能够节省很多显存但效果可能下降
@ -364,6 +370,9 @@ def main(arg1):
if __name__ == '__main__':
if online:
main()
else:
import sys
arg1 = sys.argv[1]
main(arg1)