update webdemo llama3
This commit is contained in:
parent
ac18474b1d
commit
f6dec55c06
2
app.py
2
app.py
@ -11,6 +11,6 @@ elif model == "EmoLLM_Model":
|
|||||||
os.system("python download_model.py jujimeizuo/EmoLLM_Model")
|
os.system("python download_model.py jujimeizuo/EmoLLM_Model")
|
||||||
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
||||||
elif model == "Llama3_Model":
|
elif model == "Llama3_Model":
|
||||||
os.system('streamlit run web_demo-Llama3.py --server.address=0.0.0.0 --server.port 7968')
|
os.system('streamlit run web_demo_Llama3.py --server.address=0.0.0.0 --server.port 7968')
|
||||||
else:
|
else:
|
||||||
print("Please select one model")
|
print("Please select one model")
|
@ -13,7 +13,7 @@ from transformers.utils import logging
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
# warnings.filterwarnings("ignore")
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ def on_btn_click():
|
|||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_model():
|
def load_model():
|
||||||
|
|
||||||
model_name = "./EmoLLM-Llama3-8B-Instruct2.0"
|
model_name = "./EmoLLM-Llama3-8B-Instruct3.0"
|
||||||
print(model_name)
|
print(model_name)
|
||||||
|
|
||||||
print('pip install modelscope websockets')
|
print('pip install modelscope websockets')
|
||||||
@ -159,7 +159,7 @@ def load_model():
|
|||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
#模型下载
|
#模型下载
|
||||||
model_name = snapshot_download('chg0901/EmoLLM-Llama3-8B-Instruct2.0',cache_dir=model_name)
|
model_name = snapshot_download('chg0901/EmoLLM-Llama3-8B-Instruct3.0',cache_dir=model_name)
|
||||||
print(model_name)
|
print(model_name)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
|
Loading…
Reference in New Issue
Block a user