update llama3 webdemo
This commit is contained in:
parent
f314baf75b
commit
5c4d4de9d7
6
app.py
6
app.py
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
#model = "EmoLLM_aiwei"
|
#model = "EmoLLM_aiwei"
|
||||||
model = "EmoLLM_Model"
|
# model = "EmoLLM_Model"
|
||||||
|
model = "Llama3_Model"
|
||||||
|
|
||||||
if model == "EmoLLM_aiwei":
|
if model == "EmoLLM_aiwei":
|
||||||
os.system("python download_model.py ajupyter/EmoLLM_aiwei")
|
os.system("python download_model.py ajupyter/EmoLLM_aiwei")
|
||||||
@ -9,5 +10,8 @@ if model == "EmoLLM_aiwei":
|
|||||||
elif model == "EmoLLM_Model":
|
elif model == "EmoLLM_Model":
|
||||||
os.system("python download_model.py jujimeizuo/EmoLLM_Model")
|
os.system("python download_model.py jujimeizuo/EmoLLM_Model")
|
||||||
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
||||||
|
elif model == "Llama3_Model":
|
||||||
|
os.system("python download_model.py chg0901/EmoLLM-Llama3-8B-Instruct2.0")
|
||||||
|
os.system('streamlit run web_demo-Llama3_online.py --server.address=0.0.0.0 --server.port 7860')
|
||||||
else:
|
else:
|
||||||
print("Please select one model")
|
print("Please select one model")
|
@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
# isort: skip_file
|
# isort: skip_file
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
@ -15,17 +16,22 @@ from transformers.utils import logging
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # isort: skip
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if not os.path.isdir("model"):
|
||||||
|
print("[ERROR] not find model dir")
|
||||||
|
exit(0)
|
||||||
|
|
||||||
online = True
|
online = True
|
||||||
if online:
|
|
||||||
from openxlab.model import download
|
## running on local to test online function
|
||||||
download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
|
# if online:
|
||||||
output='model')
|
# from openxlab.model import download
|
||||||
|
# download(model_repo='chg0901/EmoLLM-Llama3-8B-Instruct2.0',
|
||||||
|
# output='model')
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerationConfig:
|
class GenerationConfig:
|
||||||
@ -275,17 +281,17 @@ def combine_history(prompt):
|
|||||||
return total_prompt
|
return total_prompt
|
||||||
|
|
||||||
|
|
||||||
def main(arg1):
|
def main(arg1=None):
|
||||||
|
|
||||||
|
|
||||||
if online:
|
if online:
|
||||||
model_name_or_path = 'model'
|
model_name_or_path = 'model'
|
||||||
adapter_name_or_path = None
|
adapter_name_or_path = None
|
||||||
else:
|
else:
|
||||||
# model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e3"
|
# model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e3"
|
||||||
# adapter_name_or_path = '/root/StableCascade/emollm2/EmoLLM/xtuner_config/hf_llama3_e1_sc2'
|
# adapter_name_or_path = './xtuner_config/hf_llama3_e1_sc2'
|
||||||
|
|
||||||
model_name_or_path = "/root/StableCascade/emollm2/EmoLLM/xtuner_config/merged_Llama3_8b_instruct_e1_sc"
|
model_name_or_path = "./xtuner_config/merged_Llama3_8b_instruct_e1_sc"
|
||||||
adapter_name_or_path = None
|
adapter_name_or_path = None
|
||||||
|
|
||||||
# 若开启4bit推理能够节省很多显存,但效果可能下降
|
# 若开启4bit推理能够节省很多显存,但效果可能下降
|
||||||
@ -364,6 +370,9 @@ def main(arg1):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
if online:
|
||||||
|
main()
|
||||||
|
else:
|
||||||
import sys
|
import sys
|
||||||
arg1 = sys.argv[1]
|
arg1 = sys.argv[1]
|
||||||
main(arg1)
|
main(arg1)
|
||||||
|
Loading…
Reference in New Issue
Block a user