mirror of https://github.com/Jittor/Jittor
Update acl_compiler.py
This commit is contained in:
parent
1c752fbd83
commit
572f4301c2
|
@ -2771,6 +2771,7 @@ def change_function():
|
|||
jt.Var.__setitem__ = lambda x, slices, value: warp(
|
||||
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
|
||||
|
||||
fake_matmul = jt.Var.matmul
|
||||
jt.nn.bmm = warp(jt.nn.bmm, bmm_acl)
|
||||
jt.bmm = warp(jt.bmm, bmm_acl)
|
||||
jt.nn.matmul = warp(jt.matmul, matmul_acl)
|
||||
|
@ -2778,6 +2779,7 @@ def change_function():
|
|||
jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl)
|
||||
jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl)
|
||||
jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl)
|
||||
jt.Var.__matmul__ = lambda x, y: warp(fake_matmul, matmul_acl)(x, y)
|
||||
|
||||
jt.transpose = warp(jt.transpose, transpose_acl)
|
||||
fake_transpose = jt.transpose
|
||||
|
|
Loading…
Reference in New Issue