diff --git a/app.py b/app.py index bebcaf3..dd75608 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,8 @@ import os #model = "EmoLLM_aiwei" -model = "EmoLLM_Model" +# model = "EmoLLM_Model" +model = "Llama3_Model" if model == "EmoLLM_aiwei": os.system("python download_model.py ajupyter/EmoLLM_aiwei") @@ -9,5 +10,8 @@ if model == "EmoLLM_aiwei": elif model == "EmoLLM_Model": os.system("python download_model.py jujimeizuo/EmoLLM_Model") os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860') +elif model == "Llama3_Model": + os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0") + os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860') else: print("Please select one model") \ No newline at end of file diff --git a/web_demo-Llama3_online.py b/web_demo-Llama3_online.py index e80f454..c2c41c2 100644 --- a/web_demo-Llama3_online.py +++ b/web_demo-Llama3_online.py @@ -1,6 +1,7 @@ # isort: skip_file import copy +import os import warnings from dataclasses import asdict, dataclass from typing import Callable, List, Optional @@ -15,17 +16,22 @@ from transformers.utils import logging from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip from peft import PeftModel -import warnings warnings.filterwarnings("ignore") warnings.filterwarnings("ignore", category=DeprecationWarning) logger = logging.get_logger(__name__) +if not os.path.isdir("model"): + print("[ERROR] not find model dir") + exit(0) + online = True -if online: - from openxlab.model import download - download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', - output='model') + +## running on local to test online function +# if online: +# from openxlab.model import download +# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0', +# output='model') @dataclass class GenerationConfig: @@ -275,17 +281,17 @@ def combine_history(prompt): return total_prompt -def main(arg1): +def main(arg1=None): if online: model_name_or_path = 'model' adapter_name_or_path = None else: - # model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3" - # adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2' + # model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3" + # adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2' - model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc" + model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc" adapter_name_or_path = None # 若开启4bit推理能够节省很多显存,但效果可能下降 @@ -364,6 +370,9 @@ def main(arg1): if __name__ == '__main__': - import sys - arg1 = sys.argv[1] - main(arg1) + if online: + main() + else: + import sys + arg1 = sys.argv[1] + main(arg1)