[Code] update rag (#122)
This commit is contained in:
commit
ad7329d113
@ -33,5 +33,5 @@ prompt_template = """
|
|||||||
{system_prompt}
|
{system_prompt}
|
||||||
根据下面检索回来的信息,回答问题。
|
根据下面检索回来的信息,回答问题。
|
||||||
{content}
|
{content}
|
||||||
问题:{question}
|
问题:{query}
|
||||||
"""
|
"""
|
@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|||||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
||||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||||||
from BCEmbedding import EmbeddingModel, RerankerModel
|
from BCEmbedding import EmbeddingModel, RerankerModel
|
||||||
from util.pipeline import EmoLLMRAG
|
# from util.pipeline import EmoLLMRAG
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
||||||
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
||||||
@ -91,13 +91,13 @@ class Data_process():
|
|||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
for key, value in obj.items():
|
for key, value in obj.items():
|
||||||
try:
|
try:
|
||||||
self.extract_text_from_json(value, content)
|
content = self.extract_text_from_json(value, content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing value: {e}")
|
print(f"Error processing value: {e}")
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
for index, item in enumerate(obj):
|
for index, item in enumerate(obj):
|
||||||
try:
|
try:
|
||||||
self.extract_text_from_json(item, content)
|
content = self.extract_text_from_json(item, content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing item: {e}")
|
print(f"Error processing item: {e}")
|
||||||
elif isinstance(obj, str):
|
elif isinstance(obj, str):
|
||||||
@ -157,7 +157,7 @@ class Data_process():
|
|||||||
logger.info(f'splitting file {file_path}')
|
logger.info(f'splitting file {file_path}')
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
print(data)
|
# print(data)
|
||||||
for conversation in data:
|
for conversation in data:
|
||||||
# for dialog in conversation['conversation']:
|
# for dialog in conversation['conversation']:
|
||||||
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
|
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
|
||||||
@ -165,6 +165,7 @@ class Data_process():
|
|||||||
# split_qa.append(Document(page_content = content))
|
# split_qa.append(Document(page_content = content))
|
||||||
#按conversation块切分
|
#按conversation块切分
|
||||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||||
|
logger.info(f'content====={content}')
|
||||||
split_qa.append(Document(page_content = content))
|
split_qa.append(Document(page_content = content))
|
||||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||||
return split_qa
|
return split_qa
|
||||||
@ -229,9 +230,8 @@ class Data_process():
|
|||||||
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
||||||
# compressed_docs = compression_retriever.get_relevant_documents(query)
|
# compressed_docs = compression_retriever.get_relevant_documents(query)
|
||||||
# return compressed_docs
|
# return compressed_docs
|
||||||
|
|
||||||
|
|
||||||
def rerank(self, query, docs):
|
def rerank(self, query, docs):
|
||||||
reranker = self.load_rerank_model()
|
reranker = self.load_rerank_model()
|
||||||
passages = []
|
passages = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
@ -240,9 +240,41 @@ class Data_process():
|
|||||||
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
|
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
|
||||||
sorted_passages, sorted_scores = zip(*sorted_pairs)
|
sorted_passages, sorted_scores = zip(*sorted_pairs)
|
||||||
return sorted_passages, sorted_scores
|
return sorted_passages, sorted_scores
|
||||||
|
|
||||||
|
|
||||||
|
# def create_prompt(question, context):
|
||||||
|
# from langchain.prompts import PromptTemplate
|
||||||
|
# prompt_template = f"""请基于以下内容回答问题:
|
||||||
|
|
||||||
|
# {context}
|
||||||
|
|
||||||
|
# 问题: {question}
|
||||||
|
# 回答:"""
|
||||||
|
# prompt = PromptTemplate(
|
||||||
|
# template=prompt_template, input_variables=["context", "question"]
|
||||||
|
# )
|
||||||
|
# logger.info(f'Prompt: {prompt}')
|
||||||
|
# return prompt
|
||||||
|
|
||||||
|
def create_prompt(question, context):
|
||||||
|
prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:"""
|
||||||
|
logger.info(f'Prompt: {prompt}')
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def test_zhipu(prompt):
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
api_key = "" # 填写您自己的APIKey
|
||||||
|
if api_key == "":
|
||||||
|
raise ValueError("请填写api_key")
|
||||||
|
client = ZhipuAI(api_key=api_key)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="glm-4", # 填写需要调用的模型名称
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": prompt[:100]}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response.choices[0].message)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info(data_dir)
|
logger.info(data_dir)
|
||||||
if not os.path.exists(data_dir):
|
if not os.path.exists(data_dir):
|
||||||
@ -254,7 +286,8 @@ if __name__ == "__main__":
|
|||||||
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。"
|
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。"
|
||||||
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
|
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
|
||||||
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
|
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
|
||||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||||
|
query = "我现在心情非常差,有什么解决办法吗?"
|
||||||
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
||||||
logger.info(f'Query: {query}')
|
logger.info(f'Query: {query}')
|
||||||
logger.info("Retrieve results:")
|
logger.info("Retrieve results:")
|
||||||
@ -267,4 +300,6 @@ if __name__ == "__main__":
|
|||||||
logger.info("After reranking...")
|
logger.info("After reranking...")
|
||||||
for i in range(len(scores)):
|
for i in range(len(scores)):
|
||||||
logger.info(str(scores[i]) + '\n')
|
logger.info(str(scores[i]) + '\n')
|
||||||
logger.info(passages[i])
|
logger.info(passages[i])
|
||||||
|
prompt = create_prompt(query, passages[0])
|
||||||
|
test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制
|
@ -1,20 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import time
|
||||||
import pickle
|
import jwt
|
||||||
import numpy as np
|
|
||||||
from typing import Tuple
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir
|
from config.config import base_dir, data_dir
|
||||||
from util.encode import load_embedding, encode_qa
|
from data_processing import Data_process
|
||||||
from util.pipeline import EmoLLMRAG
|
from pipeline import EmoLLMRAG
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from openxlab.model import download
|
from openxlab.model import download
|
||||||
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
|
|
||||||
from data_processing import Data_process
|
|
||||||
'''
|
'''
|
||||||
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
|
||||||
2)调用 embedding 提供的接口对 query 向量化
|
2)调用 embedding 提供的接口对 query 向量化
|
||||||
@ -24,21 +21,45 @@ from data_processing import Data_process
|
|||||||
6)拼接 prompt 并调用模型返回结果
|
6)拼接 prompt 并调用模型返回结果
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# download(
|
def get_glm(temprature):
|
||||||
# model_repo=model_repo,
|
llm = ChatOpenAI(
|
||||||
# output='model'
|
model_name="glm-4",
|
||||||
# )
|
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
openai_api_key=generate_token("api-key"),
|
||||||
|
streaming=False,
|
||||||
|
temperature=temprature
|
||||||
|
)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
def generate_token(apikey: str, exp_seconds: int=100):
|
||||||
|
try:
|
||||||
|
id, secret = apikey.split(".")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("invalid apikey", e)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": id,
|
||||||
|
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||||
|
"timestamp": int(round(time.time() * 1000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
payload,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||||
|
)
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def load_model():
|
def load_model():
|
||||||
model_dir = os.path.join(base_dir,'../model')
|
model_dir = os.path.join(base_dir,'../model')
|
||||||
logger.info(f'Loading model from {model_dir}')
|
logger.info(f'Loading model from {model_dir}')
|
||||||
model = (
|
model = (
|
||||||
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
|
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
|
||||||
.to(torch.bfloat16)
|
.to(torch.bfloat16)
|
||||||
.cuda()
|
.cuda()
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def main(query, system_prompt=''):
|
def main(query, system_prompt=''):
|
||||||
@ -60,4 +81,9 @@ def main(query, system_prompt=''):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||||
main(query)
|
main(query)
|
||||||
|
#model = get_glm(0.7)
|
||||||
|
#rag_obj = EmoLLMRAG(model, 3)
|
||||||
|
#res = rag_obj.main(query)
|
||||||
|
#logger.info(res)
|
||||||
|
|
||||||
|
@ -2,9 +2,8 @@ from langchain_core.output_parsers import StrOutputParser
|
|||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from data_processing import DataProcessing
|
from data_processing import Data_process
|
||||||
from config.config import retrieval_num, select_num, system_prompt, prompt_template
|
from config.config import system_prompt, prompt_template
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -28,10 +27,8 @@ class EmoLLMRAG(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.data_processing_obj = Data_process()
|
||||||
self.vectorstores = self._load_vector_db()
|
self.vectorstores = self._load_vector_db()
|
||||||
self.system_prompt = self._get_system_prompt()
|
|
||||||
self.prompt_template = self._get_prompt_template()
|
|
||||||
self.data_processing_obj = DataProcessing()
|
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
self.prompt_template = prompt_template
|
self.prompt_template = prompt_template
|
||||||
self.retrieval_num = retrieval_num
|
self.retrieval_num = retrieval_num
|
||||||
@ -43,8 +40,6 @@ class EmoLLMRAG(object):
|
|||||||
调用 embedding 模块给出接口 load vector DB
|
调用 embedding 模块给出接口 load vector DB
|
||||||
"""
|
"""
|
||||||
vectorstores = self.data_processing_obj.load_vector_db()
|
vectorstores = self.data_processing_obj.load_vector_db()
|
||||||
if not vectorstores:
|
|
||||||
vectorstores = self.data_processing_obj.load_index_and_knowledge()
|
|
||||||
|
|
||||||
return vectorstores
|
return vectorstores
|
||||||
|
|
||||||
@ -57,13 +52,17 @@ class EmoLLMRAG(object):
|
|||||||
content = ''
|
content = ''
|
||||||
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
|
||||||
|
|
||||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
|
||||||
if self.rerank_flag:
|
|
||||||
documents = self.data_processing_obj.rerank(documents, self.select_num)
|
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content += doc.page_content
|
content += doc.page_content
|
||||||
|
|
||||||
|
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||||
|
if self.rerank_flag:
|
||||||
|
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)
|
||||||
|
|
||||||
|
content = ''
|
||||||
|
for doc in documents:
|
||||||
|
content += doc
|
||||||
|
logger.info(f'Retrieval data: {content}')
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def generate_answer(self, query, content) -> str:
|
def generate_answer(self, query, content) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user