update bmm

This commit is contained in:
Meng-Hao 2020-07-12 17:41:46 +08:00
parent c1c29deaa9
commit a4bb6318e1
2 changed files with 2 additions and 15 deletions

View File

@ -111,18 +111,4 @@ void CublasBatchedMatmulOp::jit_run() {
} // jittor
// nn.py
// def bmm(a, b):
// from compile_extern import cublas_ops
// if jt.flags.use_cuda and cublas_ops:
// return cublas_ops.cublas_batched_matmul(a, b, 0, 0)
// assert len(a.shape) >= 2 and len(b.shape) >= 2
// assert a.shape[-1] == b.shape[-2]
// shape = list(a.shape) + [b.shape[-1]]
// a = a.broadcast(shape, [len(shape)-1])
// b = b.broadcast(shape, [len(shape)-3])
// return (a*b).sum(len(shape)-2)

View File

@ -23,7 +23,8 @@ def matmul_transpose(a, b):
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1]
if jt.flags.use_cuda:
jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2])
b = b.broadcast(shape)