72 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			72 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # -*- coding: utf-8 -*-
 | ||
| # @Time : 2024/10/24 11:10
 | ||
| # @Author : 黄子寒
 | ||
| # @Email : 1064071566@qq.com
 | ||
| # @File : LDArec.py
 | ||
| # @Project : EmoLLM
 | ||
| import json
 | ||
| import jieba
 | ||
| from gensim import corpora
 | ||
| from gensim.models.ldamodel import LdaModel
 | ||
| from collections import defaultdict
 | ||
| 
 | ||
| 
 | ||
| # 加载问答对数据
 | ||
| def load_qa_data(file_path):
 | ||
|     qa_pairs = []
 | ||
|     with open(file_path, 'r', encoding='utf-8') as f:
 | ||
|         for line in f:
 | ||
|             qa_pairs.append(json.loads(line.strip()))
 | ||
|     return qa_pairs
 | ||
| 
 | ||
| 
 | ||
| # 加载中文停用词
 | ||
| def load_stopwords(file_path):
 | ||
|     with open(file_path, 'r', encoding='utf-8') as f:
 | ||
|         return set([line.strip() for line in f])
 | ||
| 
 | ||
| 
 | ||
| # 使用jieba对中文文本进行分词,并去除停用词
 | ||
| def preprocess_text(text, stopwords):
 | ||
|     words = jieba.lcut(text)  # 使用jieba进行中文分词
 | ||
|     words = [word for word in words if word not in stopwords and len(word) > 1]  # 去除停用词和长度为1的词
 | ||
|     return words
 | ||
| 
 | ||
| 
 | ||
| # 生成LDA主题模型
 | ||
| def build_lda_model(qa_pairs, stopwords, num_topics=5):
 | ||
|     # 处理所有问题文本
 | ||
|     questions = [qa['input'] for qa in qa_pairs]
 | ||
|     processed_questions = [preprocess_text(question, stopwords) for question in questions]
 | ||
| 
 | ||
|     # 创建字典和词袋模型
 | ||
|     dictionary = corpora.Dictionary(processed_questions)
 | ||
|     corpus = [dictionary.doc2bow(text) for text in processed_questions]
 | ||
| 
 | ||
|     # 训练LDA模型
 | ||
|     lda_model = LdaModel(corpus, num_topics=num_topics, id2word=dictionary, passes=15)
 | ||
|     return lda_model, dictionary, corpus
 | ||
| 
 | ||
| 
 | ||
| # 打印每个主题的关键词
 | ||
| def print_topics(lda_model, num_words=10):
 | ||
|     for idx, topic in lda_model.print_topics(num_words=num_words):
 | ||
|         print(f"主题 {idx}: {topic}")
 | ||
| 
 | ||
| 
 | ||
| if __name__ == '__main__':
 | ||
|     qa_file = "output/train_optimized_multiple.jsonl"  # 问答对文件
 | ||
|     stopwords_file = "chinese_stopwords.txt"  # 停用词文件
 | ||
| 
 | ||
|     # 加载问答对
 | ||
|     qa_pairs = load_qa_data(qa_file)
 | ||
| 
 | ||
|     # 加载停用词
 | ||
|     stopwords = load_stopwords(stopwords_file)
 | ||
| 
 | ||
|     # 构建LDA主题模型
 | ||
|     lda_model, dictionary, corpus = build_lda_model(qa_pairs, stopwords, num_topics=20)
 | ||
| 
 | ||
|     # 打印主题及其关键词
 | ||
|     print_topics(lda_model)
 |