OliveSensorAPI/IOTLLM/generate_data/EC_process/LDArec.py

72 lines
2.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
# @Time : 2024/10/24 11:10
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : LDArec.py
# @Project : EmoLLM
import json
import jieba
from gensim import corpora
from gensim.models.ldamodel import LdaModel
from collections import defaultdict
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 加载中文停用词
def load_stopwords(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return set([line.strip() for line in f])
# 使用jieba对中文文本进行分词并去除停用词
def preprocess_text(text, stopwords):
words = jieba.lcut(text) # 使用jieba进行中文分词
words = [word for word in words if word not in stopwords and len(word) > 1] # 去除停用词和长度为1的词
return words
# 生成LDA主题模型
def build_lda_model(qa_pairs, stopwords, num_topics=5):
# 处理所有问题文本
questions = [qa['input'] for qa in qa_pairs]
processed_questions = [preprocess_text(question, stopwords) for question in questions]
# 创建字典和词袋模型
dictionary = corpora.Dictionary(processed_questions)
corpus = [dictionary.doc2bow(text) for text in processed_questions]
# 训练LDA模型
lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
return lda_model, dictionary, corpus
# 打印每个主题的关键词
def print_topics(lda_model, num_words=10):
for idx, topic in lda_model.print_topics(num_words=num_words):
print(f"主题 {idx}: {topic}")
if __name__ == '__main__':
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
stopwords_file = "chinese_stopwords.txt" # 停用词文件
# 加载问答对
qa_pairs = load_qa_data(qa_file)
# 加载停用词
stopwords = load_stopwords(stopwords_file)
# 构建LDA主题模型
lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
# 打印主题及其关键词
print_topics(lda_model)