mirror of https://github.com/Jittor/Jittor
fix batch matmul bugs (grad)
This commit is contained in:
parent
d11c3ad40b
commit
f54701a71c
|
@ -39,11 +39,17 @@ VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
// a [b,n,m] b [b,m,k], c[b,n,k]
|
||||
// c = a*b
|
||||
if (v_index == 0) {
|
||||
// da = dc*b^T
|
||||
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1);
|
||||
if (trans_a)
|
||||
return make_cublas_batched_matmul(b, dout, trans_b, 1);
|
||||
else
|
||||
// da = dc*b^T
|
||||
return make_cublas_batched_matmul(dout, b, 0, trans_b^1);
|
||||
} else {
|
||||
// db = a^T*dc
|
||||
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0);
|
||||
if (trans_b)
|
||||
return make_cublas_batched_matmul(dout, a, 1, trans_a);
|
||||
else
|
||||
// db = a^T*dc
|
||||
return make_cublas_batched_matmul(a, dout, trans_a^1, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue