mirror of https://github.com/Jittor/Jittor
update bmm
This commit is contained in:
parent
c1c29deaa9
commit
a4bb6318e1
|
@ -111,18 +111,4 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
|
||||
} // jittor
|
||||
|
||||
// nn.py
|
||||
// def bmm(a, b):
|
||||
// from compile_extern import cublas_ops
|
||||
// if jt.flags.use_cuda and cublas_ops:
|
||||
// return cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
// assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
// assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
// shape = list(a.shape) + [b.shape[-1]]
|
||||
// a = a.broadcast(shape, [len(shape)-1])
|
||||
// b = b.broadcast(shape, [len(shape)-3])
|
||||
// return (a*b).sum(len(shape)-2)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,8 @@ def matmul_transpose(a, b):
|
|||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1]
|
||||
|
||||
if jt.flags.use_cuda:
|
||||
jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
b = b.broadcast(shape)
|
||||
|
|
Loading…
Reference in New Issue