commit
72a7746d8b
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,6 +3,8 @@ ESConv.json
|
||||
tmp/
|
||||
zhipuai/
|
||||
data/
|
||||
pdf/
|
||||
.idea/
|
||||
|
||||
*.jsonl
|
||||
*.json
|
||||
|
@ -220,6 +220,7 @@ git clone https://github.com/SmartFlowAI/EmoLLM.git
|
||||
| [Anooyman](https://github.com/Anooyman) | 南京理工大学硕士 | | |
|
||||
| [Vicky-3021](https://github.com/Vicky-3021) | 西安电子科技大学硕士(研0) | | |
|
||||
| [SantiagoTOP](https://github.com/santiagoTOP) | 太原理工大学在读硕士 | | |
|
||||
| [zealot52099](https://github.com/zealot52099) | AI搬用工 | |清洗数据、RAG|
|
||||
|
||||
### 版权说明
|
||||
|
||||
|
@ -244,7 +244,7 @@ This project uses Git for version control. You can see the currently available v
|
||||
| [Anooyman](https://github.com/Anooyman) | Nanjing University of Science and Technology, Master's student | | |
|
||||
| [Vicky-3021](https://github.com/Vicky-3021) | Xidian University, Master's student (Research Year 0) | | |
|
||||
| [SantiagoTOP](https://github.com/santiagoTOP) | Taiyuan University of Technology, Master's student | | |
|
||||
|
||||
| [zealot52099](https://github.com/zealot52099) | AI Mover | |Data Processing and RAG|
|
||||
|
||||
### Copyright Notice
|
||||
|
||||
|
@ -23,7 +23,14 @@ def qwen_api(data, emo):
|
||||
病人:病人的咨询或陈述
|
||||
医生:医生的安抚和建议
|
||||
'''
|
||||
response = dashscope.Generation.call(
|
||||
try:
|
||||
response = dashscope.Generation.call(
|
||||
model='qwen-max',
|
||||
prompt=prompt,
|
||||
history=[],
|
||||
)
|
||||
except:
|
||||
response = dashscope.Generation.call(
|
||||
model='qwen-max',
|
||||
prompt=prompt,
|
||||
history=[],
|
||||
@ -55,13 +62,16 @@ if __name__ == '__main__':
|
||||
areas_of_life = configs['areas_of_life']
|
||||
ai_tool = 'qwen'
|
||||
|
||||
save_interval = 5
|
||||
total_num_each_emo_area = 5
|
||||
|
||||
conversation_lis = []
|
||||
|
||||
for emo in emotions_lis:
|
||||
for area in areas_of_life:
|
||||
for area in areas_of_life:
|
||||
for emo in emotions_lis:
|
||||
gen_path = f'./{ai_tool}/{area}/{emo}.jsonl'
|
||||
|
||||
for i in tqdm(range(100), desc='{emo}, {area}'.format(emo=emo, area=area)):
|
||||
for i in tqdm(range(total_num_each_emo_area), desc='{emo}, {area}'.format(emo=emo, area=area)):
|
||||
one_conversation = {
|
||||
"conversation": []
|
||||
}
|
||||
@ -98,8 +108,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
conversation_lis.append(one_conversation)
|
||||
|
||||
# 每生成10条数据存储一次
|
||||
if ((i+1) % 10 == 0):
|
||||
if ((i+1) % save_interval == 0):
|
||||
save_jsonl(data_lis=conversation_lis, file_path=gen_path)
|
||||
print(f'generate {gen_path}')
|
||||
conversation_lis = [] # 清空
|
||||
|
@ -100,7 +100,10 @@
|
||||
|
||||
5. **数据集整合**
|
||||
|
||||
在进行数据集整合之前,我们要检查生成的数据是否存在格式错误,类型不符合等情况。我们需要check.py进行检查数据。最后再使用merge_json.py将所有的json整合为一个总的json文件。
|
||||
在进行数据集整合之前,我们要检查生成的数据是否存在格式错误,类型不符合等情况。
|
||||
|
||||
* 首先使用`check.py`进行数据检查。
|
||||
* 然后使用`merge_json.py`将所有的json整合为一个总的json文件。
|
||||
|
||||
6. **评估与优化**
|
||||
|
||||
|
@ -82,12 +82,15 @@ if __name__ == '__main__':
|
||||
areas_of_life = configs['areas_of_life']
|
||||
ai_tool = 'zhipuai'
|
||||
|
||||
save_interval = 5
|
||||
total_num_each_emo_area = 5
|
||||
|
||||
conversation_lis = []
|
||||
for emo in emotions_lis:
|
||||
for area in areas_of_life:
|
||||
for area in areas_of_life:
|
||||
for emo in emotions_lis:
|
||||
gen_path = f'./{ai_tool}/{area}/{emo}.jsonl'
|
||||
|
||||
for i in tqdm(range(100), desc='{emo}, {area}'.format(emo=emo, area=area)):
|
||||
for i in tqdm(range(total_num_each_emo_area), desc='{emo}, {area}'.format(emo=emo, area=area)):
|
||||
res = zhipu_api(area, emo)
|
||||
print(res)
|
||||
if res == 'null':
|
||||
@ -95,7 +98,7 @@ if __name__ == '__main__':
|
||||
continue
|
||||
conversation_lis.append(convert(res))
|
||||
|
||||
if ((i+1) % 10 == 0):
|
||||
if ((i+1) % save_interval == 0):
|
||||
# path = f'./{args.data}.jsonl'
|
||||
save_jsonl(data_lis=conversation_lis, file_path=gen_path)
|
||||
print(f'generate {gen_path}')
|
||||
|
@ -0,0 +1,66 @@
|
||||
# EmoLLM RAG
|
||||
|
||||
## **Module purpose**
|
||||
|
||||
Based on the customer's questions, the corresponding information is retrieved to enhance the professionalism of the answer, making EmoLLM's answer more professional and reliable. Search content includes but is not limited to the following:
|
||||
- Psychology related theories
|
||||
- Psychology methodology
|
||||
- Classic Case
|
||||
- Customer background knowledge
|
||||
|
||||
## **Datasets**
|
||||
|
||||
|
||||
- Cleaned QA pairs: Each QA pair is embedding as a sample
|
||||
- Filtered TXT texts
|
||||
- Directly generate embedding for TXT text (segmented based on token length)
|
||||
- Filter out irrelevant information such as directories and generate embedding for TXT text (segmented based on token length)
|
||||
- After filtering irrelevant information such as directories, the TXT is semantically segmented to generate embedding.
|
||||
- Split TXT according to the directory structure, and generate embeddings based on the architecture hierarchy.
|
||||
|
||||
|
||||
For details on data collection construction, please refer to [qa_generation_README](https://github.com/SmartFlowAI/EmoLLM/blob/ccfa75c493c4685e84073dfbc53c50c09a2988e3/scripts/qa_generation/README.md)
|
||||
|
||||
## **Components**
|
||||
|
||||
### [BCEmbedding](https://github.com/netease-youdao/BCEmbedding?tab=readme-ov-file)
|
||||
|
||||
- [bce-embedding-base_v1](https://hf-mirror.com/maidalun1020/bce-embedding-base_v1): embedding model, used to build vector DB
|
||||
- [bce-reranker-base_v1](https://hf-mirror.com/maidalun1020/bce-reranker-base_v1): rerank model, used to rerank retrieved documents
|
||||
|
||||
### [Langchain](https://python.langchain.com/docs/get_started)
|
||||
|
||||
LangChain is an open source framework for building large language model (LLM) based applications. LangChain provides a variety of tools and abstractions to increase the customization, accuracy, and relevance of the information generated by your models.
|
||||
|
||||
### [FAISS](https://faiss.ai/)
|
||||
|
||||
FAISS is a library for efficient similarity search and dense vector clustering. It contains algorithms that can search sets of vectors of any size. Since langchain has integrated FAISS, this project will no longer be developed based on native documents. [FAISS in Langchain](https://python.langchain.com/docs/integrations/vectorstores/faiss)
|
||||
|
||||
|
||||
### [RAGAS](https://github.com/explodinggradients/ragas)
|
||||
|
||||
RAG’s classic evaluation framework is evaluated through the following three aspects:
|
||||
|
||||
- Faithfulness: The answers given should be generated based on the given context.
|
||||
- Answer Relevance: The generated answer should solve the actual question asked.
|
||||
- Context Relevance: The retrieved information should be highly concentrated and contain as little irrelevant information as possible.
|
||||
|
||||
Later, more evaluation indicators were added, such as: context recall, etc.
|
||||
|
||||
## **Detials**
|
||||
|
||||
### RAG pipeline
|
||||
|
||||
- Build vector DB based on data set
|
||||
- Embedding questions entered by customers
|
||||
- Search in vector database based on embedding results
|
||||
- Reorder recall data
|
||||
- Generate final results based on user questions and recall data
|
||||
|
||||
**Noted**: The above process will only be carried out when the user chooses to use RAG
|
||||
|
||||
### Follow-up actions
|
||||
|
||||
- Add RAGAS evaluation results to the generation process. For example, when the generated results cannot solve the user's problem, it needs to be regenerated.
|
||||
- Add web retrieval to deal with the problem that the corresponding information cannot be retrieved in vector DB
|
||||
- Add multi-channel retrieval to increase recall rate. That is, multiple similar queries are generated based on user input for retrieval.
|
@ -2,3 +2,5 @@ sentence_transformers
|
||||
transformers
|
||||
numpy
|
||||
loguru
|
||||
langchain
|
||||
torch
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__)) # config
|
||||
src_dir = os.path.dirname(cur_dir) # src
|
||||
base_dir = os.path.dirname(src_dir) # base
|
||||
model_repo = 'ajupyter/EmoLLM_aiwei'
|
||||
|
||||
# model
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
@ -17,3 +18,6 @@ knowledge_pkl_path = os.path.join(data_dir, 'knowledge.pkl') # pickle
|
||||
# log
|
||||
log_dir = os.path.join(base_dir, 'log') # log
|
||||
log_path = os.path.join(log_dir, 'log.log') # file
|
||||
|
||||
select_num = 3
|
||||
retrieval_num = 10
|
@ -5,8 +5,19 @@ import numpy as np
|
||||
from typing import Tuple
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from config.config import knowledge_json_path, knowledge_pkl_path
|
||||
from config.config import knowledge_json_path, knowledge_pkl_path, model_repo
|
||||
from util.encode import load_embedding, encode_qa
|
||||
from util.pipeline import EmoLLMRAG
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
import streamlit as st
|
||||
from openxlab.model import download
|
||||
|
||||
download(
|
||||
model_repo=model_repo,
|
||||
output='model'
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
@ -62,6 +73,19 @@ def main():
|
||||
## 2. 将 contents 拼接为 prompt,传给 LLM,作为 {已知内容}
|
||||
## 3. 要求 LLM 根据已知内容回复
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
|
||||
.to(torch.bfloat16)
|
||||
.cuda()
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
#main()
|
||||
query = ''
|
||||
model, tokenizer = load_model()
|
||||
rag_obj = EmoLLMRAG(model)
|
||||
response = rag_obj.main(query)
|
114
rag/src/util/pipeline.py
Normal file
114
rag/src/util/pipeline.py
Normal file
@ -0,0 +1,114 @@
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from transformers.utils import logging
|
||||
|
||||
from config.config import retrieval_num, select_num
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class EmoLLMRAG(object):
|
||||
"""
|
||||
EmoLLM RAG Pipeline
|
||||
1. 根据 query 进行 embedding
|
||||
2. 从 vector DB 中检索数据
|
||||
3. rerank 检索后的结果
|
||||
4. 将 query 和检索回来的 content 传入 LLM 中
|
||||
"""
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
"""
|
||||
输入 Model 进行初始化
|
||||
|
||||
DataProcessing obj: 进行数据处理,包括数据 embedding/rerank
|
||||
vectorstores: 加载vector DB。如果没有应该重新创建
|
||||
system prompt: 获取预定义的 system prompt
|
||||
prompt template: 定义最后的输入到 LLM 中的 template
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.vectorstores = self._load_vector_db()
|
||||
self.system_prompt = self._get_system_prompt()
|
||||
self.prompt_template = self._get_prompt_template()
|
||||
|
||||
# 等待 embedding team 封装对应接口
|
||||
#self.data_process_obj = DataProcessing()
|
||||
|
||||
def _load_vector_db(self):
|
||||
"""
|
||||
调用 embedding 模块给出接口 load vector DB
|
||||
"""
|
||||
return
|
||||
|
||||
def _get_system_prompt(self) -> str:
|
||||
"""
|
||||
加载 system prompt
|
||||
"""
|
||||
return ''
|
||||
|
||||
def _get_prompt_template(self) -> str:
|
||||
"""
|
||||
加载 prompt template
|
||||
"""
|
||||
return ''
|
||||
|
||||
def get_retrieval_content(self, query, rerank_flag=False) -> str:
|
||||
"""
|
||||
Input: 用户提问, 是否需要rerank
|
||||
ouput: 检索后并且 rerank 的内容
|
||||
"""
|
||||
|
||||
content = ''
|
||||
documents = self.vectorstores.similarity_search(query, k=retrieval_num)
|
||||
|
||||
# 如果需要rerank,调用接口对 documents 进行 rerank
|
||||
if rerank_flag:
|
||||
pass
|
||||
# 等后续调用接口
|
||||
#documents = self.data_process_obj.rerank_documents(documents, select_num)
|
||||
|
||||
for doc in documents:
|
||||
content += doc.page_content
|
||||
|
||||
return content
|
||||
|
||||
def generate_answer(self, query, content) -> str:
|
||||
"""
|
||||
Input: 用户提问, 检索返回的内容
|
||||
Output: 模型生成结果
|
||||
"""
|
||||
|
||||
# 构建 template
|
||||
# 第一版不涉及 history 信息,因此将 system prompt 直接纳入到 template 之中
|
||||
prompt = PromptTemplate(
|
||||
template=self.prompt_template,
|
||||
input_variables=["query", "content", "system_prompt"],
|
||||
)
|
||||
|
||||
# 定义 chain
|
||||
# output格式为 string
|
||||
rag_chain = prompt | self.model | StrOutputParser()
|
||||
|
||||
# Run
|
||||
generation = rag_chain.invoke(
|
||||
{
|
||||
"query": query,
|
||||
"content": content,
|
||||
"system_prompt": self.system_prompt
|
||||
}
|
||||
)
|
||||
return generation
|
||||
|
||||
def main(self, query) -> str:
|
||||
"""
|
||||
Input: 用户提问
|
||||
output: LLM 生成的结果
|
||||
|
||||
定义整个 RAG 的 pipeline 流程,调度各个模块
|
||||
TODO:
|
||||
加入 RAGAS 评分系统
|
||||
"""
|
||||
content = self.get_retrieval_content(query)
|
||||
response = self.generate_answer(query, content)
|
||||
|
||||
return response
|
111
scripts/qa_generation/QA_clean.py
Normal file
111
scripts/qa_generation/QA_clean.py
Normal file
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from config.config import result_dir, clean_dir, storage_interval, window_size, overlap_size, multi_process_num
|
||||
from model.qwen import call_qwen_single_turn, call_qwen_Psychology_QA_Pairs
|
||||
from util.logger import get_logger
|
||||
from util.data_loader import get_jsonl_file_paths, get_file_list, get_QA_pairs, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents):
|
||||
|
||||
storage_counter = 0
|
||||
judge_list = []
|
||||
for content in tqdm(contents):
|
||||
# print('content: ', content)
|
||||
try:
|
||||
# model_caller 函数的作用是调用某个预训练的问答生成模型,传递输入内容 content 给模型,然后获取模型的输出 response
|
||||
response = model_caller(content)
|
||||
# print('response: ', response)
|
||||
|
||||
if response == '1':
|
||||
content = json.loads(content)
|
||||
judge_list.append(content)
|
||||
storage_counter += 1
|
||||
else:
|
||||
continue
|
||||
|
||||
# 在达到指定的 interval 后,将 storage_list 中的内容保存到指定的文件 storage_jsonl_path 中
|
||||
if storage_counter % interval == 0:
|
||||
save_to_file(storage_jsonl_path, judge_list)
|
||||
storage_counter = 0
|
||||
judge_list = []
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("QA generation error : %s" % (exc))
|
||||
|
||||
# 最后,如果 storage_list 中还有剩余内容,也会将其保存到文件中。
|
||||
if judge_list:
|
||||
save_to_file(storage_jsonl_path, judge_list)
|
||||
judge_list = []
|
||||
|
||||
|
||||
"""
|
||||
生成 QA 对
|
||||
model_name: 可调用的模型名称,暂时只实现了 qwen
|
||||
interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销
|
||||
"""
|
||||
def clean_qa(
|
||||
model_name: str = 'qwen',
|
||||
interval: int = 10,
|
||||
):
|
||||
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
if model_name == 'qwen':
|
||||
model_caller = call_qwen_Psychology_QA_Pairs
|
||||
else:
|
||||
logger.warning('This model is currently not supported and will call the default model - qwen.')
|
||||
model_caller = call_qwen_Psychology_QA_Pairs
|
||||
model_name = 'qwen'
|
||||
|
||||
logger.info(f'The called model is: {model_name}.')
|
||||
logger.info(f'The storage interval is: {interval}.')
|
||||
|
||||
file_lists = get_jsonl_file_paths() # 数据整合文件夹下所有.jsonl文件的地址
|
||||
|
||||
for file_path in file_lists:
|
||||
# 一个jsonl文件的所有QA Pairs
|
||||
contents = get_QA_pairs(file_path)
|
||||
# print(contents)
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
print(file_name)
|
||||
storage_jsonl_path = os.path.join(
|
||||
clean_dir, f'{file_name}')
|
||||
|
||||
logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')
|
||||
|
||||
contents_array = np.array(contents)
|
||||
chunks = np.array_split(contents_array, multi_process_num)
|
||||
|
||||
# 构建并发参数 list
|
||||
parameters_list = list()
|
||||
for thread_num, chunk in enumerate(chunks):
|
||||
parameters_list.append(
|
||||
[thread_num, interval, model_caller, storage_jsonl_path, list(chunk)]
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
|
||||
# 循环调用 single_thread_generate 函数,每次赋予参数 parameters
|
||||
futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list]
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as exc:
|
||||
logger.error("Thread generated an exception: %s" % (exc))
|
||||
|
||||
merge_sub_qa_generation(result_dir, storage_jsonl_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 创建washed文件夹
|
||||
os.makedirs('./data/cleaned', exist_ok=True)
|
||||
clean_qa(interval=storage_interval)
|
@ -93,3 +93,34 @@
|
||||
## **步骤四:清洗QA对**
|
||||
|
||||
- 清洗目的
|
||||
|
||||
- 提高提取的QA数据质量,清理掉与心理学无关的QA对
|
||||
|
||||
- 清洗方法
|
||||
|
||||
- 使用Prompt方法,驱动LLM对给出的QA对进行判断
|
||||
|
||||
- **参考Prompt**
|
||||
|
||||
- ```markdown
|
||||
你是一名经验丰富的心理咨询师,熟悉心理学相关知识。根据我提供的 QA 对,来判断这个 QA 对是否属于心理学范畴。
|
||||
|
||||
标准如下:
|
||||
|
||||
- 若当前 QA 对属于心理学范畴,则返回1
|
||||
- 若当前 QA 对不属于心理学范畴,则返回0
|
||||
|
||||
|
||||
以下是给定的心理学 QA 对内容:
|
||||
```
|
||||
|
||||
- 清洗工具
|
||||
- 配置`config/config.py` 中的 `DASHSCOPE_API_KEY`,`API_KEY`获取方法见步骤三
|
||||
- 使用提供的清洗脚本[QA_Clear](https://github.com/SmartFlowAI/EmoLLM/blob/main/scripts/qa_generation/QA_clean.py)
|
||||
|
||||
- 使用方法
|
||||
- 准备好需要清洗的 QA 对数据
|
||||
- 将该数据放进 model 同级 data 文件夹下
|
||||
- 根据文件夹名去修改 `config/config.py` 中的 `judge_dir`。
|
||||
- 如存储数据的文件名为`xxx`,则`judge_dir`是 `judge_dir = os.path.join(data_dir, 'xxx')`
|
||||
- 清洗完的 QA 对会以 `jsonl` 的格式存在 `data/cleaned` 下
|
||||
|
@ -93,3 +93,40 @@ Using books specialized in psychology to build QA knowledge pairs for RAG to pro
|
||||
## **Step 4: Cleaning of QA pairs**
|
||||
|
||||
- Purpose of cleaning
|
||||
- Improve the quality of extracted QA data and clean out QA pairs that are not relevant to psychology
|
||||
|
||||
- Cleaning Methods
|
||||
|
||||
- Use the Prompt method to drive the LLM to make a judgment on the given QA pairs
|
||||
|
||||
- **Reference to Prompt**
|
||||
|
||||
- ```markdown
|
||||
You are an experienced counselor and are familiar with psychology. Based on the QA pair I have provided, determine if this QA pair is psychological in nature.
|
||||
|
||||
The criteria are as follows:
|
||||
|
||||
- If the current QA pair belongs to the category of psychology, then return 1
|
||||
- If the current QA pair does not belong to the category of psychology, then return 0.
|
||||
|
||||
|
||||
The following is the content of the given psychology QA pair:
|
||||
```
|
||||
|
||||
- Cleaning Tools
|
||||
|
||||
- Configure `DASHSCOPE_API_KEY` in `config/config.py`, see step 3 for how to get `API_KEY`.
|
||||
|
||||
- Use the provided cleaning script [QA_Clear](https://github.com/SmartFlowAI/EmoLLM/blob/main/scripts/qa_generation/QA_clean.py)
|
||||
|
||||
- How to use
|
||||
|
||||
- Prepare the QA pair data to be cleaned
|
||||
|
||||
- Put the data into the data folder of the same level as the model.
|
||||
|
||||
- Modify `judge_dir` in `config/config.py` according to the folder name.
|
||||
|
||||
- If the file name of the stored data is `xxx`, then `judge_dir` is `judge_dir = os.path.join(data_dir, 'xxx')`.
|
||||
|
||||
- The cleaned QA pairs are stored as `jsonl` under `data/cleaned`.
|
||||
|
8
scripts/qa_generation/choose_prompt.md
Normal file
8
scripts/qa_generation/choose_prompt.md
Normal file
@ -0,0 +1,8 @@
|
||||
你是一名经验丰富的心理咨询师,熟悉心理学相关知识。根据我提供的 QA 对,来判断这个 QA 对是否属于心理学范畴。
|
||||
|
||||
标准如下:
|
||||
- 若当前 QA 对属于心理学范畴,则返回1
|
||||
- 若当前 QA 对不属于心理学范畴,则返回0
|
||||
|
||||
|
||||
以下是给定的心理学 QA 对内容:
|
@ -10,7 +10,9 @@ base_dir = os.path.dirname(cur_dir) # ba
|
||||
model_dir = os.path.join(base_dir, 'model') # model
|
||||
|
||||
# data
|
||||
data_dir = os.path.join(base_dir, 'data') # data
|
||||
data_dir = os.path.join(base_dir, 'data')
|
||||
clean_dir = os.path.join(data_dir, 'cleaned')
|
||||
judge_dir = os.path.join(data_dir, '数据整合')
|
||||
result_dir = os.path.join(data_dir, 'generated') # result
|
||||
|
||||
# log
|
||||
@ -18,7 +20,9 @@ log_dir = os.path.join(base_dir, 'log') # lo
|
||||
log_file_path = os.path.join(log_dir, 'log.log') # file
|
||||
|
||||
# system prompt
|
||||
# Prompt内容
|
||||
system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # system prompt
|
||||
wash_prompt_file_path = os.path.join(base_dir, 'choose_prompt.md')
|
||||
|
||||
|
||||
"""
|
||||
@ -28,7 +32,6 @@ system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # sy
|
||||
DASHSCOPE_API_KEY = ''
|
||||
|
||||
|
||||
|
||||
"""
|
||||
控制参数
|
||||
"""
|
||||
@ -36,3 +39,4 @@ storage_interval = 10
|
||||
window_size = 8
|
||||
overlap_size = 2
|
||||
multi_process_num = 3
|
||||
|
||||
|
@ -5,7 +5,7 @@ from dashscope.api_entities.dashscope_response import Role
|
||||
|
||||
from config.config import DASHSCOPE_API_KEY
|
||||
from util.logger import get_logger
|
||||
from util.prompt_loader import load_system_prompt
|
||||
from util.prompt_loader import load_system_prompt, load_wash_prompt
|
||||
|
||||
|
||||
dashscope.api_key = DASHSCOPE_API_KEY
|
||||
@ -39,3 +39,31 @@ def call_qwen_single_turn(query: str) -> str:
|
||||
response.code, response.message
|
||||
))
|
||||
return ""
|
||||
|
||||
|
||||
def call_qwen_Psychology_QA_Pairs(query: str) -> str:
|
||||
messages = [
|
||||
{
|
||||
'role': Role.SYSTEM,
|
||||
'content': load_wash_prompt()
|
||||
},
|
||||
{
|
||||
'role': Role.USER,
|
||||
'content': query
|
||||
}
|
||||
]
|
||||
response = Generation.call(
|
||||
model='qwen-max-1201',
|
||||
messages=messages,
|
||||
result_format='message',
|
||||
stream=False,
|
||||
incremental_output=False
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0]['message']['content']
|
||||
else:
|
||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||
response.request_id, response.status_code,
|
||||
response.code, response.message
|
||||
))
|
||||
return ""
|
||||
|
@ -4,11 +4,39 @@ import json
|
||||
import glob
|
||||
from typing import List, Dict
|
||||
|
||||
from config.config import data_dir
|
||||
from config.config import data_dir, judge_dir
|
||||
from util.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
"""
|
||||
递归获取 数据整合 下的所有 .jsonl 文件列表
|
||||
"""
|
||||
def get_jsonl_file_paths() -> List[str]:
|
||||
json_file_paths = []
|
||||
|
||||
# 遍历根目录及其所有子目录
|
||||
for dirpath, dirnames, filenames in os.walk(judge_dir):
|
||||
# 对每个文件进行检查
|
||||
for filename in filenames:
|
||||
# 使用正则表达式匹配以.jsonl结尾的文件名
|
||||
if re.search(r'\.jsonl$', filename):
|
||||
# 构建完整的文件路径并添加到列表中
|
||||
json_file_path = os.path.join(dirpath, filename)
|
||||
json_file_paths.append(json_file_path)
|
||||
|
||||
return json_file_paths
|
||||
|
||||
def get_QA_pairs(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 按照换行符分割字符串
|
||||
QA_Pairs = content.split('\n')
|
||||
|
||||
return QA_Pairs
|
||||
|
||||
"""
|
||||
递归获取 data_dir 下的所有 .txt 文件列表
|
||||
"""
|
||||
@ -47,7 +75,7 @@ def get_txt_content(
|
||||
res = []
|
||||
sentences_amount = len(sentences)
|
||||
start_index, end_index = 0, sentences_amount - window_size
|
||||
## check length
|
||||
# check length
|
||||
if window_size < overlap_size:
|
||||
logger.error("window_size must be greater than or equal to overlap_size")
|
||||
return None
|
||||
@ -56,7 +84,7 @@ def get_txt_content(
|
||||
return ['\n'.join(sentences)]
|
||||
|
||||
for i in range(start_index, end_index + 1, overlap_size):
|
||||
res.append('\n'.join(sentences[i : i + window_size]))
|
||||
res.append('\n'.join(sentences[i: i + window_size]))
|
||||
return res
|
||||
|
||||
|
||||
@ -80,6 +108,7 @@ def capture_qa(content: str) -> List[Dict]:
|
||||
logger.warning("No JSON block found.")
|
||||
return None
|
||||
|
||||
|
||||
"""
|
||||
将 storage_list 存入到 storage_jsonl_path
|
||||
"""
|
||||
@ -88,6 +117,7 @@ def save_to_file(storage_jsonl_path, storage_list):
|
||||
for item in storage_list:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
"""
|
||||
将并发产生的文件合并成为一个文件
|
||||
"""
|
||||
@ -103,4 +133,3 @@ def merge_sub_qa_generation(directory, storage_jsonl_path):
|
||||
file_contents.append(json.loads(line))
|
||||
os.remove(file_path)
|
||||
save_to_file(storage_jsonl_path, file_contents)
|
||||
|
||||
|
@ -1,7 +1,14 @@
|
||||
from config.config import system_prompt_file_path
|
||||
from config.config import wash_prompt_file_path
|
||||
|
||||
|
||||
def load_system_prompt() -> str:
|
||||
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
return system_prompt
|
||||
|
||||
|
||||
def load_wash_prompt() -> str:
|
||||
with open(wash_prompt_file_path, 'r', encoding='utf-8') as f:
|
||||
wash_prompt = f.read()
|
||||
return wash_prompt
|
||||
|
Loading…
Reference in New Issue
Block a user