commit
						f156a9c42c
					
				
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -1,4 +1,5 @@ | ||||
| ESConv.json | ||||
| .DS_Store | ||||
| __pycache__/ | ||||
| tmp/ | ||||
| tmp/ | ||||
| data/zhipuai/ | ||||
| @ -1,5 +1,8 @@ | ||||
| # EmoLLM | ||||
| 
 | ||||
| 
 | ||||
| ## 🌟 Contributors | ||||
| 
 | ||||
| [](https://github.com/aJupyter/EmoLLM/graphs/contributors) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -3,38 +3,38 @@ import os | ||||
| 
 | ||||
| 
 | ||||
| def save_merge_json(data_lis, file_path): | ||||
|     import json | ||||
| 
 | ||||
|     with open(file_path, 'wt', encoding='utf-8') as file: | ||||
|         json.dump(data_lis, file, ensure_ascii=False) | ||||
|         json.dump(data_lis, file, indent=4, ensure_ascii=False) | ||||
| 
 | ||||
| 
 | ||||
| def get_all_file_paths(folder_path): | ||||
|     # 确保传入的是一个目录 | ||||
|     if not os.path.isdir(folder_path): | ||||
|         raise ValueError(f"{folder_path} is not a valid directory") | ||||
| 
 | ||||
|     # 获取文件夹下所有文件的路径 | ||||
|     file_paths = [os.path.join(folder_path, file) for file in os.listdir( | ||||
|         folder_path) if os.path.isfile(os.path.join(folder_path, file))] | ||||
|     return file_paths | ||||
|     files = os.listdir(folder_path) | ||||
|     path = [] | ||||
|     for file in files: | ||||
|         file_path = os.path.join(folder_path, file) | ||||
|         if os.path.isdir(file_path): | ||||
|             path.extend(get_all_file_paths(file_path)) | ||||
|         else: | ||||
|             path.append(file_path) | ||||
|     return path | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     conversion_lis = [] | ||||
|     folder_path = '' # input | ||||
|     merge_path = '' # input | ||||
|     paths = get_all_file_paths(folder_path=folder_path) | ||||
| 
 | ||||
|     for path in get_all_file_paths('res/'): | ||||
|     for path in paths: | ||||
|         print(path) | ||||
| 
 | ||||
|         with open(path, 'rt', encoding='utf-8') as file: | ||||
|             for line in file: | ||||
|         with open(path, 'rt', encoding='utf-8') as lines: | ||||
|             for line in lines: | ||||
|                 # 移除行尾的换行符 | ||||
|                 line = line.rstrip('\n') | ||||
|                 line.rstrip('\n') | ||||
|                 # 解析JSON | ||||
|                 try: | ||||
|                     data = json.loads(line) | ||||
|                     conversion_lis.append(data) | ||||
|                 except json.JSONDecodeError as e: | ||||
|                     print(f"Error decoding JSON: {e}") | ||||
|          | ||||
|     save_merge_json(data_lis=conversion_lis, file_path='merge.json') | ||||
|         save_merge_json(data_lis=conversion_lis, file_path=merge_path) | ||||
| @ -1,4 +1,5 @@ | ||||
| import os | ||||
| import random | ||||
| import json | ||||
| from tqdm import tqdm | ||||
| from dotenv import load_dotenv | ||||
| @ -22,10 +23,12 @@ def zhipu_api(data, emo): | ||||
| 医生:医生的安抚和建议 | ||||
|     ''' | ||||
|      | ||||
|     top_p = round(random.uniform(0.1, 0.9), 2) | ||||
|     messages = getText('user', prompt) | ||||
|     response = client.chat.completions.create( | ||||
|         model='glm-4', | ||||
|         messages=messages, | ||||
|         top_p=top_p, | ||||
|     ) | ||||
| 
 | ||||
|     return response.choices[0].message.content | ||||
| @ -47,6 +50,8 @@ def convert(conversation): | ||||
| 
 | ||||
| 
 | ||||
| def save_jsonl(data_lis, file_path): | ||||
|     if not os.path.exists(os.path.dirname(file_path)): | ||||
|         os.makedirs(os.path.dirname(file_path)) | ||||
|     with open(file_path, 'w', encoding='utf-8') as f: | ||||
|         for item in data_lis: | ||||
|             f.write(json.dumps(item, ensure_ascii=False) + '\n') | ||||
| @ -67,7 +72,7 @@ if __name__ == '__main__': | ||||
|         "渴望", | ||||
|         "厌恶", | ||||
|         "同情", | ||||
|         "痛苦" | ||||
|         "痛苦", | ||||
|         "着迷", | ||||
|         "嫉妒", | ||||
|         "兴奋", | ||||
| @ -80,7 +85,6 @@ if __name__ == '__main__': | ||||
|         "悲伤", | ||||
|         "满意", | ||||
|         "性欲", | ||||
|         "同情", | ||||
|         "满足" | ||||
|     ] | ||||
|     areas_of_life = [ | ||||
| @ -103,22 +107,18 @@ if __name__ == '__main__': | ||||
|     ] | ||||
| 
 | ||||
|     conversation_lis = [] | ||||
|     idx = 0 | ||||
|     for area in areas_of_life: | ||||
|         j = 0 | ||||
|         for idx in tqdm(range(len(emotions_lis)), desc=f'data:{area}, emo:{emotions_lis[j]}'): | ||||
|             emo = emotions_lis[j] | ||||
|             res = zhipu_api(area, emo) | ||||
|             print(res) | ||||
|             if res == 'null': | ||||
|                 print(area, emo, 'error') | ||||
|     for emo in emotions_lis: | ||||
|         for area in areas_of_life: | ||||
|             if os.path.exists(f'./zhipuai/{area}/{emo}.jsonl'): | ||||
|                 print(f'./zhipuai/{area}/{emo}.jsonl exists') | ||||
|                 continue | ||||
|             conversation_lis.append(convert(res)) | ||||
|             if idx % 2 == 1: | ||||
|                 save_jsonl(conversation_lis, f'./zhipuai_{idx}.jsonl') | ||||
|                 conversation_lis = [] | ||||
|                 idx += 1 | ||||
|             j += 1 | ||||
|         if len(conversation_lis) > 0: | ||||
|             save_jsonl(conversation_lis, f'./zhipuai.jsonl') | ||||
|             conversation_lis = [] | ||||
|             for i in tqdm(range(5), desc='{emo}, {area}'.format(emo=emo, area=area)): | ||||
|                 res = zhipu_api(area, emo) | ||||
|                 print(res) | ||||
|                 if res == 'null': | ||||
|                     print(area, emo, 'error') | ||||
|                     continue | ||||
|                 conversation_lis.append(convert(res)) | ||||
|             save_jsonl(conversation_lis, f'./zhipuai/{area}/{emo}.jsonl') | ||||
|             print(f'generate ./zhipuai/{area}/{emo}.jsonl') | ||||
|             conversation_lis = [] | ||||
|  | ||||
							
								
								
									
										194
									
								
								finetune/ft_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								finetune/ft_config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,194 @@ | ||||
| # Copyright (c) OpenMMLab. All rights reserved. | ||||
| import torch | ||||
| from datasets import load_dataset | ||||
| from mmengine.dataset import DefaultSampler | ||||
| from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | ||||
|                             LoggerHook, ParamSchedulerHook) | ||||
| from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR | ||||
| from peft import LoraConfig | ||||
| from torch.optim import AdamW | ||||
| from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||||
|                           BitsAndBytesConfig) | ||||
| 
 | ||||
| from xtuner.dataset import process_hf_dataset | ||||
| from xtuner.dataset.collate_fns import default_collate_fn | ||||
| from xtuner.dataset.map_fns import template_map_fn_factory | ||||
| from xtuner.engine import DatasetInfoHook, EvaluateChatHook | ||||
| from xtuner.model import SupervisedFinetune | ||||
| from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE | ||||
| 
 | ||||
| ####################################################################### | ||||
| #                          PART 1  Settings                           # | ||||
| ####################################################################### | ||||
| # Model | ||||
| pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b' | ||||
| 
 | ||||
| 
 | ||||
| # Data | ||||
| data_path = 'merge.json' | ||||
| prompt_template = PROMPT_TEMPLATE.internlm2_chat | ||||
| max_length = 2048 | ||||
| pack_to_max_length = True | ||||
| 
 | ||||
| # Scheduler & Optimizer | ||||
| batch_size = 8 # per_device | ||||
| accumulative_counts = 2 | ||||
| dataloader_num_workers = 0 | ||||
| max_epochs = 3 | ||||
| optim_type = AdamW | ||||
| lr = 2e-4 | ||||
| betas = (0.9, 0.999) | ||||
| weight_decay = 0 | ||||
| max_norm = 1  # grad clip | ||||
| warmup_ratio = 0.03 | ||||
| 
 | ||||
| # Evaluate the generation performance during the training | ||||
| evaluation_freq = 500 | ||||
| SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。" | ||||
| evaluation_inputs = [ | ||||
|     '我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。' | ||||
| ] | ||||
| 
 | ||||
| ####################################################################### | ||||
| #                      PART 2  Model & Tokenizer                      # | ||||
| ####################################################################### | ||||
| tokenizer = dict( | ||||
|     type=AutoTokenizer.from_pretrained, | ||||
|     pretrained_model_name_or_path=pretrained_model_name_or_path, | ||||
|     trust_remote_code=True, | ||||
|     padding_side='right') | ||||
| 
 | ||||
| model = dict( | ||||
|     type=SupervisedFinetune, | ||||
|     llm=dict( | ||||
|         type=AutoModelForCausalLM.from_pretrained, | ||||
|         pretrained_model_name_or_path=pretrained_model_name_or_path, | ||||
|         trust_remote_code=True, | ||||
|         torch_dtype=torch.float16, | ||||
|         quantization_config=dict( | ||||
|             type=BitsAndBytesConfig, | ||||
|             load_in_4bit=True, | ||||
|             load_in_8bit=False, | ||||
|             llm_int8_threshold=6.0, | ||||
|             llm_int8_has_fp16_weight=False, | ||||
|             bnb_4bit_compute_dtype=torch.float16, | ||||
|             bnb_4bit_use_double_quant=True, | ||||
|             bnb_4bit_quant_type='nf4')), | ||||
|     lora=dict( | ||||
|         type=LoraConfig, | ||||
|         r=64, | ||||
|         lora_alpha=16, | ||||
|         lora_dropout=0.1, | ||||
|         bias='none', | ||||
|         task_type='CAUSAL_LM')) | ||||
| 
 | ||||
| ####################################################################### | ||||
| #                      PART 3  Dataset & Dataloader                   # | ||||
| ####################################################################### | ||||
| alpaca_en = dict( | ||||
|     type=process_hf_dataset, | ||||
|     dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)), | ||||
|     tokenizer=tokenizer, | ||||
|     max_length=max_length, | ||||
|     dataset_map_fn=None, | ||||
|     template_map_fn=dict( | ||||
|         type=template_map_fn_factory, template=prompt_template), | ||||
|     remove_unused_columns=True, | ||||
|     shuffle_before_pack=True, | ||||
|     pack_to_max_length=pack_to_max_length) | ||||
| 
 | ||||
| train_dataloader = dict( | ||||
|     batch_size=batch_size, | ||||
|     num_workers=dataloader_num_workers, | ||||
|     dataset=alpaca_en, | ||||
|     sampler=dict(type=DefaultSampler, shuffle=True), | ||||
|     collate_fn=dict(type=default_collate_fn)) | ||||
| 
 | ||||
| ####################################################################### | ||||
| #                    PART 4  Scheduler & Optimizer                    # | ||||
| ####################################################################### | ||||
| # optimizer | ||||
| optim_wrapper = dict( | ||||
|     type=AmpOptimWrapper, | ||||
|     optimizer=dict( | ||||
|         type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | ||||
|     clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | ||||
|     accumulative_counts=accumulative_counts, | ||||
|     loss_scale='dynamic', | ||||
|     dtype='float16') | ||||
| 
 | ||||
| # learning policy | ||||
| # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501 | ||||
| param_scheduler = [ | ||||
|     dict( | ||||
|         type=LinearLR, | ||||
|         start_factor=1e-5, | ||||
|         by_epoch=True, | ||||
|         begin=0, | ||||
|         end=warmup_ratio * max_epochs, | ||||
|         convert_to_iter_based=True), | ||||
|     dict( | ||||
|         type=CosineAnnealingLR, | ||||
|         eta_min=0.0, | ||||
|         by_epoch=True, | ||||
|         begin=warmup_ratio * max_epochs, | ||||
|         T_max=max_epochs, | ||||
|         convert_to_iter_based=True) | ||||
| ] | ||||
| 
 | ||||
| # train, val, test setting | ||||
| train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) | ||||
| 
 | ||||
| ####################################################################### | ||||
| #                           PART 5  Runtime                           # | ||||
| ####################################################################### | ||||
| # Log the dialogue periodically during the training process, optional | ||||
| custom_hooks = [ | ||||
|     dict(type=DatasetInfoHook, tokenizer=tokenizer), | ||||
|     dict( | ||||
|         type=EvaluateChatHook, | ||||
|         tokenizer=tokenizer, | ||||
|         every_n_iters=evaluation_freq, | ||||
|         evaluation_inputs=evaluation_inputs, | ||||
|         system=SYSTEM, | ||||
|         prompt_template=prompt_template) | ||||
| ] | ||||
| 
 | ||||
| # configure default hooks | ||||
| default_hooks = dict( | ||||
|     # record the time of every iteration. | ||||
|     timer=dict(type=IterTimerHook), | ||||
|     # print log every 100 iterations. | ||||
|     logger=dict(type=LoggerHook, interval=10), | ||||
|     # enable the parameter scheduler. | ||||
|     param_scheduler=dict(type=ParamSchedulerHook), | ||||
|     # save checkpoint per epoch. | ||||
|     checkpoint=dict(type=CheckpointHook, interval=1), | ||||
|     # set sampler seed in distributed evrionment. | ||||
|     sampler_seed=dict(type=DistSamplerSeedHook), | ||||
| ) | ||||
| 
 | ||||
| # configure environment | ||||
| env_cfg = dict( | ||||
|     # whether to enable cudnn benchmark | ||||
|     cudnn_benchmark=False, | ||||
|     # set multi process parameters | ||||
|     mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||||
|     # set distributed parameters | ||||
|     dist_cfg=dict(backend='nccl'), | ||||
| ) | ||||
| 
 | ||||
| # set visualizer | ||||
| visualizer = None | ||||
| 
 | ||||
| # set log level | ||||
| log_level = 'INFO' | ||||
| 
 | ||||
| # load from which checkpoint | ||||
| load_from = None | ||||
| 
 | ||||
| # whether to resume training from the loaded checkpoint | ||||
| resume = False | ||||
| 
 | ||||
| # Defaults to use random seed and disable `deterministic` | ||||
| randomness = dict(seed=None, deterministic=False) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 xzwang
						xzwang