... views
根据官方文档描述,设置resume_from_checkpoint参数为待恢复的检查点。
trainer.train(resume_from_checkpoint='finetuned/checkpoint-5000')
恢复训练后发现train loss、eval loss都升高了,似乎又从头开始训练了。
# 训练入口函数
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
if resume_from_checkpoint is not None:
if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
train()调用self._load_from_checkpoint(resume_from_checkpoint)函数加载待恢复的检查点。
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
if model is None:
model = self.model
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
# WEIGHTS_NAME: pytorch_model.bin
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
# if multiple adapters exist, they get saved in sub directories
adapter_subdirs = (
[
folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
and (
os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
)
]
if os.path.isdir(resume_from_checkpoint)
else []
)
if not (
any(
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
adapter_weights_file,
adapter_safe_weights_file,
]
)
or is_fsdp_ckpt
or adapter_subdirs
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}.")
# 如果存在pytorch_model.bin,会优先恢复pytorch_model.bin的权重
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
# If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled():
pass
elif self.is_fsdp_enabled:
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
state_dict = torch.load(
weights_file,
map_location="cpu",
**weights_only_kwarg,
)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
# release memory
del state_dict
self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
# TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
if os.path.exists(resume_from_checkpoint):
# For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapters = model.active_adapters
if len(active_adapters) > 1:
logger.warning("Multiple active adapters detected will only consider the first adapter")
active_adapter = active_adapters[0]
else:
active_adapter = model.active_adapter
if adapter_subdirs:
for subdir_name in adapter_subdirs:
peft_id = os.path.join(resume_from_checkpoint, subdir_name)
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
model.set_adapter(active_adapter)
else:
model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
)
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(
model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
)
if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result)
分析源码后,怀疑trainer优先加载了检查点的pytorch_model.bin权重,而没有加载Lora的adapter_model.safetensors权重文件。
于是,检查了checkpoint-5000目录下的权重文件,果然在finetuned/checkpoint-5000目录下有个888字节的pytorch_model.bin文件。把pytorch_model.bin重命名为pytorch_model.bin.0,尝试再次恢复训练,发现loss值正常了,训练能正常的恢复了!