This commit is contained in:
HongCheng 2024-04-22 17:34:43 +09:00
parent 452d3299f3
commit 9f0ea20d43

View File

@ -188,8 +188,20 @@ def on_btn_click():
# return model, tokenizer
@st.cache_resource
def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
def load_model():
model = AutoModelForCausalLM.from_pretrained("model",
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
return model, tokenizer
@st.cache_resource
def load_model0(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
@ -305,12 +317,15 @@ def main():
print(f'Loading model from: {model_name_or_path}')
print(f'adapter_name_or_path: {adapter_name_or_path}')
# model, tokenizer = load_model(arg1)
model, tokenizer = load_model(
# arg1 if arg1 is not None else model_name_or_path,
model_name_or_path,
load_in_4bit=load_in_4bit,
adapter_name_or_path=adapter_name_or_path
)
model, tokenizer = load_model()
# model, tokenizer = load_model(
# # arg1 if arg1 is not None else model_name_or_path,
# model_name_or_path,
# load_in_4bit=load_in_4bit,
# adapter_name_or_path=adapter_name_or_path
# )
model.eval()
print('load model end.')
@ -344,10 +359,6 @@ def main():
})
# stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True)
# assert len(stop_token_id) == 1
# stop_token_id = stop_token_id[0]
with st.chat_message('robot', avatar=robot_avator):
message_placeholder = st.empty()
for cur_response in generate_interactive(
@ -373,4 +384,3 @@ def main():
if __name__ == "__main__":
main()