mirror of https://github.com/inclusionAI/AReaL
This commit is contained in:
parent
e30790db11
commit
40b5d12490
|
@ -2,7 +2,7 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Literal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -101,7 +101,8 @@ class SFTInterface(model_api.ModelInterface):
|
|||
stat = module.train_batch(
|
||||
input_=data,
|
||||
loss_fn=compute_packed_sft_loss,
|
||||
loss_weight_fn=lambda x: x.data["prompt_mask"].count_nonzero()
|
||||
loss_weight_fn=lambda x: x.data["prompt_mask"]
|
||||
.count_nonzero()
|
||||
.logical_not()
|
||||
.count_nonzero(),
|
||||
token_normalize_scope=self.token_normalize_scope,
|
||||
|
|
Loading…
Reference in New Issue