OliveSensorAPI/IOTLLM/generate_data/EC_process/topic_model.py

59 lines
1.8 KiB
Python

# -*- coding: utf-8 -*-
# @Time : 2024/10/23 23:16
# @Author : 黄子寒
# @Email : 1064071566@qq.com
# @File : topic_model.py
# @Project : EmoLLM
import json
import gensim
from gensim import corpora
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from collections import defaultdict
# 加载问答对数据
def load_qa_data(file_path):
qa_pairs = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
qa_pairs.append(json.loads(line.strip()))
return qa_pairs
# 文本预处理
def preprocess_text(text):
stop_words = set(stopwords.words('chinese'))
tokens = word_tokenize(text.lower())
tokens = [word for word in tokens if word.isalnum() and word not in stop_words]
return tokens
# 生成LDA主题模型
def build_lda_model(qa_pairs, num_topics=5):
# 处理所有问题文本
questions = [qa['input'] for qa in qa_pairs]
processed_questions = [preprocess_text(question) for question in questions]
# 创建字典和词袋模型
dictionary = corpora.Dictionary(processed_questions)
corpus = [dictionary.doc2bow(text) for text in processed_questions]
# 训练LDA模型
lda_model = gensim.models.ldamodel.LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
return lda_model, dictionary, corpus
# 打印每个主题的关键词
def print_topics(lda_model, num_words=10):
for idx, topic in lda_model.print_topics(num_words=num_words):
print(f"主题 {idx}: {topic}")
if __name__ == '__main__':
qa_file = "output/train_optimized_multiple.jsonl" # 问答对文件
# 加载问答对
qa_pairs = load_qa_data(qa_file)
# 构建LDA主题模型
lda_model, dictionary, corpus = build_lda_model(qa_pairs, num_topics=5)
# 打印主题及其关键词
print_topics(lda_model)