update
This commit is contained in:
parent
452d3299f3
commit
9f0ea20d43
@ -188,8 +188,20 @@ def on_btn_click():
|
|||||||
|
|
||||||
|
|
||||||
# return model, tokenizer
|
# return model, tokenizer
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_model(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
|
def load_model():
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("model",
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16)
|
||||||
|
model = model.eval()
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def load_model0(model_name_or_path, load_in_4bit=False, adapter_name_or_path=None):
|
||||||
if load_in_4bit:
|
if load_in_4bit:
|
||||||
quantization_config = BitsAndBytesConfig(
|
quantization_config = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
@ -305,12 +317,15 @@ def main():
|
|||||||
print(f'Loading model from: {model_name_or_path}')
|
print(f'Loading model from: {model_name_or_path}')
|
||||||
print(f'adapter_name_or_path: {adapter_name_or_path}')
|
print(f'adapter_name_or_path: {adapter_name_or_path}')
|
||||||
# model, tokenizer = load_model(arg1)
|
# model, tokenizer = load_model(arg1)
|
||||||
model, tokenizer = load_model(
|
|
||||||
# arg1 if arg1 is not None else model_name_or_path,
|
model, tokenizer = load_model()
|
||||||
model_name_or_path,
|
|
||||||
load_in_4bit=load_in_4bit,
|
# model, tokenizer = load_model(
|
||||||
adapter_name_or_path=adapter_name_or_path
|
# # arg1 if arg1 is not None else model_name_or_path,
|
||||||
)
|
# model_name_or_path,
|
||||||
|
# load_in_4bit=load_in_4bit,
|
||||||
|
# adapter_name_or_path=adapter_name_or_path
|
||||||
|
# )
|
||||||
model.eval()
|
model.eval()
|
||||||
print('load model end.')
|
print('load model end.')
|
||||||
|
|
||||||
@ -344,10 +359,6 @@ def main():
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# stop_token_id = tokenizer.encode('<|eot_id|>', add_special_tokens=True)
|
|
||||||
# assert len(stop_token_id) == 1
|
|
||||||
# stop_token_id = stop_token_id[0]
|
|
||||||
|
|
||||||
with st.chat_message('robot', avatar=robot_avator):
|
with st.chat_message('robot', avatar=robot_avator):
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
for cur_response in generate_interactive(
|
for cur_response in generate_interactive(
|
||||||
@ -373,4 +384,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user