feat: Update Aiwei configuration.
This commit is contained in:
commit
c696e163cd
@ -60,7 +60,7 @@
|
||||
- 评估和诊断工具:为了有效促进心理健康,需要有科学的工具来评估个体的心理状态,以及诊断可能存在的心理问题。
|
||||
|
||||
### 最近更新
|
||||
|
||||
- 【2024.2.23】推出基于InternLM2_7B_chat_qlora的 `温柔御姐心理医生艾薇`,[点击获取模型权重](https://openxlab.org.cn/models/detail/ajupyter/EmoLLM_aiwei),[配置文件](xtuner_config/aiwei-internlm2_chat_7b_qlora.py)
|
||||
- 【2024.2.23】更新[若干微调配置](/xtuner_config/),新增 [data_pro.json](/datasets/data_pro.json)(数量更多、场景更全、更丰富)和 [aiwei.json](/datasets/aiwei.json)(温柔御姐角色扮演专用,带有Emoji表情),即将推出 `温柔御姐心理医生艾薇`
|
||||
- 【2024.2.18】 [基于Qwen1_5-0_5B-Chat全量微调版本开源](https://www.modelscope.cn/models/aJupyter/EmoLLM_Qwen1_5-0_5B-Chat_full_sft/summary),算力有限的道友可以玩起来~
|
||||
- 【2024.2.6】 EmoLLM在[**Openxlab** ](https://openxlab.org.cn/models/detail/jujimeizuo/EmoLLM_Model) 平台下载量高达18.7k,欢迎大家体验!
|
||||
|
3
app.py
3
app.py
@ -1,2 +1,3 @@
|
||||
import os
|
||||
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
||||
# os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860')
|
||||
os.system('streamlit run web_demo-aiwei.py --server.address=0.0.0.0 --server.port 7860')
|
||||
|
BIN
assets/aiwei_logo.jpg
Normal file
BIN
assets/aiwei_logo.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 80 KiB |
267
web_demo-aiwei.py
Normal file
267
web_demo-aiwei.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers.
|
||||
We mainly modified part of the code logic to adapt to the generation of our model.
|
||||
Please refer to these links below for more information:
|
||||
1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
|
||||
2. chatglm2: https://github.com/THUDM/ChatGLM2-6B
|
||||
3. transformers: https://github.com/huggingface/transformers
|
||||
Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`.
|
||||
Using `python path/to/web_demo.py` may cause unknown problems.
|
||||
"""
|
||||
import copy
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import streamlit as st
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
|
||||
from openxlab.model import download
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
download(model_repo='ajupyter/EmoLLM_aiwei',
|
||||
output='model')
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
# this config is used for chat to provide more diversity
|
||||
max_length: int = 32768
|
||||
top_p: float = 0.8
|
||||
temperature: float = 0.8
|
||||
do_sample: bool = True
|
||||
repetition_penalty: float = 1.005
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_interactive(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
additional_eos_token_id: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
|
||||
input_length = len(inputs["input_ids"][0])
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
input_ids = inputs["input_ids"]
|
||||
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841 # pylint: disable=W0612
|
||||
if generation_config is None:
|
||||
generation_config = model.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
||||
generation_config.bos_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
if additional_eos_token_id is not None:
|
||||
eos_token_id.append(additional_eos_token_id)
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||||
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||||
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warn( # pylint: disable=W4902
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_new_tokens`."
|
||||
)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=input_ids,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||
)
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
scores = None
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token
|
||||
outputs = model(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
if generation_config.do_sample:
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(probs, dim=-1)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
|
||||
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
output_token_ids = input_ids[0].cpu().tolist()
|
||||
output_token_ids = output_token_ids[input_length:]
|
||||
for each_eos_token_id in eos_token_id:
|
||||
if output_token_ids[-1] == each_eos_token_id:
|
||||
output_token_ids = output_token_ids[:-1]
|
||||
response = tokenizer.decode(output_token_ids)
|
||||
|
||||
yield response
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
|
||||
def on_btn_click():
|
||||
del st.session_state.messages
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def load_model():
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained("model", trust_remote_code=True)
|
||||
.to(torch.bfloat16)
|
||||
.cuda()
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("model", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_generation_config():
|
||||
with st.sidebar:
|
||||
# 使用 Streamlit 的 markdown 函数添加 Markdown 文本
|
||||
st.image('assets/aiwei_logo.jpg', width=1, caption='EmoLLM-aiwei AI Logo', use_column_width=True)
|
||||
st.markdown("[访问 EmoLLM 官方repo](https://github.com/aJupyter/EmoLLM)")
|
||||
|
||||
max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
|
||||
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
|
||||
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
|
||||
st.button("Clear Chat History", on_click=on_btn_click)
|
||||
|
||||
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
user_prompt = "<|im_start|>user\n{user}<|im_end|>\n"
|
||||
robot_prompt = "<|im_start|>assistant\n{robot}<|im_end|>\n"
|
||||
cur_query_prompt = "<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
def combine_history(prompt):
|
||||
messages = st.session_state.messages
|
||||
meta_instruction = (
|
||||
"你是一个拥有丰富心理学知识的温柔邻家温柔大姐姐艾薇,我有一些心理问题,请你用专业的知识和温柔、可爱、俏皮、的口吻帮我解决,回复中可以穿插一些可爱的Emoji表情符号或者文本符号。\n"
|
||||
)
|
||||
total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
|
||||
for message in messages:
|
||||
cur_content = message["content"]
|
||||
if message["role"] == "user":
|
||||
cur_prompt = user_prompt.format(user=cur_content)
|
||||
elif message["role"] == "robot":
|
||||
cur_prompt = robot_prompt.format(robot=cur_content)
|
||||
else:
|
||||
raise RuntimeError
|
||||
total_prompt += cur_prompt
|
||||
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
|
||||
return total_prompt
|
||||
|
||||
|
||||
def main():
|
||||
# torch.cuda.empty_cache()
|
||||
print("load model begin.")
|
||||
model, tokenizer = load_model()
|
||||
print("load model end.")
|
||||
|
||||
user_avator = "assets/user.png"
|
||||
robot_avator = "assets/robot.jpeg"
|
||||
|
||||
st.title("EmoLLM-温柔御姐艾薇(aiwei)")
|
||||
|
||||
generation_config = prepare_generation_config()
|
||||
|
||||
# Initialize chat history
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"], avatar=message.get("avatar")):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# Accept user input
|
||||
if prompt := st.chat_input("What is up?"):
|
||||
# Display user message in chat message container
|
||||
with st.chat_message("user", avatar=user_avator):
|
||||
st.markdown(prompt)
|
||||
real_prompt = combine_history(prompt)
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
|
||||
|
||||
with st.chat_message("robot", avatar=robot_avator):
|
||||
message_placeholder = st.empty()
|
||||
for cur_response in generate_interactive(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=real_prompt,
|
||||
additional_eos_token_id=92542,
|
||||
**asdict(generation_config),
|
||||
):
|
||||
# Display robot response in chat message container
|
||||
message_placeholder.markdown(cur_response + "▌")
|
||||
message_placeholder.markdown(cur_response) # pylint: disable=undefined-loop-variable
|
||||
# Add robot response to chat history
|
||||
st.session_state.messages.append(
|
||||
{
|
||||
"role": "robot",
|
||||
"content": cur_response, # pylint: disable=undefined-loop-variable
|
||||
"avatar": robot_avator,
|
||||
}
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
218
xtuner_config/aiwei-internlm2_chat_7b_qlora.py
Normal file
218
xtuner_config/aiwei-internlm2_chat_7b_qlora.py
Normal file
@ -0,0 +1,218 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
||||
LoggerHook, ParamSchedulerHook)
|
||||
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
||||
from peft import LoraConfig
|
||||
from torch.optim import AdamW
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig)
|
||||
|
||||
from xtuner.dataset import process_hf_dataset
|
||||
from xtuner.dataset.collate_fns import default_collate_fn
|
||||
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
||||
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
||||
VarlenAttnArgsToMessageHubHook)
|
||||
from xtuner.engine.runner import TrainLoop
|
||||
from xtuner.model import SupervisedFinetune
|
||||
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
||||
|
||||
from mmengine.visualization import Visualizer,WandbVisBackend, TensorboardVisBackend
|
||||
|
||||
#######################################################################
|
||||
# PART 1 Settings #
|
||||
#######################################################################
|
||||
# Model
|
||||
pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b'
|
||||
# /root/share/model_repos/internlm2-chat-7b
|
||||
use_varlen_attn = False
|
||||
|
||||
# Data
|
||||
data_path = './aiwei.json'
|
||||
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
||||
max_length = 2048
|
||||
pack_to_max_length = True
|
||||
|
||||
# Scheduler & Optimizer
|
||||
batch_size = 16 # per_device
|
||||
accumulative_counts = 1
|
||||
dataloader_num_workers = 0
|
||||
max_epochs = 5
|
||||
optim_type = AdamW
|
||||
lr = 1e-5
|
||||
betas = (0.9, 0.999)
|
||||
weight_decay = 0.0001
|
||||
max_norm = 1 # grad clip
|
||||
warmup_ratio = 0.03
|
||||
|
||||
# Save
|
||||
save_steps = 100
|
||||
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
||||
|
||||
# Evaluate the generation performance during the training
|
||||
evaluation_freq = 100
|
||||
SYSTEM = "现在你是一个拥有丰富心理学知识的温柔御姐艾薇医生,我有一些心理问题,请你用专业的知识和温柔的口吻帮我解决,可以生成一些可爱的Emoji表情符号或者文本符号。"
|
||||
evaluation_inputs = [
|
||||
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。'
|
||||
]
|
||||
|
||||
#######################################################################
|
||||
# PART 2 Model & Tokenizer #
|
||||
#######################################################################
|
||||
tokenizer = dict(
|
||||
type=AutoTokenizer.from_pretrained,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
padding_side='right')
|
||||
|
||||
model = dict(
|
||||
type=SupervisedFinetune,
|
||||
use_varlen_attn=use_varlen_attn,
|
||||
llm=dict(
|
||||
type=AutoModelForCausalLM.from_pretrained,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
quantization_config=dict(
|
||||
type=BitsAndBytesConfig,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type='nf4')),
|
||||
lora=dict(
|
||||
type=LoraConfig,
|
||||
r=64,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
bias='none',
|
||||
task_type='CAUSAL_LM'))
|
||||
|
||||
#######################################################################
|
||||
# PART 3 Dataset & Dataloader #
|
||||
#######################################################################
|
||||
alpaca_en = dict(
|
||||
type=process_hf_dataset,
|
||||
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
|
||||
tokenizer=tokenizer,
|
||||
max_length=max_length,
|
||||
dataset_map_fn=None,
|
||||
template_map_fn=dict(
|
||||
type=template_map_fn_factory, template=prompt_template),
|
||||
remove_unused_columns=True,
|
||||
shuffle_before_pack=True,
|
||||
pack_to_max_length=pack_to_max_length,
|
||||
use_varlen_attn=use_varlen_attn)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=batch_size,
|
||||
num_workers=dataloader_num_workers,
|
||||
dataset=alpaca_en,
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
||||
|
||||
#######################################################################
|
||||
# PART 4 Scheduler & Optimizer #
|
||||
#######################################################################
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type=AmpOptimWrapper,
|
||||
optimizer=dict(
|
||||
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
||||
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
||||
accumulative_counts=accumulative_counts,
|
||||
loss_scale='dynamic',
|
||||
dtype='float16')
|
||||
|
||||
# learning policy
|
||||
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type=LinearLR,
|
||||
start_factor=1e-5,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=warmup_ratio * max_epochs,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type=CosineAnnealingLR,
|
||||
eta_min=0.0,
|
||||
by_epoch=True,
|
||||
begin=warmup_ratio * max_epochs,
|
||||
end=max_epochs,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
||||
|
||||
#######################################################################
|
||||
# PART 5 Runtime #
|
||||
#######################################################################
|
||||
# Log the dialogue periodically during the training process, optional
|
||||
custom_hooks = [
|
||||
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
||||
dict(
|
||||
type=EvaluateChatHook,
|
||||
tokenizer=tokenizer,
|
||||
every_n_iters=evaluation_freq,
|
||||
evaluation_inputs=evaluation_inputs,
|
||||
system=SYSTEM,
|
||||
prompt_template=prompt_template)
|
||||
]
|
||||
|
||||
if use_varlen_attn:
|
||||
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
||||
|
||||
# configure default hooks
|
||||
default_hooks = dict(
|
||||
# record the time of every iteration.
|
||||
timer=dict(type=IterTimerHook),
|
||||
# print log every 10 iterations.
|
||||
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
||||
# enable the parameter scheduler.
|
||||
param_scheduler=dict(type=ParamSchedulerHook),
|
||||
# save checkpoint per `save_steps`.
|
||||
checkpoint=dict(
|
||||
type=CheckpointHook,
|
||||
by_epoch=False,
|
||||
interval=save_steps,
|
||||
max_keep_ckpts=save_total_limit),
|
||||
# set sampler seed in distributed evrionment.
|
||||
sampler_seed=dict(type=DistSamplerSeedHook),
|
||||
)
|
||||
|
||||
# configure environment
|
||||
env_cfg = dict(
|
||||
# whether to enable cudnn benchmark
|
||||
cudnn_benchmark=False,
|
||||
# set multi process parameters
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
# set distributed parameters
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
|
||||
# set visualizer
|
||||
visualizer = dict(
|
||||
type=Visualizer,
|
||||
vis_backends=[dict(type=WandbVisBackend)]
|
||||
)
|
||||
|
||||
# set log level
|
||||
log_level = 'INFO'
|
||||
|
||||
# load from which checkpoint
|
||||
load_from = None
|
||||
|
||||
# whether to resume training from the loaded checkpoint
|
||||
resume = True
|
||||
|
||||
# Defaults to use random seed and disable `deterministic`
|
||||
randomness = dict(seed=None, deterministic=False)
|
||||
|
||||
# set log processor
|
||||
log_processor = dict(by_epoch=False)
|
Loading…
Reference in New Issue
Block a user