Fix broadcasting issue in acl_compiler.py and add support for setting item in jt.Var

This commit is contained in:
张仪 2024-09-14 16:00:15 +08:00
parent 651b24e634
commit e47a74a497
11 changed files with 44 additions and 17 deletions

0
doc/build_doc.sh Executable file → Normal file
View File

View File

@ -971,6 +971,8 @@ def change_function():
assert dd is True, "can not broadcast"
output_shape = boardcast_shape
output_shape += x.shape[slices_len:]
if output_shape == []:
output_shape = [1]
for ii in slices:
indices.append(jt.Var(ii))
if isinstance(slices[0], jt.Var) or isinstance(
@ -1154,7 +1156,9 @@ def change_function():
assert dd is True, "can not broadcast"
value_shape = boardcast_shape
value_shape +=x.shape[slices_len:]
if isinstance(value,int):
if value_shape == []:
value_shape = [1]
if isinstance(value,int) or isinstance(value,float):
value = jt.full(value_shape,value)
self.value_shape = value_shape
for ii in slices:
@ -1208,7 +1212,7 @@ def change_function():
if not sizes:
sizes = [1]
steps = [1]
if isinstance(value,int):
if isinstance(value,int) or isinstance(value,float):
value = jt.full(sizes,value)
self.type_ = 'slicev2'
attr_code = f"""
@ -1231,10 +1235,7 @@ def change_function():
result = result.squeeze(-1)
return result
def grad(self,grad_output):
print("grad")
#value_grad
value_grad = grad_output[self.input_slice]
#x_grad
grad_output[self.input_slice] = jt.zeros(self.value_shape)
return grad_output, None,value_grad
def setitem(x, slices, value):
@ -1266,13 +1267,20 @@ def change_function():
output_dtypes=[x1.dtype],
output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
attr_code="op.jt_name=\"bmm\";")[0]
grad_x2 = acl_cmd(
if self.trans_x2:
output_shape = grad_output.shape[:-2] + grad_output.shape[-1:] + x1.shape[-1:]
grad_x2 = acl_cmd(
"BatchMatMul", [grad_output.transpose(-2, -1), x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm\";")[0]
else:
output_shape = x1.shape[:-2] + x1.shape[-1:] + grad_output.shape[-1:]
grad_x2 = acl_cmd(
"BatchMatMul", [x1.transpose(-2, -1), grad_output],
output_dtypes=[x2.dtype],
output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm\";")[0]
if self.trans_x2:
grad_x2 = grad_x2.transpose(-2, -1)
return grad_x1, grad_x2
def bmm(x1, x2):
@ -1302,15 +1310,22 @@ def change_function():
grad_x1 = acl_cmd(
"MatMul", [grad_output, x2.transpose(-2, -1)],
output_dtypes=[x1.dtype],
output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
attr_code="op.jt_name=\"matmul\";")[0]
grad_x2 = acl_cmd(
"MatMul", [x1.transpose(-2, -1), grad_output],
output_dtypes=[x2.dtype],
output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
output_shapes=[grad_output.shape[:-1] + x2.shape[-2:-1]],
attr_code="op.jt_name=\"matmul\";")[0]
if self.trans_x2:
grad_x2 = grad_x2.transpose(-2, -1)
output_shape = grad_output.shape[:-2] + grad_output.shape[-1:] + x1.shape[-1:]
grad_x2 = acl_cmd(
"MatMul", [grad_output.transpose(-2, -1), x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul\";")[0]
else:
output_shape = x1.shape[:-2] + x1.shape[-1:] + grad_output.shape[-1:]
grad_x2 = acl_cmd(
"MatMul", [x1.transpose(-2, -1), grad_output],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul\";")[0]
return grad_x1, grad_x2
def matmul(x1, x2):
@ -1582,7 +1597,9 @@ def change_function():
jt.getitem, getitem)(x, slices)
jt.setitem = warp(jt.setitem, setitem)
jt.Var.setitem = lambda x, slices, value: warp(jt.setitem, setitem)(x, slices, value)
jt.Var.setitem = lambda x, slices, value: warp(jt.Var.setitem, setitem)(x, slices, value)
jt.Var.__setitem__ = lambda x, slices, value: warp(jt.Var.__setitem__, setitem)(x, slices, value)
jt.nn.bmm = warp(jt.nn.bmm, bmm)
jt.bmm = warp(jt.bmm, bmm)
jt.nn.matmul = warp(jt.matmul, matmul)

0
python/jittor/script/install.sh Executable file → Normal file
View File

0
python/jittor/script/install_mkl.sh Executable file → Normal file
View File

0
python/jittor/script/tmpi Executable file → Normal file
View File

0
python/jittor/script/update.sh Executable file → Normal file
View File

0
python/jittor/src/misc/miniz.cc Executable file → Normal file
View File

0
python/jittor/src/misc/miniz.h Executable file → Normal file
View File

0
python/jittor/utils/asm_tuner.py Executable file → Normal file
View File

0
python/jittor/utils/jtune.py Executable file → Normal file
View File

10
test_1.py Normal file
View File

@ -0,0 +1,10 @@
import jittor as jt
jt.flags.use_acl =1
weight = jt.zeros([200,])
# b = weight[0,0]
# g = jt.grad(b,weight)
# print(g)
# print(weight)
weight[0]=1.2
print(weight)