mirror of https://github.com/Jittor/Jittor
Update attention.py fix parameter_name error
This commit is contained in:
parent
646a0346fb
commit
b166e4e385
|
@ -77,7 +77,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
attn_weight = softmax(attn_weight, dim=-1)
|
||||
attn_weight = dropout(attn_weight, dropout_p, train=True)
|
||||
attn_weight = dropout(attn_weight, dropout_p, is_train=True)
|
||||
return attn_weight @ value
|
||||
|
||||
def _mha_shape_check(query: Var, key: Var, value: Var,
|
||||
|
|
Loading…
Reference in New Issue