Update attention.py fix parameter_name error

This commit is contained in:
DongYang Li 2025-02-08 18:03:34 +08:00 committed by GitHub
parent 646a0346fb
commit b166e4e385
1 changed files with 1 additions and 1 deletions

View File

@ -77,7 +77,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = softmax(attn_weight, dim=-1)
attn_weight = dropout(attn_weight, dropout_p, train=True)
attn_weight = dropout(attn_weight, dropout_p, is_train=True)
return attn_weight @ value
def _mha_shape_check(query: Var, key: Var, value: Var,