72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
|
# -*- coding: utf-8 -*-
|
|||
|
# @Time : 2024/10/24 11:10
|
|||
|
# @Author : 黄子寒
|
|||
|
# @Email : 1064071566@qq.com
|
|||
|
# @File : LDArec.py
|
|||
|
# @Project : EmoLLM
|
|||
|
import json
|
|||
|
import jieba
|
|||
|
from gensim import corpora
|
|||
|
from gensim.models.ldamodel import LdaModel
|
|||
|
from collections import defaultdict
|
|||
|
|
|||
|
|
|||
|
# 加载问答对数据
|
|||
|
def load_qa_data(file_path):
|
|||
|
qa_pairs = []
|
|||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
for line in f:
|
|||
|
qa_pairs.append(json.loads(line.strip()))
|
|||
|
return qa_pairs
|
|||
|
|
|||
|
|
|||
|
# 加载中文停用词
|
|||
|
def load_stopwords(file_path):
|
|||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
return set([line.strip() for line in f])
|
|||
|
|
|||
|
|
|||
|
# 使用jieba对中文文本进行分词,并去除停用词
|
|||
|
def preprocess_text(text, stopwords):
|
|||
|
words = jieba.lcut(text) # 使用jieba进行中文分词
|
|||
|
words = [word for word in words if word not in stopwords and len(word) > 1] # 去除停用词和长度为1的词
|
|||
|
return words
|
|||
|
|
|||
|
|
|||
|
# 生成LDA主题模型
|
|||
|
def build_lda_model(qa_pairs, stopwords, num_topics=5):
|
|||
|
# 处理所有问题文本
|
|||
|
questions = [qa['input'] for qa in qa_pairs]
|
|||
|
processed_questions = [preprocess_text(question, stopwords) for question in questions]
|
|||
|
|
|||
|
# 创建字典和词袋模型
|
|||
|
dictionary = corpora.Dictionary(processed_questions)
|
|||
|
corpus = [dictionary.doc2bow(text) for text in processed_questions]
|
|||
|
|
|||
|
# 训练LDA模型
|
|||
|
lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
|
|||
|
return lda_model, dictionary, corpus
|
|||
|
|
|||
|
|
|||
|
# 打印每个主题的关键词
|
|||
|
def print_topics(lda_model, num_words=10):
|
|||
|
for idx, topic in lda_model.print_topics(num_words=num_words):
|
|||
|
print(f"主题 {idx}: {topic}")
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
|
|||
|
stopwords_file = "chinese_stopwords.txt" # 停用词文件
|
|||
|
|
|||
|
# 加载问答对
|
|||
|
qa_pairs = load_qa_data(qa_file)
|
|||
|
|
|||
|
# 加载停用词
|
|||
|
stopwords = load_stopwords(stopwords_file)
|
|||
|
|
|||
|
# 构建LDA主题模型
|
|||
|
lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
|
|||
|
|
|||
|
# 打印主题及其关键词
|
|||
|
print_topics(lda_model)
|