update llama3 webdemo

This commit is contained in:
HongCheng 2024-04-22 00:05:11 +09:00
parent f314baf75b
commit 5c4d4de9d7
2 changed files with 26 additions and 13 deletions

6
app.py
View File

@ -1,7 +1,8 @@
import os import os
#model = "EmoLLM_aiwei" #model = "EmoLLM_aiwei"
model = "EmoLLM_Model" # model = "EmoLLM_Model"
model = "Llama3_Model"
if model == "EmoLLM_aiwei": if model == "EmoLLM_aiwei":
os.system("python download_model.py ajupyter/EmoLLM_aiwei") os.system("python download_model.py ajupyter/EmoLLM_aiwei")
@ -9,5 +10,8 @@ if model == "EmoLLM_aiwei":
elif model == "EmoLLM_Model": elif model == "EmoLLM_Model":
os.system("python download_model.py jujimeizuo/EmoLLM_Model") os.system("python download_model.py jujimeizuo/EmoLLM_Model")
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
elif model == "Llama3_Model":
os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0")
os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860')
else: else:
print("Please select one model") print("Please select one model")

View File

@ -1,6 +1,7 @@
# isort: skip_file # isort: skip_file
import copy import copy
import os
import warnings import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Callable, List, Optional from typing import Callable, List, Optional
@ -15,17 +16,22 @@ from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip
from peft import PeftModel from peft import PeftModel
import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if not os.path.isdir("model"):
print("[ERROR] not find model dir")
exit(0)
online = True online = True
if online:
from openxlab.model import download ## running on local to test online function
download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', # if online:
output='model') # from openxlab.model import download
# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
# output='model')
@dataclass @dataclass
class GenerationConfig: class GenerationConfig:
@ -275,17 +281,17 @@ def combine_history(prompt):
return total_prompt return total_prompt
def main(arg1): def main(arg1=None):
if online: if online:
model_name_or_path = 'model' model_name_or_path = 'model'
adapter_name_or_path = None adapter_name_or_path = None
else: else:
# model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3" # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3"
# adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2' # adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2'
model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc" model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc"
adapter_name_or_path = None adapter_name_or_path = None
# 若开启4bit推理能够节省很多显存但效果可能下降 # 若开启4bit推理能够节省很多显存但效果可能下降
@ -364,6 +370,9 @@ def main(arg1):
if __name__ == '__main__': if __name__ == '__main__':
if online:
main()
else:
import sys import sys
arg1 = sys.argv[1] arg1 = sys.argv[1]
main(arg1) main(arg1)