ADD ft.config @aJupyter
This commit is contained in:
		
							parent
							
								
									48c67f7299
								
							
						
					
					
						commit
						f246532984
					
				
							
								
								
									
										194
									
								
								finetune/ft_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								finetune/ft_config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,194 @@ | |||||||
|  | # Copyright (c) OpenMMLab. All rights reserved. | ||||||
|  | import torch | ||||||
|  | from datasets import load_dataset | ||||||
|  | from mmengine.dataset import DefaultSampler | ||||||
|  | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | ||||||
|  |                             LoggerHook, ParamSchedulerHook) | ||||||
|  | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR | ||||||
|  | from peft import LoraConfig | ||||||
|  | from torch.optim import AdamW | ||||||
|  | from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||||||
|  |                           BitsAndBytesConfig) | ||||||
|  | 
 | ||||||
|  | from xtuner.dataset import process_hf_dataset | ||||||
|  | from xtuner.dataset.collate_fns import default_collate_fn | ||||||
|  | from xtuner.dataset.map_fns import template_map_fn_factory | ||||||
|  | from xtuner.engine import DatasetInfoHook, EvaluateChatHook | ||||||
|  | from xtuner.model import SupervisedFinetune | ||||||
|  | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE | ||||||
|  | 
 | ||||||
|  | ####################################################################### | ||||||
|  | #                          PART 1  Settings                           # | ||||||
|  | ####################################################################### | ||||||
|  | # Model | ||||||
|  | pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b' | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Data | ||||||
|  | data_path = 'merge.json' | ||||||
|  | prompt_template = PROMPT_TEMPLATE.internlm2_chat | ||||||
|  | max_length = 2048 | ||||||
|  | pack_to_max_length = True | ||||||
|  | 
 | ||||||
|  | # Scheduler & Optimizer | ||||||
|  | batch_size = 8 # per_device | ||||||
|  | accumulative_counts = 2 | ||||||
|  | dataloader_num_workers = 0 | ||||||
|  | max_epochs = 3 | ||||||
|  | optim_type = AdamW | ||||||
|  | lr = 2e-4 | ||||||
|  | betas = (0.9, 0.999) | ||||||
|  | weight_decay = 0 | ||||||
|  | max_norm = 1  # grad clip | ||||||
|  | warmup_ratio = 0.03 | ||||||
|  | 
 | ||||||
|  | # Evaluate the generation performance during the training | ||||||
|  | evaluation_freq = 500 | ||||||
|  | SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。" | ||||||
|  | evaluation_inputs = [ | ||||||
|  |     '我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。' | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | ####################################################################### | ||||||
|  | #                      PART 2  Model & Tokenizer                      # | ||||||
|  | ####################################################################### | ||||||
|  | tokenizer = dict( | ||||||
|  |     type=AutoTokenizer.from_pretrained, | ||||||
|  |     pretrained_model_name_or_path=pretrained_model_name_or_path, | ||||||
|  |     trust_remote_code=True, | ||||||
|  |     padding_side='right') | ||||||
|  | 
 | ||||||
|  | model = dict( | ||||||
|  |     type=SupervisedFinetune, | ||||||
|  |     llm=dict( | ||||||
|  |         type=AutoModelForCausalLM.from_pretrained, | ||||||
|  |         pretrained_model_name_or_path=pretrained_model_name_or_path, | ||||||
|  |         trust_remote_code=True, | ||||||
|  |         torch_dtype=torch.float16, | ||||||
|  |         quantization_config=dict( | ||||||
|  |             type=BitsAndBytesConfig, | ||||||
|  |             load_in_4bit=True, | ||||||
|  |             load_in_8bit=False, | ||||||
|  |             llm_int8_threshold=6.0, | ||||||
|  |             llm_int8_has_fp16_weight=False, | ||||||
|  |             bnb_4bit_compute_dtype=torch.float16, | ||||||
|  |             bnb_4bit_use_double_quant=True, | ||||||
|  |             bnb_4bit_quant_type='nf4')), | ||||||
|  |     lora=dict( | ||||||
|  |         type=LoraConfig, | ||||||
|  |         r=64, | ||||||
|  |         lora_alpha=16, | ||||||
|  |         lora_dropout=0.1, | ||||||
|  |         bias='none', | ||||||
|  |         task_type='CAUSAL_LM')) | ||||||
|  | 
 | ||||||
|  | ####################################################################### | ||||||
|  | #                      PART 3  Dataset & Dataloader                   # | ||||||
|  | ####################################################################### | ||||||
|  | alpaca_en = dict( | ||||||
|  |     type=process_hf_dataset, | ||||||
|  |     dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)), | ||||||
|  |     tokenizer=tokenizer, | ||||||
|  |     max_length=max_length, | ||||||
|  |     dataset_map_fn=None, | ||||||
|  |     template_map_fn=dict( | ||||||
|  |         type=template_map_fn_factory, template=prompt_template), | ||||||
|  |     remove_unused_columns=True, | ||||||
|  |     shuffle_before_pack=True, | ||||||
|  |     pack_to_max_length=pack_to_max_length) | ||||||
|  | 
 | ||||||
