Update web_demo-Llama3.py

This commit is contained in:
HongCheng 2024-05-04 10:02:10 +09:00
parent 408784289d
commit d0aaf31ff6

View File

@ -151,20 +151,29 @@ def on_btn_click():
@st.cache_resource @st.cache_resource
def load_model(): def load_model():
model_name0 = "./EmoLLM-Llama3-8B-Instruct3.0" # model_name0 = "./EmoLLM-Llama3-8B-Instruct3.0"
print(model_name0) # print(model_name0)
print('pip install modelscope websockets') # print('pip install modelscope websockets')
os.system(f'pip install modelscope websockets==11.0.3') # os.system(f'pip install modelscope websockets==11.0.3')
from modelscope import snapshot_download # from modelscope import snapshot_download
#模型下载 # #模型下载
model_name = snapshot_download('chg0901/EmoLLM-Llama3-8B-Instruct3.0',cache_dir=model_name0) # model_name = snapshot_download('chg0901/EmoLLM-Llama3-8B-Instruct3.0',cache_dir=model_name0)
print(model_name) # print(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
# # model.eval()
# tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
base_path = './EmoLLM-Llama3-8B-Instruct3.0'
os.system(f'git clone https://code.openxlab.org.cn/chg0901/EmoLLM-Llama3-8B-Instruct3.0.git {base_path}')
os.system(f'cd {base_path} && git lfs pull')
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval() model = AutoModelForCausalLM.from_pretrained(base_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16).eval()
# model.eval() # model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(base_path, trust_remote_code=True)
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token