olivebot/llm/nlp_rwkv.py

29 lines
930 B
Python
Raw Permalink Normal View History

import torch
from ringrwkv.configuration_rwkv_world import RwkvConfig
from ringrwkv.rwkv_tokenizer import TRIE_TOKENIZER
from ringrwkv.modehf_world import RwkvForCausalLM
model = RwkvForCausalLM.from_pretrained("RWKV-4-World-1.5B")
#model = RwkvForCausalLM.from_pretrained("RWKV-4-World-3B")
#model = RwkvForCausalLM.from_pretrained("RWKV-4-World-0.4B")
tokenizer = TRIE_TOKENIZER('./ringrwkv/rwkv_vocab_v20230424.txt')
data = ""
def question(cont, uid=0, observation=""):
global data
prompt = data + f'Question: {cont.strip()}\n\nAnswer:'
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids).unsqueeze(0)
out = model.generate(input_ids,max_new_tokens=20)
outlist = out[0].tolist()
for i in outlist:
if i==0:
outlist.remove(i)
answer = tokenizer.decode(outlist)
# data = answer + "\n\n"
answer = answer.replace(prompt, "", 1)
return answer