mirror of https://github.com/inclusionAI/AReaL
PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine
Merge branch mzy/add-gradient-ckpt of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/353 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add gradient checkpointing
This commit is contained in:
parent
8d4b8dc90f
commit
434d2f5064
|
@ -109,6 +109,11 @@ class FSDPEngine(TrainEngine):
|
|||
attn_implementation=self.config.attn_impl,
|
||||
)
|
||||
|
||||
if self.config.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
|
||||
# Simple auto wrap policy
|
||||
self.mixed_precision_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
|
|
Loading…
Reference in New Issue