mirror of https://github.com/Jittor/Jittor
update matmul_transpose
This commit is contained in:
parent
fb873bc50e
commit
f7d893306a
|
@ -17,14 +17,14 @@ import math
|
|||
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
|
||||
from jittor.optim import *
|
||||
|
||||
|
||||
def matmul_transpose(a, b):
|
||||
'''
|
||||
returns a * b^T
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1]
|
||||
if jt.flags.use_cuda:
|
||||
jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
b = b.broadcast(shape)
|
||||
|
|
Loading…
Reference in New Issue