update matmul_transpose

This commit is contained in:
Meng-Hao 2020-07-24 09:40:06 +08:00
parent fb873bc50e
commit f7d893306a
1 changed files with 2 additions and 2 deletions

View File

@ -17,14 +17,14 @@ import math
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.optim import *
def matmul_transpose(a, b):
'''
returns a * b^T
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1]
if jt.flags.use_cuda:
jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2])
b = b.broadcast(shape)