mirror of https://github.com/Jittor/Jittor
Fix broadcasting issue in acl_compiler.py and add support for setting item in jt.Var
This commit is contained in:
parent
651b24e634
commit
e47a74a497
|
@ -971,6 +971,8 @@ def change_function():
|
|||
assert dd is True, "can not broadcast"
|
||||
output_shape = boardcast_shape
|
||||
output_shape += x.shape[slices_len:]
|
||||
if output_shape == []:
|
||||
output_shape = [1]
|
||||
for ii in slices:
|
||||
indices.append(jt.Var(ii))
|
||||
if isinstance(slices[0], jt.Var) or isinstance(
|
||||
|
@ -1154,7 +1156,9 @@ def change_function():
|
|||
assert dd is True, "can not broadcast"
|
||||
value_shape = boardcast_shape
|
||||
value_shape +=x.shape[slices_len:]
|
||||
if isinstance(value,int):
|
||||
if value_shape == []:
|
||||
value_shape = [1]
|
||||
if isinstance(value,int) or isinstance(value,float):
|
||||
value = jt.full(value_shape,value)
|
||||
self.value_shape = value_shape
|
||||
for ii in slices:
|
||||
|
@ -1208,7 +1212,7 @@ def change_function():
|
|||
if not sizes:
|
||||
sizes = [1]
|
||||
steps = [1]
|
||||
if isinstance(value,int):
|
||||
if isinstance(value,int) or isinstance(value,float):
|
||||
value = jt.full(sizes,value)
|
||||
self.type_ = 'slicev2'
|
||||
attr_code = f"""
|
||||
|
@ -1231,10 +1235,7 @@ def change_function():
|
|||
result = result.squeeze(-1)
|
||||
return result
|
||||
def grad(self,grad_output):
|
||||
print("grad")
|
||||
#value_grad
|
||||
value_grad = grad_output[self.input_slice]
|
||||
#x_grad
|
||||
grad_output[self.input_slice] = jt.zeros(self.value_shape)
|
||||
return grad_output, None,value_grad
|
||||
def setitem(x, slices, value):
|
||||
|
@ -1266,13 +1267,20 @@ def change_function():
|
|||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
|
||||
attr_code="op.jt_name=\"bmm\";")[0]
|
||||
grad_x2 = acl_cmd(
|
||||
if self.trans_x2:
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[-1:] + x1.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [grad_output.transpose(-2, -1), x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm\";")[0]
|
||||
else:
|
||||
output_shape = x1.shape[:-2] + x1.shape[-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [x1.transpose(-2, -1), grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm\";")[0]
|
||||
if self.trans_x2:
|
||||
grad_x2 = grad_x2.transpose(-2, -1)
|
||||
return grad_x1, grad_x2
|
||||
|
||||
def bmm(x1, x2):
|
||||
|
@ -1302,15 +1310,22 @@ def change_function():
|
|||
grad_x1 = acl_cmd(
|
||||
"MatMul", [grad_output, x2.transpose(-2, -1)],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
grad_x2 = acl_cmd(
|
||||
"MatMul", [x1.transpose(-2, -1), grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
|
||||
output_shapes=[grad_output.shape[:-1] + x2.shape[-2:-1]],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
if self.trans_x2:
|
||||
grad_x2 = grad_x2.transpose(-2, -1)
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[-1:] + x1.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"MatMul", [grad_output.transpose(-2, -1), x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
else:
|
||||
output_shape = x1.shape[:-2] + x1.shape[-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"MatMul", [x1.transpose(-2, -1), grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
return grad_x1, grad_x2
|
||||
|
||||
def matmul(x1, x2):
|
||||
|
@ -1582,7 +1597,9 @@ def change_function():
|
|||
jt.getitem, getitem)(x, slices)
|
||||
|
||||
jt.setitem = warp(jt.setitem, setitem)
|
||||
jt.Var.setitem = lambda x, slices, value: warp(jt.setitem, setitem)(x, slices, value)
|
||||
jt.Var.setitem = lambda x, slices, value: warp(jt.Var.setitem, setitem)(x, slices, value)
|
||||
jt.Var.__setitem__ = lambda x, slices, value: warp(jt.Var.__setitem__, setitem)(x, slices, value)
|
||||
|
||||
jt.nn.bmm = warp(jt.nn.bmm, bmm)
|
||||
jt.bmm = warp(jt.bmm, bmm)
|
||||
jt.nn.matmul = warp(jt.matmul, matmul)
|
||||
|
|
Loading…
Reference in New Issue