update llama3 webdemo
This commit is contained in:
parent
f314baf75b
commit
5c4d4de9d7
6
app.py
6
app.py
@ -1,7 +1,8 @@
|
||||
import os
|
||||
|
||||
#model = "EmoLLM_aiwei"
|
||||
model = "EmoLLM_Model"
|
||||
# model = "EmoLLM_Model"
|
||||
model = "Llama3_Model"
|
||||
|
||||
if model == "EmoLLM_aiwei":
|
||||
os.system("python download_model.py ajupyter/EmoLLM_aiwei")
|
||||
@ -9,5 +10,8 @@ if model == "EmoLLM_aiwei":
|
||||
elif model == "EmoLLM_Model":
|
||||
os.system("python download_model.py jujimeizuo/EmoLLM_Model")
|
||||
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
||||
elif model == "Llama3_Model":
|
||||
os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0")
|
||||
os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860')
|
||||
else:
|
||||
print("Please select one model")
|
@ -1,6 +1,7 @@
|
||||
|
||||
# isort: skip_file
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
@ -15,17 +16,22 @@ from transformers.utils import logging
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip
|
||||
from peft import PeftModel
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not os.path.isdir("model"):
|
||||
print("[ERROR] not find model dir")
|
||||
exit(0)
|
||||
|
||||
online = True
|
||||
if online:
|
||||
from openxlab.model import download
|
||||
download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
|
||||
output='model')
|
||||
|
||||
## running on local to test online function
|
||||
# if online:
|
||||
# from openxlab.model import download
|
||||
# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
|
||||
# output='model')
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
@ -275,17 +281,17 @@ def combine_history(prompt):
|
||||
return total_prompt
|
||||
|
||||
|
||||
def main(arg1):
|
||||
def main(arg1=None):
|
||||
|
||||
|
||||
if online:
|
||||
model_name_or_path = 'model'
|
||||
adapter_name_or_path = None
|
||||
else:
|
||||
# model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3"
|
||||
# adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2'
|
||||
# model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3"
|
||||
# adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2'
|
||||
|
||||
model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc"
|
||||
model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc"
|
||||
adapter_name_or_path = None
|
||||
|
||||
# 若开启4bit推理能够节省很多显存,但效果可能下降
|
||||
@ -364,6 +370,9 @@ def main(arg1):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if online:
|
||||
main()
|
||||
else:
|
||||
import sys
|
||||
arg1 = sys.argv[1]
|
||||
main(arg1)
|
||||
|
Loading…
Reference in New Issue
Block a user