Update acl_compiler.py

This commit is contained in:
Yi Zhang 2024-12-20 22:44:49 +08:00 committed by GitHub
parent 1c752fbd83
commit 572f4301c2
1 changed files with 2 additions and 0 deletions

View File

@ -2771,6 +2771,7 @@ def change_function():
jt.Var.__setitem__ = lambda x, slices, value: warp(
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
fake_matmul = jt.Var.matmul
jt.nn.bmm = warp(jt.nn.bmm, bmm_acl)
jt.bmm = warp(jt.bmm, bmm_acl)
jt.nn.matmul = warp(jt.matmul, matmul_acl)
@ -2778,6 +2779,7 @@ def change_function():
jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl)
jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl)
jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl)
jt.Var.__matmul__ = lambda x, y: warp(fake_matmul, matmul_acl)(x, y)
jt.transpose = warp(jt.transpose, transpose_acl)
fake_transpose = jt.transpose