mirror of https://github.com/Jittor/Jittor
update aclnn
This commit is contained in:
parent
b4244090ae
commit
21580ce80e
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.100.10'
|
||||
__version__ = '1.3.200.110'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -149,28 +149,67 @@ def post_process():
|
|||
|
||||
def acl_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list,
|
||||
output_shapes: list,
|
||||
attr_code: str = ""):
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
# inputs: list,
|
||||
# output_dtypes: list,
|
||||
# output_shapes: list,
|
||||
# attr_code: str = ""):
|
||||
# input_code = ''
|
||||
# for i in range(len(inputs)):
|
||||
# input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
# output_code = ''
|
||||
# for i in range(len(output_dtypes)):
|
||||
# output_code += f"op.add(out{i}, false);\n"
|
||||
|
||||
# # read the tmp_file.cpp to the cuda_header
|
||||
# with open(
|
||||
# "/home/ma-user/work/zy/JittorHW/python/jittor/extern/acl/tmp_file.cpp",
|
||||
# "r") as f:
|
||||
# cuda_header = f.read()
|
||||
# import jittor as jt
|
||||
# return jt.code(output_shapes,
|
||||
# output_dtypes,
|
||||
# inputs,
|
||||
# cuda_header=cuda_header,
|
||||
# cuda_src=f"""
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
# print(attr_header)
|
||||
|
||||
# read the tmp_file.cpp to the cuda_header
|
||||
with open(
|
||||
"/home/ma-user/work/zy/JittorHW/python/jittor/extern/acl/tmp_file.cpp",
|
||||
"r") as f:
|
||||
cuda_header = f.read()
|
||||
import jittor as jt
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
# print(f'{name } output_dtypes', output_dtypes)
|
||||
# print(f'{name } output_shapes', output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
# print(f'{name } outputs_', outputs_)
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(output_dtypes)):
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
|
||||
# read the tmp_file.cpp to the cuda_header
|
||||
with open(
|
||||
"/home/ma-user/work/zy/jittor/python/jittor/extern/acl/tmp_file.cpp",
|
||||
"r") as f:
|
||||
cuda_header = f.read()
|
||||
import jittor as jt
|
||||
return jt.code(output_shapes,
|
||||
output_dtypes,
|
||||
inputs,
|
||||
cuda_header=cuda_header,
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
AclOpRunner op("{name}");
|
||||
{input_code}
|
||||
|
@ -250,7 +289,7 @@ def change_function():
|
|||
|
||||
output_shape = (x.shape[0], out_channels, output_height,
|
||||
output_width)
|
||||
|
||||
|
||||
inputs = [x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
|
@ -276,6 +315,9 @@ def change_function():
|
|||
if bias is not None:
|
||||
output_shapes.append(bias.shape)
|
||||
output_dtypes.append(bias.dtype)
|
||||
else:
|
||||
output_shapes.append([1])
|
||||
output_dtypes.append(x.dtype)
|
||||
padding = self.padding
|
||||
stride = self.stride
|
||||
dilation = self.dilation
|
||||
|
@ -295,6 +337,8 @@ def change_function():
|
|||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)
|
||||
if self.bias is None:
|
||||
return results[0], results[1]
|
||||
|
||||
return results
|
||||
|
||||
|
@ -398,6 +442,671 @@ def change_function():
|
|||
self.padding, self.dilation, self.groups)
|
||||
return ret
|
||||
|
||||
class PoolACL(Function):
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=None,
|
||||
padding=0,
|
||||
dilation=None,
|
||||
return_indices=None,
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
op='maximum'):
|
||||
self.kernel_size = kernel_size if isinstance(
|
||||
kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
stride = stride if stride else kernel_size
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride,
|
||||
stride)
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding,
|
||||
padding)
|
||||
dilation = dilation if dilation else 1
|
||||
assert dilation == 1
|
||||
self.dilation = dilation if isinstance(
|
||||
dilation, tuple) else (dilation, dilation)
|
||||
for item in self.kernel_size:
|
||||
if item <= 0:
|
||||
raise RuntimeError(
|
||||
f"kernel_size must be greater than zero, but got {item}"
|
||||
)
|
||||
for item in self.stride:
|
||||
if item <= 0:
|
||||
raise RuntimeError(
|
||||
f"stride must be greater than zero, but got {item}")
|
||||
for item in self.padding:
|
||||
if item < 0:
|
||||
raise RuntimeError(
|
||||
f"padding must be non-negative, but got {item}")
|
||||
self.op = op
|
||||
self.return_indices = return_indices
|
||||
self.ceil_mode = ceil_mode
|
||||
self.count_include_pad = count_include_pad
|
||||
|
||||
def execute(self, input):
|
||||
self.input = input
|
||||
attr_code = f"""
|
||||
op.jt_name = "maxpool";
|
||||
PoolAttr *attr = new PoolAttr();
|
||||
attr->kernel_size = {{ {self.kernel_size[0]}, {self.kernel_size[1]} }};
|
||||
attr->poolStrides = {{ {self.stride[0]}, {self.stride[1]} }};
|
||||
attr->poolPads = {{ {self.padding[0]}, {self.padding[1]} }};
|
||||
attr->poolDilations = {{ {self.dilation[0]}, {self.dilation[1]} }};
|
||||
attr->poolCeil = {"true" if self.ceil_mode else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
input_height, input_width = input.shape[-2:]
|
||||
kernel_height, kernel_width = self.kernel_size[-2:]
|
||||
|
||||
output_height = (input_height + 2 * self.padding[0] -
|
||||
(kernel_height - 1) - 1) // self.stride[0] + 1
|
||||
output_width = (input_width + 2 * self.padding[1] -
|
||||
(kernel_width - 1) - 1) // self.stride[1] + 1
|
||||
|
||||
output_shape = (input.shape[0], input.shape[1], output_height,
|
||||
output_width)
|
||||
|
||||
inputs = [input]
|
||||
|
||||
if self.op == 'maximum':
|
||||
result = acl_cmd(
|
||||
"Maxpool",
|
||||
inputs,
|
||||
output_dtypes=[input.dtype, 'int32'],
|
||||
output_shapes=[output_shape, output_shape],
|
||||
attr_code=attr_code,
|
||||
)
|
||||
elif self.op == 'mean':
|
||||
result = acl_cmd(
|
||||
"Avgpool",
|
||||
inputs,
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code,
|
||||
)
|
||||
else:
|
||||
raise ValueError('no this type pool')
|
||||
|
||||
if self.op == 'maximum':
|
||||
self.index = result[1]
|
||||
|
||||
if self.return_indices:
|
||||
return result[0], result[1]
|
||||
else:
|
||||
return result[0]
|
||||
|
||||
def grad(self, grad_output, indices=None):
|
||||
input = self.input
|
||||
inputs = [grad_output, input, indices]
|
||||
attr_code = f"""
|
||||
op.jt_name = "maxpoolbackward";
|
||||
PoolAttr *attr = new PoolAttr();
|
||||
attr->kernel_size = {{ {self.kernel_size[0]}, {self.kernel_size[1]} }};
|
||||
attr->poolStrides = {{ {self.stride[0]}, {self.stride[1]} }};
|
||||
attr->poolPads = {{ {self.padding[0]}, {self.padding[1]} }};
|
||||
attr->poolDilations = {{ {self.dilation[0]}, {self.dilation[1]} }};
|
||||
attr->poolCeil = {"true" if self.ceil_mode else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
output_shapes = [input.shape]
|
||||
output_dtypes = [input.dtype]
|
||||
result = acl_cmd("MaxpoolBackward",
|
||||
inputs,
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
class FlipACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(FlipACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim):
|
||||
self.input = input
|
||||
attr_code = f"""
|
||||
op.jt_name = "flip";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{{', '.join(map(str, (list(dim))))}}};
|
||||
attr->prod_dim = {len(dim)};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
self.attr_code = attr_code
|
||||
result = acl_cmd("Flip", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
grad_input = acl_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return grad_input
|
||||
|
||||
class ConcatACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ConcatACL, self).__init__()
|
||||
|
||||
def execute(self, input_tensors, dim=0):
|
||||
self.input = input_tensors
|
||||
self.dim = dim
|
||||
for i in range(len(input_tensors)):
|
||||
if input_tensors[i].dtype != input_tensors[0].dtype:
|
||||
raise ValueError(
|
||||
"All input tensors must have the same dtype")
|
||||
if input_tensors[i].shape[:dim] != input_tensors[
|
||||
0].shape[:dim] or input_tensors[i].shape[
|
||||
dim + 1:] != input_tensors[0].shape[dim + 1:]:
|
||||
raise ValueError(
|
||||
"All input tensors must have the same shape")
|
||||
attr_code = f"""
|
||||
op.jt_name = "concat";
|
||||
ConcatAttr *attr = new ConcatAttr();
|
||||
attr->tensorNum = {len(input_tensors)};
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd(
|
||||
"Concat",
|
||||
input_tensors,
|
||||
output_dtypes=[input_tensors[0].dtype],
|
||||
output_shapes=[
|
||||
jt.empty(self.calculate_output_shape(input_tensors,
|
||||
dim)).shape
|
||||
],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
"""def grad(self, grad_output):
|
||||
grad_inputs = self.split_grad(grad_output, self.input, self.axis)
|
||||
return grad_inputs"""
|
||||
|
||||
def calculate_output_shape(self, input_tensors, axis):
|
||||
shape = list(input_tensors[0].shape)
|
||||
for tensor in input_tensors[1:]:
|
||||
shape[axis] += tensor.shape[axis]
|
||||
return tuple(shape)
|
||||
|
||||
"""def split_grad(self, grad_output, input_tensors, axis):
|
||||
offset = 0
|
||||
grad_inputs = []
|
||||
for tensor in input_tensors:
|
||||
grad_input = acl_cmd("Slice", [
|
||||
grad_output, [0] * axis + [offset] + [0] *
|
||||
(len(tensor.shape) - axis - 1), tensor.shape
|
||||
])
|
||||
grad_inputs.append(grad_input)
|
||||
offset += tensor.shape[axis]
|
||||
return grad_inputs"""
|
||||
|
||||
class GatherACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(GatherACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index):
|
||||
self.input = input
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Gather", [input, index],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {self.dim};
|
||||
attr->reduction = {1};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = acl_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
||||
class CumsumACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(CumsumACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim=-1):
|
||||
self.input = input
|
||||
self.dim = dim
|
||||
attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Cumsum", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
cumsum_attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flip_attr_code = f"""
|
||||
op.jt_name = "flip";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{{self.dim}}};
|
||||
attr->prod_dim = {{{1}}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flipped_grad_output = acl_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=cumsum_attr_code)[0]
|
||||
grad_input = acl_cmd("Flip", [cumulative_grad],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
return grad_input
|
||||
|
||||
class IndexACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(IndexACL, self).__init__()
|
||||
|
||||
def execute(self, inshape: list, dim=None, dtype="int32"):
|
||||
# zeros a tensor, shape is inshape, dtype is dtype
|
||||
if dim == None:
|
||||
dim = [i for i in range(len(inshape))]
|
||||
elif type(dim) == int:
|
||||
dim = [dim]
|
||||
results = []
|
||||
for d in dim:
|
||||
max_len = inshape[d]
|
||||
tmp = jt.zeros(max_len, dtype=dtype)
|
||||
result = acl_cmd(
|
||||
"Index", [jt.Var(0), jt.Var(max_len),
|
||||
jt.Var(1)],
|
||||
output_dtypes=[tmp.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code="op.jt_name=\"index\";")[0]
|
||||
broadcast_dim = []
|
||||
for i in range(len(inshape)):
|
||||
if i != d:
|
||||
broadcast_dim.append(i)
|
||||
result = jt.broadcast(result,
|
||||
shape=inshape,
|
||||
dims=broadcast_dim)
|
||||
results.append(result)
|
||||
return tuple(results)
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
|
||||
class ScatterACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ScatterACL, self).__init__()
|
||||
|
||||
def __call__(self, input, dim, index, src, reduce='void'):
|
||||
return self.execute(input, dim, index, src, reduce)
|
||||
|
||||
def execute(self, input, dim, index, src, reduce='void'):
|
||||
self.input = input
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
self.reduce = reduce
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {dim};
|
||||
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Scatter", [input, self.index, src],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = acl_cmd("Gather", [grad_output, self.index],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[self.index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_output, None, None, grad_input
|
||||
|
||||
class WhereACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(WhereACL, self).__init__()
|
||||
|
||||
def execute(self, condition, x, y):
|
||||
self.condition = condition
|
||||
|
||||
if x.dtype != y.dtype:
|
||||
if x.dtype == jt.float32:
|
||||
y = y.float32()
|
||||
elif y.dtype == jt.float32:
|
||||
x = x.float32()
|
||||
else:
|
||||
x = x.to(y.dtype)
|
||||
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
result = acl_cmd("Where", [condition, x, y],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
|
||||
grad_x = acl_cmd("Where", [self.condition, grad_output, tmp],
|
||||
output_dtypes=[self.x.dtype],
|
||||
output_shapes=[self.x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
|
||||
grad_y = acl_cmd("Where", [self.condition, tmp, grad_output],
|
||||
output_dtypes=[self.y.dtype],
|
||||
output_shapes=[self.y.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return grad_output, grad_x, grad_y
|
||||
|
||||
class FloorIntACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(FloorIntACL, self).__init__()
|
||||
|
||||
def execute(self, input):
|
||||
self.input = input
|
||||
self.shape = input.shape
|
||||
result = acl_cmd("Floor", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code="op.jt_name=\"floor\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return jt.zeros(self.shape, dtype=grad_output.dtype)
|
||||
|
||||
def caculate_shape(tensors):
|
||||
if isinstance(tensors, jt.Var):
|
||||
# tensors = tensors[0]
|
||||
return tensors.shape
|
||||
elif isinstance(tensors, (int, float)):
|
||||
return []
|
||||
elif isinstance(tensors, (list, tuple)):
|
||||
# return [caculate_shape(tensor) for tensor in tensors]
|
||||
sub_shape = caculate_shape(tensors[0])
|
||||
return [len(tensors)] + sub_shape
|
||||
else:
|
||||
assert False, f"not implemented for {type(tensors)}"
|
||||
|
||||
def can_broadcast_and_shape(shape1, shape2):
|
||||
"""
|
||||
检查两个张量是否可以广播,并返回广播后的形状。
|
||||
|
||||
参数:
|
||||
- shape1: 第一个张量的形状(tuple 或 list)
|
||||
- shape2: 第二个张量的形状(tuple 或 list)
|
||||
|
||||
返回:
|
||||
- can_broadcast: 布尔值,表示是否可以广播
|
||||
- broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None
|
||||
"""
|
||||
# 将形状转换为元组,以防输入是列表
|
||||
shape1 = tuple(shape1)
|
||||
shape2 = tuple(shape2)
|
||||
|
||||
# 使两个形状的长度一致,通过在前面补1
|
||||
len1, len2 = len(shape1), len(shape2)
|
||||
if len1 < len2:
|
||||
shape1 = (1, ) * (len2 - len1) + shape1
|
||||
elif len2 < len1:
|
||||
shape2 = (1, ) * (len1 - len2) + shape2
|
||||
|
||||
broadcast_shape = []
|
||||
|
||||
# 从最后一维开始检查每一维度
|
||||
for dim1, dim2 in zip(shape1, shape2):
|
||||
if dim1 == dim2:
|
||||
broadcast_shape.append(dim1)
|
||||
elif dim1 == 1:
|
||||
broadcast_shape.append(dim2)
|
||||
elif dim2 == 1:
|
||||
broadcast_shape.append(dim1)
|
||||
else:
|
||||
# 如果在某一维度上不兼容,则不能广播
|
||||
return False, None
|
||||
|
||||
return True, tuple(broadcast_shape)
|
||||
|
||||
class GetItemACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
self.type_ = 'notype'
|
||||
|
||||
def stride(self, x, dim):
|
||||
stride = 1
|
||||
for i in range(dim + 1, len(x.shape)):
|
||||
stride *= x.shape[i]
|
||||
return stride
|
||||
|
||||
def execute(self, x, slices, return_x=None):
|
||||
self.x_shape = x.shape
|
||||
if not isinstance(slices, tuple):
|
||||
slices = (slices, )
|
||||
slices_list = list(slices)
|
||||
# if not isinstance(slices[0], slice):
|
||||
#check slices contains slice type
|
||||
contains_slice = False
|
||||
for s in slices:
|
||||
if isinstance(s, slice):
|
||||
contains_slice = True
|
||||
break
|
||||
if not contains_slice:
|
||||
indices = []
|
||||
output_shape = []
|
||||
slices_len = len(slices)
|
||||
boardcast_shape = caculate_shape(slices_list[0])
|
||||
for ii in range(1, len(slices)):
|
||||
dd, boardcast_shape = can_broadcast_and_shape(
|
||||
boardcast_shape, caculate_shape(slices_list[ii]))
|
||||
assert dd is True, "can not broadcast"
|
||||
output_shape = boardcast_shape
|
||||
output_shape += x.shape[slices_len:]
|
||||
for ii in slices:
|
||||
indices.append(jt.Var(ii))
|
||||
if isinstance(slices[0], jt.Var) or isinstance(
|
||||
slices[0], int) or isinstance(
|
||||
slices[0], list) or isinstance(slices[0], tuple):
|
||||
self.indices = indices
|
||||
inputs = [x] + indices
|
||||
attr_code = f"""
|
||||
op.jt_name = "index";
|
||||
"""
|
||||
self.type_ = 'index'
|
||||
result = acl_cmd("Index",
|
||||
inputs=inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
x_dim = len(x.shape)
|
||||
if len(slices) < x_dim:
|
||||
slices += (slice(None, None, None), ) * (x_dim - len(slices))
|
||||
inputs = [x]
|
||||
sizes = []
|
||||
begins = []
|
||||
ends = []
|
||||
steps = []
|
||||
dims = []
|
||||
squeeze_dims = []
|
||||
for dim, s in enumerate(slices):
|
||||
if isinstance(s, int):
|
||||
s = slice(s, s + 1, 1)
|
||||
squeeze_dims.append(dim)
|
||||
if isinstance(s, jt.Var):
|
||||
assert False, "jt.Var not supported"
|
||||
start, stop, step = s.indices(x.size(dim))
|
||||
size = (stop - start - 1) // step + 1
|
||||
stride = self.stride(x, dim) * step
|
||||
sizes.append(size)
|
||||
steps.append(step)
|
||||
begins.append(start)
|
||||
ends.append(stop)
|
||||
dims.append(dim)
|
||||
if not sizes:
|
||||
sizes = [1]
|
||||
steps = [1]
|
||||
self.type_ = 'slicev2'
|
||||
self.begins = begins
|
||||
self.ends = ends
|
||||
self.steps = steps
|
||||
self.dims = dims
|
||||
attr_code = f"""
|
||||
op.jt_name = "slicev2";
|
||||
StrideAttr *attr = new StrideAttr();
|
||||
attr->begins = {{ {", ".join(map(str, begins))} }};
|
||||
attr->ends = {{ {", ".join(map(str, ends))} }};
|
||||
attr->steps = {{ {", ".join(map(str, steps))} }};
|
||||
attr->axes = {{ {", ".join(map(str, dims))} }};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
|
||||
result = acl_cmd("SliceV2",
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(sizes).shape],
|
||||
attr_code=attr_code)[0]
|
||||
for dim in squeeze_dims[::-1]:
|
||||
result = jt.squeeze(result, dim)
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
if self.type_ == 'index':
|
||||
indices = self.indices
|
||||
inputs = [grad_output] + indices
|
||||
attr_code = f"""
|
||||
op.jt_name = "indexputimpl";
|
||||
"""
|
||||
outputs = [jt.zeros(self.x_shape)]
|
||||
result = acl_cmd("IndexPutImpl",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
elif self.type_ == 'slicev2':
|
||||
#TODO: wait for cann update
|
||||
assert False, f"wait for cann update"
|
||||
begins = self.begins
|
||||
ends = self.ends
|
||||
steps = self.steps
|
||||
dims = self.dims
|
||||
begins = jt.Var(begins).int64()
|
||||
ends = jt.Var(ends).int64()
|
||||
steps = jt.Var(steps).int64()
|
||||
dims = jt.Var(dims).int64()
|
||||
inputs = [grad_output, begins, ends, steps, dims]
|
||||
attr_code = f"""
|
||||
op.jt_name = "stridedsliceassignv2";
|
||||
"""
|
||||
result = acl_cmd("StridedSliceAssignV2",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
else:
|
||||
assert False, f"grad not implemented for {self.type_}"
|
||||
|
||||
class BmmACL(Function):
|
||||
|
||||
def __init__(self, trans_x2=False):
|
||||
super(BmmACL, self).__init__()
|
||||
self.trans_x2 = trans_x2
|
||||
|
||||
def execute(self, x1, x2):
|
||||
if self.trans_x2:
|
||||
x2 = x2.transpose(-2, -1)
|
||||
self.input = [x1, x2]
|
||||
result = acl_cmd("BatchMatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
|
||||
attr_code="op.jt_name=\"bmm\";")[0]
|
||||
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x1, x2 = self.input
|
||||
grad_x1 = acl_cmd(
|
||||
"BatchMatMul", [grad_output, x2.transpose(-2, -1)],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
|
||||
attr_code="op.jt_name=\"bmm\";")[0]
|
||||
x2 = x2.transpose(-2, -1)
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [x1.transpose(-2, -1), grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
|
||||
attr_code="op.jt_name=\"bmm\";")[0]
|
||||
x1 = x1.transpose(-2, -1)
|
||||
return grad_x1, grad_x2
|
||||
|
||||
class MatmulACL(Function):
|
||||
|
||||
def __init__(self, trans_x2=False):
|
||||
super(MatmulACL, self).__init__()
|
||||
self.trans_x2 = trans_x2
|
||||
|
||||
def execute(self, x1, x2):
|
||||
if self.trans_x2:
|
||||
x2 = x2.transpose(-2, -1)
|
||||
self.input = [x1, x2]
|
||||
result = acl_cmd("MatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[x1.shape[:-1] + x2.shape[-1:]],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x1, x2 = self.input
|
||||
grad_x1 = acl_cmd(
|
||||
"MatMul", [grad_output, x2.transpose(-2, -1)],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
grad_x2 = acl_cmd(
|
||||
"MatMul", [x1.transpose(-2, -1), grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]],
|
||||
attr_code="op.jt_name=\"matmul\";")[0]
|
||||
return grad_x1, grad_x2
|
||||
|
||||
def warp(origin_func, new_func):
|
||||
|
||||
def warpper(*args, **kwargs):
|
||||
|
@ -414,3 +1123,34 @@ def change_function():
|
|||
jt.nn.conv2d = warp(jt.nn.conv2d, ConvACL())
|
||||
jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D)
|
||||
jt.nn.Conv = warp(jt.nn.Conv, Conv2D)
|
||||
jt.nn.Pool = warp(jt.nn.Pool, PoolACL)
|
||||
|
||||
jt.flip = warp(jt.flip, FlipACL())
|
||||
jt.Var.flip = lambda x, dim_vector: warp(jt.Var.flip, FlipACL())(
|
||||
x, dim_vector)
|
||||
jt.concat = warp(jt.concat, ConcatACL())
|
||||
|
||||
jt.gather = warp(jt.gather, GatherACL())
|
||||
|
||||
jt.cumsum = warp(jt.cumsum, CumsumACL())
|
||||
# jt.index = warp(jt.index, IndexACL())
|
||||
# jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim)
|
||||
|
||||
jt.scatter = warp(jt.scatter, ScatterACL())
|
||||
jt.Var.scatter = lambda x, dim, index, src, reduce="void": warp(
|
||||
jt.scatter, ScatterACL())(x, dim, index, src, reduce)
|
||||
|
||||
jt.floor_int = warp(jt.floor_int, FloorIntACL())
|
||||
jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x)
|
||||
|
||||
jt.getitem = warp(jt.getitem, GetItemACL())
|
||||
jt.Var.getitem = lambda x, slices, return_x=None: warp(
|
||||
jt.getitem, GetItemACL())(x, slices)
|
||||
|
||||
jt.nn.bmm = warp(jt.nn.bmm, BmmACL())
|
||||
jt.bmm = warp(jt.bmm, BmmACL())
|
||||
jt.nn.matmul = warp(jt.matmul, MatmulACL())
|
||||
jt.matmul = warp(jt.matmul, MatmulACL())
|
||||
jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, MatmulACL(True))
|
||||
jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, BmmACL(True))
|
||||
jt.bmm_transpose = warp(jt.bmm_transpose, BmmACL(True))
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "utils/str_utils.h"
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "aclnn/aclnn.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
|
@ -17,9 +18,23 @@ namespace jittor
|
|||
uint64_t acl_jittor_tid;
|
||||
int acl_jittor_thread_running = 0;
|
||||
aclrtStream aclstream;
|
||||
void *workspaceAddr = nullptr;
|
||||
uint64_t nowWorkSpaceSize = 0;
|
||||
|
||||
#define CHECK_ACL(x) ASSERTop(x, ==, 0)
|
||||
|
||||
void mallocWorkSpace(uint64_t size)
|
||||
{
|
||||
uint64_t alloc_size = size + 32;
|
||||
alloc_size = ((alloc_size - 1) / 32 + 1) * 32;
|
||||
if (alloc_size > nowWorkSpaceSize)
|
||||
{
|
||||
aclrtFree(workspaceAddr);
|
||||
nowWorkSpaceSize = alloc_size;
|
||||
auto ret = aclrtMalloc(&workspaceAddr, nowWorkSpaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return);
|
||||
}
|
||||
}
|
||||
static void *acl_jittor_process_callback(void *)
|
||||
{
|
||||
acl_jittor_thread_running = 1;
|
||||
|
@ -63,6 +78,10 @@ namespace jittor
|
|||
aclrtDestroyStream(aclstream);
|
||||
aclrtResetDevice(deviceId);
|
||||
CHECK_ACL(aclFinalize());
|
||||
if (nowWorkSpaceSize > 0)
|
||||
{
|
||||
aclrtFree(workspaceAddr);
|
||||
}
|
||||
}
|
||||
|
||||
} _acl_jittor_initer;
|
||||
|
|
|
@ -15,6 +15,9 @@ namespace jittor
|
|||
|
||||
EXTERN_LIB uint64_t acl_jittor_tid;
|
||||
EXTERN_LIB aclrtStream aclstream;
|
||||
EXTERN_LIB void *workspaceAddr;
|
||||
|
||||
void mallocWorkSpace(uint64_t size);
|
||||
|
||||
void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags);
|
||||
|
||||
|
@ -28,7 +31,7 @@ namespace jittor
|
|||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncBinary;
|
||||
// for Add and Sub
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclScalar *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAdd;
|
||||
// for Expand and permute
|
||||
// for Expand, permute, flip
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncExpand;
|
||||
// for bmm and matmul
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, int8_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMatmul;
|
||||
|
@ -42,10 +45,32 @@ namespace jittor
|
|||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclIntArray *, int, aclBoolArray *, int8_t, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConvBackward;
|
||||
// for proddim
|
||||
std::function<aclnnStatus(aclTensor *, int64_t, bool, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncProdDim;
|
||||
// for select
|
||||
// for select, where
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSelect;
|
||||
// for random_uniform and random_normal
|
||||
std::function<aclnnStatus(aclTensor *, int64_t, int64_t, int64_t, int64_t, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncRandom;
|
||||
// for maxpool
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMaxPool;
|
||||
// for maxpool backward
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncMaxPoolBackward;
|
||||
// for avgpool
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncAvgPool;
|
||||
// for concat
|
||||
std::function<aclnnStatus(aclTensorList *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncConcat;
|
||||
// for gather
|
||||
std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncGather;
|
||||
// for cumsum
|
||||
std::function<aclnnStatus(aclTensor *, uint64_t, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncCumsum;
|
||||
// for scatter
|
||||
std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncScatter;
|
||||
// for index
|
||||
std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncIndex;
|
||||
// for stridesliceassign
|
||||
std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncStridedSliceAssignV2;
|
||||
// for slicev2
|
||||
std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncSliceV2;
|
||||
// for indexputimpl
|
||||
std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, bool, bool, uint64_t *, aclOpExecutor **)> getWorkspaceSizeFuncIndexPutImpl;
|
||||
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, aclrtStream)> executeFunc;
|
||||
|
||||
|
@ -71,7 +96,7 @@ namespace jittor
|
|||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncAdd(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for Expand
|
||||
// for Expand, flip
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncExpand(gwsf), executeFunc(execf) {}
|
||||
|
@ -106,7 +131,7 @@ namespace jittor
|
|||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncProdDim(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for select
|
||||
// for select, where
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncSelect(gwsf), executeFunc(execf) {}
|
||||
|
@ -115,6 +140,61 @@ namespace jittor
|
|||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, int64_t, int64_t, int64_t, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncRandom(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for maxpool
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncMaxPool(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for maxpool backward
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, bool, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncMaxPoolBackward(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for avgpool
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, bool, bool, int64_t, int8_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncAvgPool(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for concat
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensorList *, int64_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncConcat(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for gather
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncGather(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for cumsum
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, int64_t, aclDataType, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncCumsum(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for scatter
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, uint64_t, aclTensor *, aclTensor *, uint64_t, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncScatter(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for index
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncIndex(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for stridesliceassignv2
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncStridedSliceAssignV2(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for slicev2
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclIntArray *, aclIntArray *, aclIntArray *, aclIntArray *, aclTensor *, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncSliceV2(gwsf), executeFunc(execf) {}
|
||||
|
||||
// for indexputimpl
|
||||
AclOpFunctions(std::function<aclnnStatus(aclTensor *, aclTensorList *, aclTensor *, bool, bool, uint64_t *, aclOpExecutor **)> gwsf,
|
||||
std::function<aclnnStatus(void *, uint64_t, aclOpExecutor *, const aclrtStream)> execf)
|
||||
: getWorkspaceSizeFuncIndexPutImpl(gwsf), executeFunc(execf) {}
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, AclOpFunctions> aclOpFuncMap = {
|
||||
|
@ -180,6 +260,19 @@ namespace jittor
|
|||
{"RandomUniform", AclOpFunctions(aclnnInplaceRandomGetWorkspaceSize, aclnnInplaceRandom)},
|
||||
{"RandomNormal", AclOpFunctions(aclnnInplaceNormalGetWorkspaceSize, aclnnInplaceNormal)},
|
||||
{"Transpose", AclOpFunctions(aclnnPermuteGetWorkspaceSize, aclnnPermute)},
|
||||
{"Maxpool", AclOpFunctions(aclnnMaxPool2dWithIndicesGetWorkspaceSize, aclnnMaxPool2dWithIndices)},
|
||||
{"MaxpoolBackward", AclOpFunctions(aclnnMaxPool2dWithIndicesBackwardGetWorkspaceSize, aclnnMaxPool2dWithIndicesBackward)},
|
||||
{"Flip", AclOpFunctions(aclnnFlipGetWorkspaceSize, aclnnFlip)},
|
||||
{"Concat", AclOpFunctions(aclnnCatGetWorkspaceSize, aclnnCat)},
|
||||
{"Gather", AclOpFunctions(aclnnGatherGetWorkspaceSize, aclnnGather)},
|
||||
{"Cumsum", AclOpFunctions(aclnnCumsumGetWorkspaceSize, aclnnCumsum)},
|
||||
{"Index", AclOpFunctions(aclnnIndexGetWorkspaceSize, aclnnIndex)},
|
||||
{"Scatter", AclOpFunctions(aclnnScatterGetWorkspaceSize, aclnnScatter)},
|
||||
{"Where", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)},
|
||||
{"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)},
|
||||
{"StridedSliceAssignV2", AclOpFunctions(aclnnStridedSliceAssignV2GetWorkspaceSize, aclnnStridedSliceAssignV2)},
|
||||
{"SliceV2", AclOpFunctions(aclnnSliceV2GetWorkspaceSize, aclnnSliceV2)},
|
||||
{"IndexPutImpl", AclOpFunctions(aclnnIndexPutImplGetWorkspaceSize, aclnnIndexPutImpl)},
|
||||
};
|
||||
|
||||
struct AclOpAttr
|
||||
|
@ -238,4 +331,66 @@ namespace jittor
|
|||
}
|
||||
};
|
||||
|
||||
struct PoolAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> kernel_size;
|
||||
vector<int64_t> poolStrides;
|
||||
vector<int64_t> poolPads;
|
||||
vector<int64_t> poolDilations;
|
||||
bool poolCeil;
|
||||
|
||||
// 析构函数
|
||||
~PoolAttr()
|
||||
{
|
||||
kernel_size.clear();
|
||||
poolStrides.clear();
|
||||
poolPads.clear();
|
||||
poolDilations.clear();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatAttr : AclOpAttr
|
||||
{
|
||||
int64_t tensorNum;
|
||||
int64_t dim;
|
||||
|
||||
~ConcatAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct GatherAttr : AclOpAttr
|
||||
{
|
||||
int64_t dim;
|
||||
|
||||
~GatherAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterAttr : AclOpAttr
|
||||
{
|
||||
int64_t axis;
|
||||
int64_t reduction;
|
||||
|
||||
~ScatterAttr()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct StrideAttr : AclOpAttr
|
||||
{
|
||||
vector<int64_t> begins;
|
||||
vector<int64_t> ends;
|
||||
vector<int64_t> steps;
|
||||
vector<int64_t> axes;
|
||||
~StrideAttr()
|
||||
{
|
||||
begins.clear();
|
||||
ends.clear();
|
||||
steps.clear();
|
||||
axes.clear();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
|
@ -169,9 +169,6 @@ namespace jittor
|
|||
// for expand
|
||||
aclIntArray *size = nullptr;
|
||||
|
||||
// for add and sub
|
||||
float alphaValue = 1.0f;
|
||||
|
||||
// for conv
|
||||
aclIntArray *strides = nullptr;
|
||||
aclIntArray *pads = nullptr;
|
||||
|
@ -179,13 +176,74 @@ namespace jittor
|
|||
aclIntArray *dilations = nullptr;
|
||||
int ret = -1;
|
||||
|
||||
// for maxpool
|
||||
aclIntArray *kernel_size = nullptr;
|
||||
|
||||
// for concat
|
||||
aclTensorList *tensor_list = nullptr;
|
||||
|
||||
if (name == string("Add") || name == string("Sub"))
|
||||
{
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
|
||||
if (get_dtype(in_[0]->dtype()) == ACL_FLOAT)
|
||||
{
|
||||
float alphaValue = 1.0;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_FLOAT16)
|
||||
{
|
||||
float alphaValue = 1.0;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT64)
|
||||
{
|
||||
int64_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT32)
|
||||
{
|
||||
int alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT8)
|
||||
{
|
||||
int8_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT16)
|
||||
{
|
||||
int16_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT8)
|
||||
{
|
||||
uint8_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT16)
|
||||
{
|
||||
uint16_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT32)
|
||||
{
|
||||
uint32_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_BOOL)
|
||||
{
|
||||
bool alphaValue = true;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "Not supported dtype: " << in_[0]->dtype();
|
||||
}
|
||||
|
||||
CHECK_RET(alpha != nullptr, return);
|
||||
}
|
||||
|
||||
if (jt_name == "conv" || jt_name == "conv2d" || jt_name == "conv2dbackward")
|
||||
if (jt_name == "conv" || jt_name == "conv2d" || jt_name == "conv2dbackward" || jt_name == "maxpool" || jt_name == "maxpoolbackward")
|
||||
use_nchw = true;
|
||||
|
||||
for (int idx = 0; idx < input_num; idx++)
|
||||
|
@ -206,11 +264,29 @@ namespace jittor
|
|||
outputShapes[0] = {};
|
||||
}
|
||||
}
|
||||
for (int idx = 0; idx < output_num; idx++)
|
||||
if (jt_name == "conv2dbackward")
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
for (int idx = 0; idx < 2; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
// biasgrad nd format
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int idx = 0; idx < output_num; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 调用CANN算子库aclnnxxxGetWorkspaceSize的接口,两段式接口的第一个
|
||||
|
@ -249,7 +325,7 @@ namespace jittor
|
|||
auto attr = dynamic_cast<RandomAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], int64_t(0), int64_t(1), attr->seed, attr->offset, &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Select"))
|
||||
else if (name == string("Select") || name == string("Where"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
|
@ -262,28 +338,114 @@ namespace jittor
|
|||
{
|
||||
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
// else if (name == string("Conv2d"))
|
||||
// {
|
||||
// auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
// strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||
// pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
// outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
// dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
else if (name == string("Conv2d"))
|
||||
{
|
||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
aclTensor *bias = nullptr;
|
||||
if (input_num == 3)
|
||||
bias = inputTensors[2];
|
||||
|
||||
// ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], nullptr, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
|
||||
// }
|
||||
// else if (name == string("Conv2dBackward"))
|
||||
// {
|
||||
// auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
// strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||
// pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
// outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
// dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
// bool outputMask[3] = {true, true, false};
|
||||
// LOGir << attr->group;
|
||||
// aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
|
||||
// ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], nullptr, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], nullptr, &workspaceSize, &executor);
|
||||
// }
|
||||
ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], bias, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Conv2dBackward"))
|
||||
{
|
||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
bool outputMask[3] = {true, true, true};
|
||||
if (input_num == 3)
|
||||
{
|
||||
outputMask[2] = false;
|
||||
}
|
||||
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
|
||||
auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size());
|
||||
ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Maxpool"))
|
||||
{
|
||||
auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
|
||||
kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
|
||||
strides = aclCreateIntArray(attr->poolStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->poolPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->poolDilations.data(), 2);
|
||||
ret = it->second.getWorkspaceSizeFuncMaxPool(inputTensors[0], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("MaxpoolBackward"))
|
||||
{
|
||||
auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
|
||||
kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
|
||||
strides = aclCreateIntArray(attr->poolStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->poolPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->poolDilations.data(), 2);
|
||||
ret = it->second.getWorkspaceSizeFuncMaxPoolBackward(inputTensors[0], inputTensors[1], inputTensors[2], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Flip"))
|
||||
{
|
||||
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
|
||||
dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Concat"))
|
||||
{
|
||||
auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
|
||||
CHECK_RET(inputTensors.size() == attr->tensorNum, return);
|
||||
std::vector<const aclTensor *> constTensors(inputTensors.begin(), inputTensors.end());
|
||||
tensor_list = aclCreateTensorList(constTensors.data(), attr->tensorNum);
|
||||
ret = it->second.getWorkspaceSizeFuncConcat(tensor_list, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Gather"))
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncGather(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Cumsum"))
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncCumsum(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Scatter"))
|
||||
{
|
||||
auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Floor"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnary(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Index"))
|
||||
{
|
||||
auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);
|
||||
ret = it->second.getWorkspaceSizeFuncIndex(inputTensors[0], indexTensorList, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("SliceV2"))
|
||||
{
|
||||
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
|
||||
auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
|
||||
auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
|
||||
auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
|
||||
auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = it->second.getWorkspaceSizeFuncSliceV2(inputTensors[0], begins, ends, axes, steps, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("IndexPutImpl"))
|
||||
{
|
||||
std::vector<aclTensor *> indexTensorList = {};
|
||||
for (int i = 1; i < input_num; i++)
|
||||
{
|
||||
indexTensorList.push_back(inputTensors[i]);
|
||||
}
|
||||
auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1);
|
||||
ret = it->second.getWorkspaceSizeFuncIndexPutImpl(outputTensors[0], indexTensorListInput, inputTensors[0], false, true, &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("StridedSliceAssignV2"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncStridedSliceAssignV2(outputTensors[0], inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], &workspaceSize, &executor);
|
||||
}
|
||||
else
|
||||
LOGf << "not supported op " << jt_name;
|
||||
|
||||
|
@ -297,11 +459,9 @@ namespace jittor
|
|||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
// 4. 根据第一段接口计算出的workspaceSize申请device内存
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: allocate workspace failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
// 5. 调用aclnnxx第二段接口
|
||||
|
@ -332,12 +492,9 @@ namespace jittor
|
|||
aclDestroyIntArray(pads);
|
||||
aclDestroyIntArray(outPads);
|
||||
aclDestroyIntArray(dilations);
|
||||
aclDestroyIntArray(kernel_size);
|
||||
aclDestroyTensorList(tensor_list);
|
||||
|
||||
// 8. 释放device资源
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
aclrtFree(workspaceAddr);
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
@ -546,7 +703,7 @@ namespace jittor
|
|||
LOGf << "op " << rop->ns << " not supported";
|
||||
op.jt_name = "reduce";
|
||||
op.add(rop->x, true);
|
||||
|
||||
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
for (int i = 0; i < rop->x->shape.size(); i++)
|
||||
if (rop->reduce_mask & (1 << i))
|
||||
|
@ -642,130 +799,6 @@ namespace jittor
|
|||
runner.run();
|
||||
current_offset += out->numel();
|
||||
}},
|
||||
{"cublas_matmul", [&](Op *op)
|
||||
{
|
||||
struct MatmulOp : Op
|
||||
{
|
||||
Var *a, *b, *c;
|
||||
bool trans_a, trans_b;
|
||||
};
|
||||
auto _op = (MatmulOp *)op;
|
||||
AclOpRunner runner("MatMul");
|
||||
runner.jt_name = "matmul";
|
||||
runner.add(_op->a, true);
|
||||
runner.add(_op->b, true);
|
||||
runner.add(_op->c, false);
|
||||
runner.run();
|
||||
}},
|
||||
{"cublas_batched_matmul", [&](Op *op)
|
||||
{
|
||||
struct BatchedMatmulOp : Op
|
||||
{
|
||||
Var *a, *b, *c;
|
||||
bool adj_x1, adj_x2;
|
||||
};
|
||||
auto _op = (BatchedMatmulOp *)op;
|
||||
AclOpRunner runner("BatchMatMul");
|
||||
runner.jt_name = "bmm";
|
||||
runner.add(_op->a, true);
|
||||
runner.add(_op->b, true);
|
||||
runner.add(_op->c, false);
|
||||
runner.run();
|
||||
}},
|
||||
// {"cudnn_conv", [](Op *op)
|
||||
// {
|
||||
// struct ConvOp : Op
|
||||
// {
|
||||
// Var *x, *w, *y;
|
||||
// int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
|
||||
// string xformat, wformat, yformat;
|
||||
// void run_acl()
|
||||
// {
|
||||
// AclOpRunner runner("Conv2D");
|
||||
// runner.jt_name = "conv";
|
||||
// runner.add(x, true);
|
||||
// runner.add(w, true);
|
||||
// runner.add(y, false);
|
||||
// ConvAttr *attr = new ConvAttr();
|
||||
|
||||
// attr->convStrides = {strideh, stridew, 1, 1};
|
||||
// attr->convPads = {paddingh, paddingh, paddingw, paddingw};
|
||||
// attr->convOutPads = {1, 1, 1, 1};
|
||||
// attr->convDilations = {dilationh, dilationw, 1, 1};
|
||||
// attr->group = groups;
|
||||
// runner.op_attr.reset(attr);
|
||||
|
||||
// runner.run();
|
||||
// }
|
||||
// };
|
||||
// auto _op = (ConvOp *)op;
|
||||
// _op->run_acl();
|
||||
// }},
|
||||
// {"cudnn_conv_backward_x", [](Op *op)
|
||||
// {
|
||||
// struct ConvBackwardXOp : Op
|
||||
// {
|
||||
// Var *w, *dy, *dx;
|
||||
// int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
|
||||
// string xformat, wformat, yformat;
|
||||
// void run_acl()
|
||||
// {
|
||||
// /*
|
||||
// AclOpRunner runner("Conv2DBackpropInput");
|
||||
// runner.add_input_host_nv32(dx->shape); // 10,3,50,50
|
||||
// // runner.add_input_host_nv32(dy->shape); // 10,3,50,50
|
||||
// runner.add(w, true, ACL_FORMAT_NCHW); // 4,3,3,3
|
||||
// aclSetTensorDescName(runner.input_desc.back(), "filter");
|
||||
// runner.add(dy, true, ACL_FORMAT_NCHW); // 10,4,48,48
|
||||
// aclSetTensorDescName(runner.input_desc.back(), "out_backprop");
|
||||
// runner.add(dx, false, ACL_FORMAT_NCHW); // 10,3,50,50
|
||||
// aclSetTensorDescName(runner.input_desc.back(), "y");
|
||||
// runner.set_attr("strides", vector<int64_t>{1,1,strideh,stridew});
|
||||
// runner.set_attr("pads", vector<int64_t>{paddingh,paddingh,paddingw,paddingw});
|
||||
// runner.set_attr("dilations", vector<int64_t>{1,1,dilationh,dilationw});
|
||||
// runner.set_attr("groups", groups);
|
||||
// runner.set_attr("data_format", "NCHW");
|
||||
// // runner.set_attr("dataFormat", "NCHW");
|
||||
// // runner.set_attr("data_format", "NCHW");
|
||||
// ASSERT(xformat=="abcd" && yformat=="abcd" && wformat=="oihw");
|
||||
// runner.run();*/
|
||||
// }
|
||||
// };
|
||||
// auto _op = (ConvBackwardXOp *)op;
|
||||
// _op->run_acl();
|
||||
// }},
|
||||
// {"cudnn_conv_backward_w", [](Op *op)
|
||||
// {
|
||||
// struct ConvBackwardWOp : Op
|
||||
// {
|
||||
// Var *x, *dy, *dw;
|
||||
// int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
|
||||
// string xformat, wformat, yformat;
|
||||
// void run_acl()
|
||||
// {
|
||||
// /*
|
||||
// AclOpRunner runner("Conv2DBackpropFilter");
|
||||
// runner.add(x, true, ACL_FORMAT_NCHW);
|
||||
// runner.add_input_host_nv32(dw->shape);
|
||||
// runner.add(dy, true, ACL_FORMAT_NCHW);
|
||||
// runner.add(dw, false, ACL_FORMAT_NCHW);
|
||||
// runner.set_attr("strides", vector<int64_t>{1, 1, strideh, stridew});
|
||||
// runner.set_attr("pads", vector<int64_t>{paddingh, paddingh, paddingw, paddingw});
|
||||
// runner.set_attr("dilations", vector<int64_t>{1, 1, dilationh, dilationw});
|
||||
// runner.set_attr("groups", groups);
|
||||
// runner.set_attr("data_format", "NCHW");
|
||||
// // runner.set_attr("dataFormat", "NCHW");
|
||||
// // runner.set_attr("data_format", "NCHW");
|
||||
// // runner.set_attr("data_origin_format", "NCHW");
|
||||
// ASSERT(xformat == "abcd" && yformat == "abcd" && wformat == "oihw");
|
||||
// runner.run();
|
||||
// */
|
||||
// }
|
||||
// };
|
||||
// auto _op = (ConvBackwardWOp *)op;
|
||||
// _op->run_acl();
|
||||
// }},
|
||||
// {"cub_arg_reduce", }
|
||||
};
|
||||
|
||||
static void exec_mapped_acl_ops(Op *op)
|
||||
|
|
|
@ -64,7 +64,19 @@
|
|||
#include "aclnnop/aclnn_random.h"
|
||||
#include "aclnnop/aclnn_normal.h"
|
||||
#include "aclnnop/aclnn_permute.h"
|
||||
|
||||
#include "aclnnop/aclnn_max_pool2d_with_indices.h"
|
||||
#include "aclnnop/aclnn_max_pool2d_with_indices_backward.h"
|
||||
#include "aclnnop/aclnn_avgpool2d.h"
|
||||
#include "aclnnop/aclnn_flip.h"
|
||||
#include "aclnnop/aclnn_cat.h"
|
||||
#include "aclnnop/aclnn_gather.h"
|
||||
#include "aclnnop/aclnn_cumsum.h"
|
||||
#include "aclnnop/aclnn_index.h"
|
||||
#include "aclnnop/aclnn_scatter.h"
|
||||
#include "aclnnop/aclnn_index.h"
|
||||
#include "aclnnop/aclnn_strided_slice_assign_v2.h"
|
||||
#include "aclnnop/aclnn_slice_v2.h"
|
||||
#include "aclnnop/aclnn_index_put_impl.h"
|
||||
|
||||
#define CHECK_RET(cond, return_expr) \
|
||||
do \
|
||||
|
|
|
@ -105,11 +105,11 @@ namespace jittor
|
|||
// for reduce
|
||||
std::vector<int64_t> axes;
|
||||
aclIntArray *dim = nullptr;
|
||||
bool keepdims;
|
||||
|
||||
bool use_nchw = false;
|
||||
|
||||
auto input_num = in_.size();
|
||||
|
||||
auto output_num = out_.size();
|
||||
|
||||
for (int input_idx = 0; input_idx < input_num; input_idx++)
|
||||
|
@ -141,9 +141,6 @@ namespace jittor
|
|||
// for expand
|
||||
aclIntArray *size = nullptr;
|
||||
|
||||
// for add and sub
|
||||
float alphaValue = 1.0f;
|
||||
|
||||
// for conv
|
||||
aclIntArray *strides = nullptr;
|
||||
aclIntArray *pads = nullptr;
|
||||
|
@ -151,13 +148,74 @@ namespace jittor
|
|||
aclIntArray *dilations = nullptr;
|
||||
int ret = -1;
|
||||
|
||||
// for maxpool
|
||||
aclIntArray *kernel_size = nullptr;
|
||||
|
||||
// for concat
|
||||
aclTensorList *tensor_list = nullptr;
|
||||
|
||||
if (name == string("Add") || name == string("Sub"))
|
||||
{
|
||||
alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
|
||||
|
||||
if (get_dtype(in_[0]->dtype()) == ACL_FLOAT)
|
||||
{
|
||||
float alphaValue = 1.0;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_FLOAT16)
|
||||
{
|
||||
float alphaValue = 1.0;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT64)
|
||||
{
|
||||
int64_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT32)
|
||||
{
|
||||
int alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT8)
|
||||
{
|
||||
int8_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_INT16)
|
||||
{
|
||||
int16_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT8)
|
||||
{
|
||||
uint8_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT16)
|
||||
{
|
||||
uint16_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_UINT32)
|
||||
{
|
||||
uint32_t alphaValue = 1;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else if (get_dtype(in_[0]->dtype()) == ACL_BOOL)
|
||||
{
|
||||
bool alphaValue = true;
|
||||
alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype()));
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "Not supported dtype: " << in_[0]->dtype();
|
||||
}
|
||||
|
||||
CHECK_RET(alpha != nullptr, return);
|
||||
}
|
||||
|
||||
if (jt_name == "conv" || jt_name == "conv2d" || jt_name == "conv2dbackward")
|
||||
if (jt_name == "conv" || jt_name == "conv2d" || jt_name == "conv2dbackward" || jt_name == "maxpool" || jt_name == "maxpoolbackward")
|
||||
use_nchw = true;
|
||||
|
||||
for (int idx = 0; idx < input_num; idx++)
|
||||
|
@ -167,22 +225,40 @@ namespace jittor
|
|||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
|
||||
if (jt_name == "reduce")
|
||||
if (jt_name == "reduce" || jt_name == "transpose")
|
||||
{
|
||||
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
|
||||
dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
|
||||
keepdims = attr->keepdims;
|
||||
if (name == string("ReduceMax") || name == string("ReduceMin") || name == string("ReduceMean") || name == string("ReduceProd"))
|
||||
{
|
||||
if (attr->axes.size() == in_[0]->shape.size())
|
||||
outputShapes[0] = {};
|
||||
}
|
||||
}
|
||||
for (int idx = 0; idx < output_num; idx++)
|
||||
if (jt_name == "conv2dbackward")
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
for (int idx = 0; idx < 2; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
// biasgrad nd format
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int idx = 0; idx < output_num; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 调用CANN算子库aclnnxxxGetWorkspaceSize的接口,两段式接口的第一个
|
||||
|
@ -206,17 +282,22 @@ namespace jittor
|
|||
ret = it->second.getWorkspaceSizeFuncMatmul(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
else if (name == string("ReduceSum") || name == string("ReduceMean"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, false, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("ReduceMax") || name == string("ReduceMin"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, false, outputTensors[0], &workspaceSize, &executor);
|
||||
ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
// else if (name == string("ReduceProd"))
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncReduceProd(inputTensors[0], dim, false, outputTensors[0], &workspaceSize, &executor);
|
||||
// }
|
||||
else if (name == string("Select"))
|
||||
else if (name == string("RandomUniform") || name == string("RandomNormal"))
|
||||
{
|
||||
auto attr = dynamic_cast<RandomAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], int64_t(0), int64_t(1), attr->seed, attr->offset, &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Select") || name == string("Where"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
|
@ -225,6 +306,10 @@ namespace jittor
|
|||
auto attr = dynamic_cast<TriuAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Transpose"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Conv2d"))
|
||||
{
|
||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
|
@ -232,8 +317,11 @@ namespace jittor
|
|||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
aclTensor *bias = nullptr;
|
||||
if (input_num == 3)
|
||||
bias = inputTensors[2];
|
||||
|
||||
ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], nullptr, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
|
||||
ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], bias, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Conv2dBackward"))
|
||||
{
|
||||
|
@ -242,9 +330,93 @@ namespace jittor
|
|||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||
bool outputMask[3] = {true, true, false};
|
||||
bool outputMask[3] = {true, true, true};
|
||||
if (input_num == 3)
|
||||
{
|
||||
outputMask[2] = false;
|
||||
}
|
||||
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
|
||||
ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], nullptr, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], nullptr, &workspaceSize, &executor);
|
||||
auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size());
|
||||
ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Maxpool"))
|
||||
{
|
||||
auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
|
||||
kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
|
||||
strides = aclCreateIntArray(attr->poolStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->poolPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->poolDilations.data(), 2);
|
||||
ret = it->second.getWorkspaceSizeFuncMaxPool(inputTensors[0], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("MaxpoolBackward"))
|
||||
{
|
||||
auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
|
||||
kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
|
||||
strides = aclCreateIntArray(attr->poolStrides.data(), 2);
|
||||
pads = aclCreateIntArray(attr->poolPads.data(), 2);
|
||||
dilations = aclCreateIntArray(attr->poolDilations.data(), 2);
|
||||
ret = it->second.getWorkspaceSizeFuncMaxPoolBackward(inputTensors[0], inputTensors[1], inputTensors[2], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Flip"))
|
||||
{
|
||||
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
|
||||
dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Concat"))
|
||||
{
|
||||
auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
|
||||
CHECK_RET(inputTensors.size() == attr->tensorNum, return);
|
||||
std::vector<const aclTensor *> constTensors(inputTensors.begin(), inputTensors.end());
|
||||
tensor_list = aclCreateTensorList(constTensors.data(), attr->tensorNum);
|
||||
ret = it->second.getWorkspaceSizeFuncConcat(tensor_list, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Gather"))
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncGather(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Cumsum"))
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncCumsum(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Scatter"))
|
||||
{
|
||||
auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Floor"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnary(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("Index"))
|
||||
{
|
||||
auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);
|
||||
ret = it->second.getWorkspaceSizeFuncIndex(inputTensors[0], indexTensorList, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("SliceV2"))
|
||||
{
|
||||
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
|
||||
auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
|
||||
auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
|
||||
auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
|
||||
auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
ret = it->second.getWorkspaceSizeFuncSliceV2(inputTensors[0], begins, ends, axes, steps, outputTensors[0], &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("IndexPutImpl"))
|
||||
{
|
||||
std::vector<aclTensor *> indexTensorList = {};
|
||||
for (int i = 1; i < input_num; i++)
|
||||
{
|
||||
indexTensorList.push_back(inputTensors[i]);
|
||||
}
|
||||
auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1);
|
||||
ret = it->second.getWorkspaceSizeFuncIndexPutImpl(outputTensors[0], indexTensorListInput, inputTensors[0], false, true, &workspaceSize, &executor);
|
||||
}
|
||||
else if (name == string("StridedSliceAssignV2"))
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncStridedSliceAssignV2(outputTensors[0], inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], &workspaceSize, &executor);
|
||||
}
|
||||
else
|
||||
LOGf << "not supported op " << jt_name;
|
||||
|
@ -253,17 +425,15 @@ namespace jittor
|
|||
if (ret != ACL_SUCCESS)
|
||||
{
|
||||
auto tmp_err_msg = aclGetRecentErrMsg();
|
||||
LOGir << tmp_err_msg;
|
||||
LOGir << name << ", " << tmp_err_msg;
|
||||
}
|
||||
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
// 4. 根据第一段接口计算出的workspaceSize申请device内存
|
||||
void *workspaceAddr = nullptr;
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: allocate workspace failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
// 5. 调用aclnnxx第二段接口
|
||||
|
@ -294,12 +464,9 @@ namespace jittor
|
|||
aclDestroyIntArray(pads);
|
||||
aclDestroyIntArray(outPads);
|
||||
aclDestroyIntArray(dilations);
|
||||
aclDestroyIntArray(kernel_size);
|
||||
aclDestroyTensorList(tensor_list);
|
||||
|
||||
// 8. 释放device资源
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
aclrtFree(workspaceAddr);
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -110,13 +110,13 @@ class ResNet(nn.Module):
|
|||
jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out")
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.Relu()
|
||||
# self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')
|
||||
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
# self.fc = nn.Linear((512 * block.expansion), num_classes)
|
||||
self.fc = nn.Linear((512 * block.expansion), num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
|
@ -138,14 +138,14 @@ class ResNet(nn.Module):
|
|||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
# x = self.maxpool(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.avgpool(x).float_auto()
|
||||
x = jt.reshape(x, (x.shape[0], -1))
|
||||
# x = self.fc(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def execute(self, x):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
@ -17,23 +17,22 @@ class TestACL(unittest.TestCase):
|
|||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_array(self):
|
||||
print("use_acl", jt.flags.use_acl)
|
||||
a = jt.array([1,2,3])
|
||||
np.testing.assert_allclose(a.numpy(), [1,2,3])
|
||||
a = jt.array([1, 2, 3])
|
||||
np.testing.assert_allclose(a.numpy(), [1, 2, 3])
|
||||
print('test_array pass')
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_add(self):
|
||||
a = jt.array([1,2,3])
|
||||
b = a+a
|
||||
np.testing.assert_allclose(b.numpy(), [2,4,6])
|
||||
a = jt.array([1, 2, 3])
|
||||
b = a + a
|
||||
np.testing.assert_allclose(b.numpy(), [2, 4, 6])
|
||||
print('test_add pass')
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_add_float(self):
|
||||
a = jt.array([1.0,2.0,3.0])
|
||||
b = a+a
|
||||
np.testing.assert_allclose(b.numpy(), [2,4,6])
|
||||
a = jt.array([1.0, 2.0, 3.0])
|
||||
b = a + a
|
||||
np.testing.assert_allclose(b.numpy(), [2, 4, 6])
|
||||
print('test_add_float pass')
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
|
@ -55,7 +54,7 @@ class TestACL(unittest.TestCase):
|
|||
@jt.flag_scope(use_acl=1)
|
||||
def test_rand(self):
|
||||
a = jt.rand(10)
|
||||
b = a*10
|
||||
b = a * 10
|
||||
b.sync()
|
||||
print(b)
|
||||
|
||||
|
@ -66,21 +65,21 @@ class TestACL(unittest.TestCase):
|
|||
@jt.flag_scope(use_acl=1)
|
||||
def test_conv(self):
|
||||
x = jt.rand(10, 3, 50, 50)
|
||||
w = jt.rand(4,3,3,3)
|
||||
w = jt.rand(4, 3, 3, 3)
|
||||
# x = jt.rand(2, 2, 1, 1)
|
||||
# w = jt.rand(2,2,1,1)
|
||||
y = jt.nn.conv2d(x, w)
|
||||
y.sync(True)
|
||||
y1 = y.data
|
||||
mask = jt.rand_like(y)
|
||||
dx, dw = jt.grad((y*mask).sum(), [x, w])
|
||||
dx, dw = jt.grad((y * mask).sum(), [x, w])
|
||||
dx1, dw1 = dx.data, dw.data
|
||||
# dw, = jt.grad((y*mask).sum(), [w])
|
||||
# dw1 = dw.data
|
||||
with jt.flag_scope(use_acl=0):
|
||||
y = jt.nn.conv2d(x, w)
|
||||
y2 = y.data
|
||||
dx, dw = jt.grad((y*mask).sum(), [x, w])
|
||||
dx, dw = jt.grad((y * mask).sum(), [x, w])
|
||||
dx2, dw2 = dx.data, dw.data
|
||||
# dw, = jt.grad((y*mask).sum(), [w])
|
||||
# dw2 = dw.data
|
||||
|
@ -93,8 +92,8 @@ class TestACL(unittest.TestCase):
|
|||
def test_matmul(self):
|
||||
# x = jt.rand(10, 3, 50, 50)
|
||||
# w = jt.rand(4,3,3,3)
|
||||
x = jt.rand(10,10)
|
||||
w = jt.rand(10,10)
|
||||
x = jt.rand(10, 10)
|
||||
w = jt.rand(10, 10)
|
||||
y = jt.matmul(x, w)
|
||||
ny = np.matmul(x.numpy(), w.numpy())
|
||||
np.testing.assert_allclose(y.numpy(), ny, atol=1e-3, rtol=1e-3)
|
||||
|
@ -102,7 +101,7 @@ class TestACL(unittest.TestCase):
|
|||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_max(self):
|
||||
x = jt.rand(3,3)
|
||||
x = jt.rand(3, 3)
|
||||
y = x.max(1).data
|
||||
ny = x.data.max(1)
|
||||
np.testing.assert_allclose(y, ny)
|
||||
|
@ -110,7 +109,7 @@ class TestACL(unittest.TestCase):
|
|||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_sum(self):
|
||||
x = jt.rand(3,3).float16()
|
||||
x = jt.rand(3, 3).float16()
|
||||
print(x)
|
||||
# return
|
||||
y = x.sum(1).data
|
||||
|
@ -124,14 +123,14 @@ class TestACL(unittest.TestCase):
|
|||
def test_broadcast(self):
|
||||
x = jt.rand(3)
|
||||
# print(x)
|
||||
y = x.broadcast([3,3]).data
|
||||
y = x.broadcast([3, 3]).data
|
||||
ny = np.broadcast_arrays(x.data, y)[0]
|
||||
np.testing.assert_allclose(y, ny)
|
||||
print(x, y)
|
||||
# y = x.broadcast([3,3], dims=[1]).data
|
||||
y = jt.broadcast(x, shape=(3,3), dims=[1]).data
|
||||
y = jt.broadcast(x, shape=(3, 3), dims=[1]).data
|
||||
with jt.flag_scope(use_acl=0):
|
||||
ny = jt.broadcast(x, shape=(3,3), dims=[1]).data
|
||||
ny = jt.broadcast(x, shape=(3, 3), dims=[1]).data
|
||||
# ny = np.broadcast_arrays(x.data, y)[0]
|
||||
np.testing.assert_allclose(y, ny)
|
||||
print(x, y)
|
||||
|
@ -141,44 +140,48 @@ class TestACL(unittest.TestCase):
|
|||
def test_resnet(self):
|
||||
from jittor.models import resnet50
|
||||
net = resnet50()
|
||||
x = jt.rand(2,3,224,224)
|
||||
x = jt.rand(2, 3, 224, 224)
|
||||
y = net(x)
|
||||
y.sync()
|
||||
|
||||
|
||||
|
||||
def matmul(a, b):
|
||||
(n, m), k = a.shape, b.shape[-1]
|
||||
a = a.broadcast([n,m,k], dims=[2])
|
||||
b = b.broadcast([n,m,k], dims=[0])
|
||||
return (a*b).sum(dim=1)
|
||||
|
||||
class Linear(Module):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
|
||||
self.b = jt.random((out_features,))-0.5 if bias else None
|
||||
self.w = (jt.random(
|
||||
(in_features, out_features)) - 0.5) / in_features**0.5
|
||||
self.b = jt.random((out_features, )) - 0.5 if bias else None
|
||||
|
||||
def execute(self, x):
|
||||
x = matmul(x, self.w)
|
||||
if self.b is not None:
|
||||
return x+self.b
|
||||
x = jt.nn.matmul(x, self.w)
|
||||
if self.b is not None:
|
||||
return x + self.b
|
||||
return x
|
||||
|
||||
|
||||
def relu(x):
|
||||
return jt.maximum(x, 0.0)
|
||||
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
|
||||
|
||||
class Model(Module):
|
||||
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = Linear(input_size, 10)
|
||||
self.relu1 = Relu()
|
||||
self.linear2 = Linear(10, 1)
|
||||
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
|
||||
class TestExample(unittest.TestCase):
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test1(self):
|
||||
np.random.seed(0)
|
||||
|
@ -190,27 +193,29 @@ class TestExample(unittest.TestCase):
|
|||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1).astype("float32")
|
||||
y = x*x
|
||||
y = x * x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
|
||||
model = Model(input_size=1)
|
||||
ps = model.parameters()
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
for i, (x, y) in enumerate(get_data(n)):
|
||||
jt.sync_all(True)
|
||||
pred_y = model(x).name("pred_y")
|
||||
loss = ((pred_y - y).sqr()).name("loss")
|
||||
loss_mean = loss.mean()
|
||||
|
||||
|
||||
gs = jt.grad(loss_mean, ps)
|
||||
for p, g in zip(ps, gs):
|
||||
p -= g * lr
|
||||
|
||||
if i>2:
|
||||
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
|
||||
if i > 2:
|
||||
assert prev == jt.liveness_info(
|
||||
), f"memory leak {prev} {jt.liveness_info()}"
|
||||
prev = jt.liveness_info()
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}")
|
||||
|
||||
print(
|
||||
f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}"
|
||||
)
|
||||
breakpoint()
|
||||
possible_results = [
|
||||
0.0009948202641680837,
|
||||
0.001381353591568768,
|
||||
|
@ -221,5 +226,6 @@ class TestExample(unittest.TestCase):
|
|||
|
||||
jt.clean()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -152,67 +152,6 @@ class TestACL(unittest.TestCase):
|
|||
[[[4, 4], [4, 4]], [[4, 4], [4, 4]], [[4, 4], [4, 4]]])
|
||||
print("test bmm grad success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_avgpool(self):
|
||||
a = jt.ones(1, 1, 4, 4)
|
||||
avg_pool = jt.nn.Pool(2, op='mean')
|
||||
b = avg_pool(a)
|
||||
np.testing.assert_allclose(b.numpy(), [[[[1, 1], [1, 1]]]])
|
||||
print("test avgpool success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_adaptive_maxpool2d(self):
|
||||
a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]]])
|
||||
pool = jt.nn.AdaptiveMaxPool2d((2, 2))
|
||||
b = pool(a)
|
||||
np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]])
|
||||
print("test adaptive_maxpool2d success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_adaptive_maxpool2d_grad(self):
|
||||
a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]]])
|
||||
max_pool = jt.nn.AdaptiveMaxPool2d((2, 2))
|
||||
optimizer = jt.optim.SGD([a], 0.1)
|
||||
b = max_pool(a)
|
||||
loss = b.sum()
|
||||
optimizer.zero_grad()
|
||||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
res = a.opt_grad(optimizer)
|
||||
np.testing.assert_allclose(
|
||||
res.numpy(),
|
||||
[[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]])
|
||||
print("test adaptive_maxpool2d grad success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_adaptive_avgpool2d(self):
|
||||
a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]]])
|
||||
pool = jt.nn.AdaptiveAvgPool2d((2, 2))
|
||||
b = pool(a)
|
||||
np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]])
|
||||
print("test adaptive_avgpool2d success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_adaptive_avgpool2d_grad(self):
|
||||
a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]]])
|
||||
avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2))
|
||||
optimizer = jt.optim.SGD([a], 0.1)
|
||||
b = avg_pool(a)
|
||||
loss = b.sum()
|
||||
optimizer.zero_grad()
|
||||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
res = a.opt_grad(optimizer)
|
||||
np.testing.assert_allclose(
|
||||
res.numpy(),
|
||||
[[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25],
|
||||
[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]])
|
||||
print("test adaptive_avgpool2d grad success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_index(self):
|
||||
a = jt.ones(2, 3)
|
||||
|
|
Loading…
Reference in New Issue