理解 Trainer 的两个核心点就在于
train/predict/evaluate逻辑- callback 处理
Trainer 三个核心函数的主逻辑都是流程控制 其中 train 最复杂 predict/evaluate 在默认实现中 这俩差不多 都是遍历数据集 然后执行单步操作 再触发callback
pipeline
下面代码基于 transformers==4.57.3 的源码进行拷贝和逻辑删减
流程相关标志 (TrainerControl 没有 should_predict)
| |
train
训练的核心是基础流程和状态管理 (可以通过 self.is_in_train 来判断是否在训练)
基础流程顺着 epoch - step - substep 的逻辑 同时在执行前后触发callback 流程的状态变更在callback中实现(DefaultFlowCallback) 并通过 control.should_ 标志来判断
其中 self._maybe_log_save_evaluate 包含来所有训练外的内容
| |
如果没有提供 compute_loss_func 也没有设置 label_smooth 那么认为 loss 的计算是包含在模型 forward 中 (transformers 模型也都是这样)
| |
predict
predict 和 evaluate 的逻辑差不多 都是使用一样的loop 然后执行callback
| |
evaluate
这里去掉了数据集为字典的处理
| |
eval_loop
Trainer没有evaluate_step都是使用的prediction_step
loop 中的逻辑为遍历 loader 然后执行预测步 默认的 prediction_step 逻辑为单步eval (实际逻辑跟单步train差不多) 并非sample
如果想要 sampling 的逻辑 需要使用 Seq2SeqTrainer (这里修改了 prediction_step 设置 predict_with_generate 参数即可在 predict 时 使用 HF generate)
在
evaluation_loop中可以通过 description 来区分来源 不过prediction_step无法区分
| |
目前会存在一个问题: 无法直接获取生成结果
evaluate和_evaluate只会返回output.metrics无法通过eval获得sampling结果
on_predict/on_evaluate只能获取到output.metrics
如果想尽可能复用代码 就只能使用 compute_metrics: Callable[[EvalPrediction], dict] 来将结果保存到 output.metrics 里面 然后通过 on_evaluate 获取结果
_maybe_log_save_evaluate
_evaluate 在 evaluate 的基础上包装了一下 基本就是后者的逻辑
| |
callback
Trainer的存储逻辑是直接实现的 而非通过 callback
默认的 callback 至少有两个 另外根据 report_to 获得额外的训练日志callback (一般情况下 WandbCallback 只在 on_log 记录日志)
DefaultFlowCallback: 流程控制 在 step_end/epoch_end 时控制标志位 (log/save/eval)PrinterCallbackorProgressCallback: 取决于是否使用tqdm 输出训练进度
| |
TrainerCallback 的所有回调如下
触发时机最早的是 on_init_end 在 Trainer.__init__ 时执行
所有的回调都是在动作后执行 参数列表都是 (args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
on_init_endon_train_beginon_train_endon_epoch_beginon_epoch_endon_step_beginon_pre_optimizer_stepon_optimizer_stepon_substep_endon_step_endon_evaluateon_predicton_saveon_logon_prediction_step
Trainer 持有 CallbackHandler 触发时顺序执行回调 下面也展示了必然存在的 kwargs 元素
| |