|  | train_dataloader = dict( | ||||||
|  |     batch_size=batch_size, | ||||||
|  |     num_workers=dataloader_num_workers, | ||||||
|  |     dataset=alpaca_en, | ||||||
|  |     sampler=dict(type=DefaultSampler, shuffle=True), | ||||||
|  |     collate_fn=dict(type=default_collate_fn)) | ||||||
|  | 
 | ||||||
|  | ####################################################################### | ||||||
|  | #                    PART 4  Scheduler & Optimizer                    # | ||||||
|  | ####################################################################### | ||||||
|  | # optimizer | ||||||
|  | optim_wrapper = dict( | ||||||
|  |     type=AmpOptimWrapper, | ||||||
|  |     optimizer=dict( | ||||||
|  |         type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | ||||||
|  |     clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | ||||||
|  |     accumulative_counts=accumulative_counts, | ||||||
|  |     loss_scale='dynamic', | ||||||
|  |     dtype='float16') | ||||||
|  | 
 | ||||||
|  | # learning policy | ||||||
|  | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501 | ||||||
|  | param_scheduler = [ | ||||||
|  |     dict( | ||||||
|  |         type=LinearLR, | ||||||
|  |         start_factor=1e-5, | ||||||
|  |         by_epoch=True, | ||||||
|  |         begin=0, | ||||||
|  |         end=warmup_ratio * max_epochs, | ||||||
|  |         convert_to_iter_based=True), | ||||||
|  |     dict( | ||||||
|  |         type=CosineAnnealingLR, | ||||||
|  |         eta_min=0.0, | ||||||
|  |         by_epoch=True, | ||||||
|  |         begin=warmup_ratio * max_epochs, | ||||||
|  |         T_max=max_epochs, | ||||||
|  |         convert_to_iter_based=True) | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | # train, val, test setting | ||||||
|  | train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) | ||||||
|  | 
 | ||||||
|  | ####################################################################### | ||||||
|  | #                           PART 5  Runtime                           # | ||||||
|  | ####################################################################### | ||||||
|  | # Log the dialogue periodically during the training process, optional | ||||||
|  | custom_hooks = [ | ||||||
|  |     dict(type=DatasetInfoHook, tokenizer=tokenizer), | ||||||
|  |     dict( | ||||||
|  |         type=EvaluateChatHook, | ||||||
|  |         tokenizer=tokenizer, | ||||||
|  |         every_n_iters=evaluation_freq, | ||||||
|  |         evaluation_inputs=evaluation_inputs, | ||||||
|  |         system=SYSTEM, | ||||||
|  |         prompt_template=prompt_template) | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | # configure default hooks | ||||||
|  | default_hooks = dict( | ||||||
|  |     # record the time of every iteration. | ||||||
|  |     timer=dict(type=IterTimerHook), | ||||||
|  |     # print log every 100 iterations. | ||||||
|  |     logger=dict(type=LoggerHook, interval=10), | ||||||
|  |     # enable the parameter scheduler. | ||||||
|  |     param_scheduler=dict(type=ParamSchedulerHook), | ||||||
|  |     # save checkpoint per epoch. | ||||||
|  |     checkpoint=dict(type=CheckpointHook, interval=1), | ||||||
|  |     # set sampler seed in distributed evrionment. | ||||||
|  |     sampler_seed=dict(type=DistSamplerSeedHook), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # configure environment | ||||||
|  | env_cfg = dict( | ||||||
|  |     # whether to enable cudnn benchmark | ||||||
|  |     cudnn_benchmark=False, | ||||||
|  |     # set multi process parameters | ||||||
|  |     mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||||||
|  |     # set distributed parameters | ||||||
|  |     dist_cfg=dict(backend='nccl'), | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # set visualizer | ||||||
|  | visualizer = None | ||||||
|  | 
 | ||||||
|  | # set log level | ||||||
|  | log_level = 'INFO' | ||||||
|  | 
 | ||||||
|  | # load from which checkpoint | ||||||
|  | load_from = None | ||||||
|  | 
 | ||||||
|  | # whether to resume training from the loaded checkpoint | ||||||
|  | resume = False | ||||||
|  | 
 | ||||||
|  | # Defaults to use random seed and disable `deterministic` | ||||||
|  | randomness = dict(seed=None, deterministic=False) | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 jupyter
						jupyter