[Code] update rag (#122)

This commit is contained in:
xzw 2024-03-22 10:06:12 +08:00 committed by GitHub
commit ad7329d113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 40 deletions

View File

@ -33,5 +33,5 @@ prompt_template = """
{system_prompt} {system_prompt}
根据下面检索回来的信息回答问题 根据下面检索回来的信息回答问题
{content} {content}
问题{question} 问题{query}
""" """

View File

@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG # from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
@ -91,13 +91,13 @@ class Data_process():
if isinstance(obj, dict): if isinstance(obj, dict):
for key, value in obj.items(): for key, value in obj.items():
try: try:
self.extract_text_from_json(value, content) content = self.extract_text_from_json(value, content)
except Exception as e: except Exception as e:
print(f"Error processing value: {e}") print(f"Error processing value: {e}")
elif isinstance(obj, list): elif isinstance(obj, list):
for index, item in enumerate(obj): for index, item in enumerate(obj):
try: try:
self.extract_text_from_json(item, content) content = self.extract_text_from_json(item, content)
except Exception as e: except Exception as e:
print(f"Error processing item: {e}") print(f"Error processing item: {e}")
elif isinstance(obj, str): elif isinstance(obj, str):
@ -157,7 +157,7 @@ class Data_process():
logger.info(f'splitting file {file_path}') logger.info(f'splitting file {file_path}')
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
print(data) # print(data)
for conversation in data: for conversation in data:
# for dialog in conversation['conversation']: # for dialog in conversation['conversation']:
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document ##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
@ -165,6 +165,7 @@ class Data_process():
# split_qa.append(Document(page_content = content)) # split_qa.append(Document(page_content = content))
#按conversation块切分 #按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '') content = self.extract_text_from_json(conversation['conversation'], '')
logger.info(f'content====={content}')
split_qa.append(Document(page_content = content)) split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}') # logger.info(f'split_qa size====={len(split_qa)}')
return split_qa return split_qa
@ -229,9 +230,8 @@ class Data_process():
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) # compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
# compressed_docs = compression_retriever.get_relevant_documents(query) # compressed_docs = compression_retriever.get_relevant_documents(query)
# return compressed_docs # return compressed_docs
def rerank(self, query, docs): def rerank(self, query, docs):
reranker = self.load_rerank_model() reranker = self.load_rerank_model()
passages = [] passages = []
for doc in docs: for doc in docs:
@ -240,9 +240,41 @@ class Data_process():
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True) sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
sorted_passages, sorted_scores = zip(*sorted_pairs) sorted_passages, sorted_scores = zip(*sorted_pairs)
return sorted_passages, sorted_scores return sorted_passages, sorted_scores
# def create_prompt(question, context):
# from langchain.prompts import PromptTemplate
# prompt_template = f"""请基于以下内容回答问题:
# {context}
# 问题: {question}
# 回答:"""
# prompt = PromptTemplate(
# template=prompt_template, input_variables=["context", "question"]
# )
# logger.info(f'Prompt: {prompt}')
# return prompt
def create_prompt(question, context):
prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:"""
logger.info(f'Prompt: {prompt}')
return prompt
def test_zhipu(prompt):
from zhipuai import ZhipuAI
api_key = "" # 填写您自己的APIKey
if api_key == "":
raise ValueError("请填写api_key")
client = ZhipuAI(api_key=api_key)
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=[
{"role": "user", "content": prompt[:100]}
],
)
print(response.choices[0].message)
if __name__ == "__main__": if __name__ == "__main__":
logger.info(data_dir) logger.info(data_dir)
if not os.path.exists(data_dir): if not os.path.exists(data_dir):
@ -254,7 +286,8 @@ if __name__ == "__main__":
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。" # query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版1979年再版。"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性" # query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" # query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = "我现在心情非常差,有什么解决办法吗?"
docs, retriever = dp.retrieve(query, vector_db, k=10) docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}') logger.info(f'Query: {query}')
logger.info("Retrieve results:") logger.info("Retrieve results:")
@ -267,4 +300,6 @@ if __name__ == "__main__":
logger.info("After reranking...") logger.info("After reranking...")
for i in range(len(scores)): for i in range(len(scores)):
logger.info(str(scores[i]) + '\n') logger.info(str(scores[i]) + '\n')
logger.info(passages[i]) logger.info(passages[i])
prompt = create_prompt(query, passages[0])
test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制

View File

@ -1,20 +1,17 @@
import os import os
import json import time
import pickle import jwt
import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir from config.config import base_dir, data_dir
from util.encode import load_embedding, encode_qa from data_processing import Data_process
from util.pipeline import EmoLLMRAG from pipeline import EmoLLMRAG
from langchain_openai import ChatOpenAI
from loguru import logger from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
import streamlit as st import streamlit as st
from openxlab.model import download from openxlab.model import download
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
from data_processing import Data_process
''' '''
1构建完整的 RAG pipeline输入为用户 query输出为 answer 1构建完整的 RAG pipeline输入为用户 query输出为 answer
2调用 embedding 提供的接口对 query 向量化 2调用 embedding 提供的接口对 query 向量化
@ -24,21 +21,45 @@ from data_processing import Data_process
6拼接 prompt 并调用模型返回结果 6拼接 prompt 并调用模型返回结果
''' '''
# download( def get_glm(temprature):
# model_repo=model_repo, llm = ChatOpenAI(
# output='model' model_name="glm-4",
# ) openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm
def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
@st.cache_resource @st.cache_resource
def load_model(): def load_model():
model_dir = os.path.join(base_dir,'../model') model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}') logger.info(f'Loading model from {model_dir}')
model = ( model = (
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16) .to(torch.bfloat16)
.cuda() .cuda()
) )
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer return model, tokenizer
def main(query, system_prompt=''): def main(query, system_prompt=''):
@ -60,4 +81,9 @@ def main(query, system_prompt=''):
if __name__ == "__main__": if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想" query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query) main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)

View File

@ -2,9 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from transformers.utils import logging from transformers.utils import logging
from data_processing import DataProcessing from data_processing import Data_process
from config.config import retrieval_num, select_num, system_prompt, prompt_template from config.config import system_prompt, prompt_template
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -28,10 +27,8 @@ class EmoLLMRAG(object):
""" """
self.model = model self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db() self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
self.data_processing_obj = DataProcessing()
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.prompt_template = prompt_template self.prompt_template = prompt_template
self.retrieval_num = retrieval_num self.retrieval_num = retrieval_num
@ -43,8 +40,6 @@ class EmoLLMRAG(object):
调用 embedding 模块给出接口 load vector DB 调用 embedding 模块给出接口 load vector DB
""" """
vectorstores = self.data_processing_obj.load_vector_db() vectorstores = self.data_processing_obj.load_vector_db()
if not vectorstores:
vectorstores = self.data_processing_obj.load_index_and_knowledge()
return vectorstores return vectorstores
@ -57,13 +52,17 @@ class EmoLLMRAG(object):
content = '' content = ''
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num) documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents = self.data_processing_obj.rerank(documents, self.select_num)
for doc in documents: for doc in documents:
content += doc.page_content content += doc.page_content
# 如果需要rerank调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
content = ''
for doc in documents:
content += doc
logger.info(f'Retrieval data: {content}')
return content return content
def generate_answer(self, query, content) -> str: def generate_answer(self, query, content) -> str: