diff --git a/web_demo-Llama3.py b/web_demo-Llama3.py index 06c9d7e..216eb0d 100644 --- a/web_demo-Llama3.py +++ b/web_demo-Llama3.py @@ -188,8 +188,20 @@ def on_btn_click(): # return model, tokenizer + @st.cache_resource -def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): +def load_model(): + model = AutoModelForCausalLM.from_pretrained("model", + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.float16) + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True) + return model, tokenizer + + +@st.cache_resource +def load_model0(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None): if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, @@ -305,12 +317,15 @@ def main(): print(f'Loading model from: {model_name_or_path}') print(f'adapter_name_or_path: {adapter_name_or_path}') # model, tokenizer = load_model(arg1) - model, tokenizer = load_model( - # arg1 if arg1 is not None else model_name_or_path, - model_name_or_path, - load_in_4bit=load_in_4bit, - adapter_name_or_path=adapter_name_or_path - ) + + model, tokenizer = load_model() + + # model, tokenizer = load_model( + # # arg1 if arg1 is not None else model_name_or_path, + # model_name_or_path, + # load_in_4bit=load_in_4bit, + # adapter_name_or_path=adapter_name_or_path + # ) model.eval() print('load model end.') @@ -343,10 +358,6 @@ def main(): 'avatar': user_avator }) - - # stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True) - # assert len(stop_token_id) == 1 - # stop_token_id = stop_token_id[0] with st.chat_message('robot', avatar=robot_avator): message_placeholder = st.empty() @@ -373,4 +384,3 @@ def main(): if __name__ == "__main__": main() - \ No newline at end of file