This commit is contained in:
meizhiyu.mzy 2025-02-27 15:51:17 +08:00
parent e30790db11
commit 40b5d12490
1 changed files with 3 additions and 2 deletions

View File

@ -2,7 +2,7 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from typing import Dict
from typing import Dict, Literal
import torch
import torch.distributed as dist
@ -101,7 +101,8 @@ class SFTInterface(model_api.ModelInterface):
stat = module.train_batch(
input_=data,
loss_fn=compute_packed_sft_loss,
loss_weight_fn=lambda x: x.data["prompt_mask"].count_nonzero()
loss_weight_fn=lambda x: x.data["prompt_mask"]
.count_nonzero()
.logical_not()
.count_nonzero(),
token_normalize_scope=self.token_normalize_scope,