fix batch matmul bugs (grad)

This commit is contained in:
li-xl 2020-11-27 15:18:31 +08:00
parent d11c3ad40b
commit f54701a71c
1 changed files with 10 additions and 4 deletions

View File

@ -39,11 +39,17 @@ VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// a [b,n,m] b [b,m,k], c[b,n,k]
// c = a*b
if (v_index == 0) {
// da = dc*b^T
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1);
if (trans_a)
return make_cublas_batched_matmul(b, dout, trans_b, 1);
else
// da = dc*b^T
return make_cublas_batched_matmul(dout, b, 0, trans_b^1);
} else {
// db = a^T*dc
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0);
if (trans_b)
return make_cublas_batched_matmul(dout, a, 1, trans_a);
else
// db = a^T*dc
return make_cublas_batched_matmul(a, dout, trans_a^1, 0);
}
}