PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine

Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/353

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add gradient checkpointing
This commit is contained in:
晓雷 2025-07-14 10:39:15 +08:00 committed by 博惟
parent 8d4b8dc90f
commit 434d2f5064
1 changed files with 5 additions and 0 deletions

View File

@ -109,6 +109,11 @@ class FSDPEngine(TrainEngine):
attn_implementation=self.config.attn_impl,
)
if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
# Simple auto wrap policy
self.mixed_precision_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,