update rag/src/data_processing.py
This commit is contained in:
parent
1c5a9c081c
commit
b5af7793d6
@ -12,7 +12,7 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
|
||||
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
|
||||
from BCEmbedding import EmbeddingModel, RerankerModel
|
||||
from util.pipeline import EmoLLMRAG
|
||||
# from util.pipeline import EmoLLMRAG
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
||||
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
|
||||
@ -91,13 +91,13 @@ class Data_process():
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
try:
|
||||
self.extract_text_from_json(value, content)
|
||||
content = self.extract_text_from_json(value, content)
|
||||
except Exception as e:
|
||||
print(f"Error processing value: {e}")
|
||||
elif isinstance(obj, list):
|
||||
for index, item in enumerate(obj):
|
||||
try:
|
||||
self.extract_text_from_json(item, content)
|
||||
content = self.extract_text_from_json(item, content)
|
||||
except Exception as e:
|
||||
print(f"Error processing item: {e}")
|
||||
elif isinstance(obj, str):
|
||||
@ -157,7 +157,7 @@ class Data_process():
|
||||
logger.info(f'splitting file {file_path}')
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
print(data)
|
||||
# print(data)
|
||||
for conversation in data:
|
||||
# for dialog in conversation['conversation']:
|
||||
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
|
||||
@ -165,6 +165,7 @@ class Data_process():
|
||||
# split_qa.append(Document(page_content = content))
|
||||
#按conversation块切分
|
||||
content = self.extract_text_from_json(conversation['conversation'], '')
|
||||
logger.info(f'content====={content}')
|
||||
split_qa.append(Document(page_content = content))
|
||||
# logger.info(f'split_qa size====={len(split_qa)}')
|
||||
return split_qa
|
||||
@ -229,9 +230,8 @@ class Data_process():
|
||||
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
||||
# compressed_docs = compression_retriever.get_relevant_documents(query)
|
||||
# return compressed_docs
|
||||
|
||||
|
||||
def rerank(self, query, docs):
|
||||
def rerank(self, query, docs):
|
||||
reranker = self.load_rerank_model()
|
||||
passages = []
|
||||
for doc in docs:
|
||||
@ -240,9 +240,41 @@ class Data_process():
|
||||
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages, sorted_scores = zip(*sorted_pairs)
|
||||
return sorted_passages, sorted_scores
|
||||
|
||||
|
||||
# def create_prompt(question, context):
|
||||
# from langchain.prompts import PromptTemplate
|
||||
# prompt_template = f"""请基于以下内容回答问题:
|
||||
|
||||
# {context}
|
||||
|
||||
# 问题: {question}
|
||||
# 回答:"""
|
||||
# prompt = PromptTemplate(
|
||||
# template=prompt_template, input_variables=["context", "question"]
|
||||
# )
|
||||
# logger.info(f'Prompt: {prompt}')
|
||||
# return prompt
|
||||
|
||||
def create_prompt(question, context):
|
||||
prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:"""
|
||||
logger.info(f'Prompt: {prompt}')
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
def test_zhipu(prompt):
|
||||
from zhipuai import ZhipuAI
|
||||
api_key = "" # 填写您自己的APIKey
|
||||
if api_key == "":
|
||||
raise ValueError("请填写api_key")
|
||||
client = ZhipuAI(api_key=api_key)
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4", # 填写需要调用的模型名称
|
||||
messages=[
|
||||
{"role": "user", "content": prompt[:100]}
|
||||
],
|
||||
)
|
||||
print(response.choices[0].message)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(data_dir)
|
||||
if not os.path.exists(data_dir):
|
||||
@ -254,7 +286,8 @@ if __name__ == "__main__":
|
||||
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。"
|
||||
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
|
||||
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
|
||||
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
|
||||
query = "我现在心情非常差,有什么解决办法吗?"
|
||||
docs, retriever = dp.retrieve(query, vector_db, k=10)
|
||||
logger.info(f'Query: {query}')
|
||||
logger.info("Retrieve results:")
|
||||
@ -267,4 +300,6 @@ if __name__ == "__main__":
|
||||
logger.info("After reranking...")
|
||||
for i in range(len(scores)):
|
||||
logger.info(str(scores[i]) + '\n')
|
||||
logger.info(passages[i])
|
||||
logger.info(passages[i])
|
||||
prompt = create_prompt(query, passages[0])
|
||||
test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制
|
Loading…
Reference in New Issue
Block a user