mirror of https://github.com/Jittor/Jittor
polish code
This commit is contained in:
parent
093d562aeb
commit
4c2c9bc8e1
File diff suppressed because it is too large
Load Diff
|
@ -154,9 +154,9 @@ namespace jittor
|
|||
std::queue<Op *> queue;
|
||||
|
||||
for (Op *op : fop->ops)
|
||||
op_indeg[op] = 0;
|
||||
op_indeg[op] = 0;
|
||||
|
||||
map<Op *, vector<Op *>> out_map;
|
||||
map<Op *, vector<Op *>> out_map;
|
||||
map<Var *, vector<Op *>> from;
|
||||
|
||||
int len = 0;
|
||||
|
@ -303,15 +303,12 @@ namespace jittor
|
|||
op.op_attr.reset(attr);
|
||||
op.add(rop->y, false);
|
||||
op.run();
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
else if (op->name() == string("broadcast_to"))
|
||||
{
|
||||
auto bop = (BroadcastToOp *)op;
|
||||
AclOpRunner op("Expand");
|
||||
if (bop->x->shape.size() == 1 && bop->x->shape[0] == 1)
|
||||
{
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
op.jt_name = "expand";
|
||||
NanoVector xshape, xshape_bk = bop->x->shape;
|
||||
NanoVector zshape = bop->z->shape;
|
||||
|
@ -333,7 +330,7 @@ namespace jittor
|
|||
op.add(bop->z, false);
|
||||
op.run();
|
||||
bop->x->shape = xshape_bk;
|
||||
// aclrtSynchronizeStream(aclstream);
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
else if (op->name() == string("fuse_transpose"))
|
||||
{
|
||||
|
|
|
@ -224,26 +224,26 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncMatmul(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 9:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 10:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 11:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 12:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 9:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 10:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 11:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 12:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
|
||||
// case 13:
|
||||
// {
|
||||
|
@ -623,7 +623,7 @@ namespace jittor
|
|||
// ret = aclrtSynchronizeStream(aclstream);
|
||||
// CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
// }
|
||||
|
||||
|
||||
// 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
|
||||
// destroy tensor
|
||||
// for (int idx = 0; idx < input_num; idx++)
|
||||
|
|
|
@ -33,19 +33,20 @@ namespace jittor
|
|||
|
||||
// Common functionality for adding input/output variables
|
||||
void add(Var *v, bool is_input);
|
||||
|
||||
|
||||
virtual void setupInputDesc();
|
||||
|
||||
|
||||
void cleanupDesc();
|
||||
|
||||
virtual void setupOutputDesc();
|
||||
|
||||
|
||||
virtual void syncRun();
|
||||
|
||||
void checkRet(aclnnStatus ret);
|
||||
|
||||
|
||||
// Base run method with common operator lookup logic
|
||||
void run();
|
||||
|
||||
protected:
|
||||
// Virtual method for specific operator execution
|
||||
virtual void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) = 0;
|
||||
|
|
|
@ -105,13 +105,13 @@ namespace jittor
|
|||
|
||||
void BaseOpRunner::syncRun()
|
||||
{
|
||||
if(sync_run) {
|
||||
ret = aclrtSynchronizeStream(aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
if (sync_run)
|
||||
{
|
||||
// ret = aclrtSynchronizeStream(aclstream);
|
||||
// CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void BaseOpRunner::checkRet(aclnnStatus ret)
|
||||
{
|
||||
if (ret != ACL_SUCCESS)
|
||||
|
|
|
@ -116,7 +116,7 @@ namespace jittor
|
|||
ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
//syncRun();
|
||||
syncRun();
|
||||
|
||||
aclDestroyScalar(alpha);
|
||||
return;
|
||||
|
|
|
@ -12,15 +12,14 @@ from typing import Union
|
|||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
|
||||
def acl_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -38,7 +37,7 @@ def acl_cmd(name: str,
|
|||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
|
@ -49,8 +48,9 @@ def acl_cmd(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class BmmACL(jt.Function):
|
||||
|
||||
def __init__(self, trans_x2=False):
|
||||
|
@ -59,16 +59,14 @@ class BmmACL(jt.Function):
|
|||
|
||||
def execute(self, x1, x2):
|
||||
self.input = [x1, x2]
|
||||
result = acl_cmd(
|
||||
"BatchMatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
x1.shape[:-1] +
|
||||
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
|
||||
x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"bmm_trans_1\";"
|
||||
if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||||
result = acl_cmd("BatchMatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
x1.shape[:-1] + x2.shape[-2:-1] if self.trans_x2
|
||||
else x1.shape[:-1] + x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"bmm_trans_1\";"
|
||||
if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||||
|
||||
return result
|
||||
|
||||
|
@ -78,57 +76,53 @@ class BmmACL(jt.Function):
|
|||
reshape_grad_x2 = True
|
||||
else:
|
||||
reshape_grad_x2 = False
|
||||
grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
grad_output.shape[:-1] +
|
||||
x2.shape[-2:-1] if not self.trans_x2 else
|
||||
grad_output.shape[:-1] + x1.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"bmm_trans_1\";" if
|
||||
not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||||
grad_x1 = acl_cmd(
|
||||
"BatchMatMul", [grad_output, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
|
||||
else grad_output.shape[:-1] + x1.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"bmm_trans_1\";"
|
||||
if not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
|
||||
if self.trans_x2:
|
||||
if reshape_grad_x2:
|
||||
output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [
|
||||
grad_output.reshape(-1, grad_output.shape[-1]),
|
||||
x1.reshape(-1, x1.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [
|
||||
grad_output.reshape(-1, grad_output.shape[-1]),
|
||||
x1.reshape(-1, x1.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [grad_output, x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [grad_output, x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
else:
|
||||
if reshape_grad_x2:
|
||||
output_shape = x1.shape[1:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [
|
||||
x1.reshape(-1, x1.shape[-1]),
|
||||
grad_output.reshape(-1, grad_output.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [
|
||||
x1.reshape(-1, x1.shape[-1]),
|
||||
grad_output.reshape(-1, grad_output.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = x1.shape[:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = acl_cmd(
|
||||
"BatchMatMul", [x1, grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
grad_x2 = acl_cmd("BatchMatMul", [x1, grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
|
||||
if len(grad_x1.shape) > len(x1.shape):
|
||||
grad_x1 = grad_x1.sum(0)
|
||||
if len(grad_x2.shape) > len(x2.shape):
|
||||
grad_x2 = grad_x2.sum(0)
|
||||
return grad_x1, grad_x2
|
||||
return grad_x1, grad_x2
|
||||
|
|
|
@ -63,8 +63,8 @@ namespace jittor
|
|||
}
|
||||
void BatchMatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
|
||||
ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
|
||||
ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
|
@ -72,6 +72,6 @@ namespace jittor
|
|||
}
|
||||
ret = aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnbatchMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
// syncRun();
|
||||
syncRun();
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ namespace jittor
|
|||
protected:
|
||||
void setupInputDesc() override;
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
BatchMatMulOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def concat_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def concat_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class ConcatACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -114,13 +116,11 @@ class ConcatACL(jt.Function):
|
|||
self.dim = dim
|
||||
for i in range(len(input_tensors)):
|
||||
if input_tensors[i].dtype != input_tensors[0].dtype:
|
||||
raise ValueError(
|
||||
"All input tensors must have the same dtype")
|
||||
raise ValueError("All input tensors must have the same dtype")
|
||||
if input_tensors[i].shape[:dim] != input_tensors[
|
||||
0].shape[:dim] or input_tensors[i].shape[
|
||||
dim + 1:] != input_tensors[0].shape[dim + 1:]:
|
||||
raise ValueError(
|
||||
"All input tensors must have the same shape")
|
||||
raise ValueError("All input tensors must have the same shape")
|
||||
attr_code = f"""
|
||||
op.jt_name = "concat";
|
||||
ConcatAttr *attr = new ConcatAttr();
|
||||
|
@ -133,15 +133,13 @@ class ConcatACL(jt.Function):
|
|||
input_tensors,
|
||||
output_dtypes=[input_tensors[0].dtype],
|
||||
output_shapes=[
|
||||
jt.empty(self.calculate_output_shape(input_tensors,
|
||||
dim)).shape
|
||||
jt.empty(self.calculate_output_shape(input_tensors, dim)).shape
|
||||
],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def _grad(self, *args):
|
||||
new_args = ((args[i] if i >= 0 else None)
|
||||
for i in self.output_mask)
|
||||
new_args = ((args[i] if i >= 0 else None) for i in self.output_mask)
|
||||
ret = self.grad(*new_args)
|
||||
new_ret = []
|
||||
for i, r in enumerate(ret):
|
||||
|
@ -185,4 +183,4 @@ class ConcatACL(jt.Function):
|
|||
output_dtypes=dtypeVec,
|
||||
output_shapes=shapeVec,
|
||||
attr_code=attr_code)
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
ConcatOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SplitWithSizeOpRunner();
|
||||
};
|
||||
|
|
|
@ -22,16 +22,18 @@ def _ntuple(n):
|
|||
|
||||
return parse
|
||||
|
||||
|
||||
_pair = _ntuple(2)
|
||||
|
||||
|
||||
def conv_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -49,7 +51,7 @@ def conv_forward(name: str,
|
|||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
|
@ -60,16 +62,17 @@ def conv_forward(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
data=extra_data)
|
||||
|
||||
|
||||
def conv_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
# TODO: not done for now
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
|
@ -88,7 +91,7 @@ def conv_forward(name: str,
|
|||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
|
@ -99,32 +102,33 @@ def conv_forward(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class ConvACL(jt.Function):
|
||||
|
||||
def execute(self,
|
||||
x,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
self.input = x
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
padding = _pair(padding)
|
||||
stride = _pair(stride)
|
||||
dilation = _pair(dilation)
|
||||
out_channels = weight.shape[0]
|
||||
if groups <= 0:
|
||||
raise ValueError("groups must be a positive integer")
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
attr_code = f"""
|
||||
def execute(self,
|
||||
x,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
self.input = x
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
padding = _pair(padding)
|
||||
stride = _pair(stride)
|
||||
dilation = _pair(dilation)
|
||||
out_channels = weight.shape[0]
|
||||
if groups <= 0:
|
||||
raise ValueError("groups must be a positive integer")
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
attr_code = f"""
|
||||
op.jt_name = "conv2d";
|
||||
ConvAttr *attr = new ConvAttr();
|
||||
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
|
||||
|
@ -134,49 +138,48 @@ class ConvACL(jt.Function):
|
|||
attr->convOutPads = {{1,1}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
input_height, input_width = x.shape[-2:]
|
||||
kernel_height, kernel_width = weight.shape[-2:]
|
||||
input_height, input_width = x.shape[-2:]
|
||||
kernel_height, kernel_width = weight.shape[-2:]
|
||||
|
||||
output_height = (input_height + 2 * padding[0] - dilation[0] *
|
||||
(kernel_height - 1) - 1) // stride[0] + 1
|
||||
output_width = (input_width + 2 * padding[1] - dilation[1] *
|
||||
(kernel_width - 1) - 1) // stride[1] + 1
|
||||
output_height = (input_height + 2 * padding[0] - dilation[0] *
|
||||
(kernel_height - 1) - 1) // stride[0] + 1
|
||||
output_width = (input_width + 2 * padding[1] - dilation[1] *
|
||||
(kernel_width - 1) - 1) // stride[1] + 1
|
||||
|
||||
output_shape = (x.shape[0], out_channels, output_height,
|
||||
output_width)
|
||||
output_shape = (x.shape[0], out_channels, output_height, output_width)
|
||||
|
||||
inputs = [x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
result = conv_forward(
|
||||
"Conv2d",
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code,
|
||||
)[0]
|
||||
return result
|
||||
inputs = [x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
result = conv_forward(
|
||||
"Conv2d",
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code,
|
||||
)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x = self.input
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
inputs = [grad_output, x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
output_shapes = [x.shape, weight.shape]
|
||||
output_dtypes = [x.dtype, weight.dtype]
|
||||
if bias is not None:
|
||||
output_shapes.append(bias.shape)
|
||||
output_dtypes.append(bias.dtype)
|
||||
else:
|
||||
output_shapes.append([1])
|
||||
output_dtypes.append(x.dtype)
|
||||
padding = self.padding
|
||||
stride = self.stride
|
||||
dilation = self.dilation
|
||||
groups = self.groups
|
||||
attr_code = f"""
|
||||
def grad(self, grad_output):
|
||||
x = self.input
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
inputs = [grad_output, x, weight]
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
output_shapes = [x.shape, weight.shape]
|
||||
output_dtypes = [x.dtype, weight.dtype]
|
||||
if bias is not None:
|
||||
output_shapes.append(bias.shape)
|
||||
output_dtypes.append(bias.dtype)
|
||||
else:
|
||||
output_shapes.append([1])
|
||||
output_dtypes.append(x.dtype)
|
||||
padding = self.padding
|
||||
stride = self.stride
|
||||
dilation = self.dilation
|
||||
groups = self.groups
|
||||
attr_code = f"""
|
||||
op.jt_name = "conv2dbackward";
|
||||
ConvAttr *attr = new ConvAttr();
|
||||
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
|
||||
|
@ -186,12 +189,12 @@ class ConvACL(jt.Function):
|
|||
attr->convOutPads = {{ 1,1}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
results = acl_cmd_forward("Conv2dBackward",
|
||||
inputs,
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)
|
||||
if self.bias is None:
|
||||
return results[0], results[1]
|
||||
results = acl_cmd_forward("Conv2dBackward",
|
||||
inputs,
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)
|
||||
if self.bias is None:
|
||||
return results[0], results[1]
|
||||
|
||||
return results
|
||||
return results
|
||||
|
|
|
@ -34,7 +34,7 @@ namespace jittor
|
|||
{
|
||||
use_nchw = true;
|
||||
}
|
||||
|
||||
|
||||
void ConvOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
// for conv
|
||||
|
@ -66,7 +66,7 @@ namespace jittor
|
|||
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
// syncRun();
|
||||
syncRun();
|
||||
|
||||
aclDestroyIntArray(strides);
|
||||
aclDestroyIntArray(pads);
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
ConvOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def cumsum_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def cumsum_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class CumsumACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -85,15 +87,15 @@ class CumsumACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flipped_grad_output = cumsum_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=cumsum_attr_code)[0]
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=cumsum_attr_code)[0]
|
||||
grad_input = cumsum_cmd("Flip", [cumulative_grad],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
return grad_input
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
CumsumOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def dropout_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def dropout_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class DropoutACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -71,9 +73,9 @@ class DropoutACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = dropout_cmd("Dropout", [x],
|
||||
output_dtypes=[x.dtype, "uint8"],
|
||||
output_shapes=[x.shape, mask_shape],
|
||||
attr_code=attr_code)
|
||||
output_dtypes=[x.dtype, "uint8"],
|
||||
output_shapes=[x.shape, mask_shape],
|
||||
attr_code=attr_code)
|
||||
self.maskout = result[1]
|
||||
return result[0]
|
||||
|
||||
|
@ -85,8 +87,8 @@ class DropoutACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = dropout_cmd("DropoutBackward",
|
||||
[grad_output, self.maskout],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
[grad_output, self.maskout],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
DropoutOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
DropoutBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def flashattention_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,21 +52,22 @@ def flashattention_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class FlashAttentionACL(jt.Function):
|
||||
|
||||
def __init__(self,
|
||||
headnum,
|
||||
layout="BNSD",
|
||||
prefix=None,
|
||||
qstart=None,
|
||||
kvstart=None,
|
||||
scale=1.0,
|
||||
prob=1.0,
|
||||
pretokens=2147483647,
|
||||
nexttokens=2147483647,
|
||||
innerprecise=0,
|
||||
sparsemode=0,
|
||||
psetype=1):
|
||||
headnum,
|
||||
layout="BNSD",
|
||||
prefix=None,
|
||||
qstart=None,
|
||||
kvstart=None,
|
||||
scale=1.0,
|
||||
prob=1.0,
|
||||
pretokens=2147483647,
|
||||
nexttokens=2147483647,
|
||||
innerprecise=0,
|
||||
sparsemode=0,
|
||||
psetype=1):
|
||||
self.headnum = headnum
|
||||
self.layout = layout
|
||||
self.scale = scale
|
||||
|
@ -116,9 +118,7 @@ class FlashAttentionACL(jt.Function):
|
|||
|
||||
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
|
||||
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
|
||||
self.kvstart = self.kvstart if self.kvstart else [
|
||||
0 for _ in range(B)
|
||||
]
|
||||
self.kvstart = self.kvstart if self.kvstart else [0 for _ in range(B)]
|
||||
|
||||
self.hasRealshift = (not realshift == None)
|
||||
self.hasDropmask = (not dropMask == None)
|
||||
|
@ -126,8 +126,7 @@ class FlashAttentionACL(jt.Function):
|
|||
self.hasAttenmask = (not attenMask == None)
|
||||
|
||||
# 待定,目前设为nullptr
|
||||
self.realshift = realshift if realshift else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
self.realshift = realshift if realshift else jt.zeros(B, N, SQ, SKV)
|
||||
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
|
||||
self.paddingMask = paddingMask if paddingMask else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
|
@ -207,4 +206,4 @@ class FlashAttentionACL(jt.Function):
|
|||
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
|
||||
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
|
||||
attr_code=attr_code)
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -58,7 +58,6 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward")
|
||||
{
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FlashAttentionOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FlashAttentionBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def flip_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def flip_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class FlipACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -70,14 +72,14 @@ class FlipACL(jt.Function):
|
|||
"""
|
||||
self.attr_code = attr_code
|
||||
result = flip_cmd("Flip", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
grad_input = flip_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return grad_input
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FlipOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def floor_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def floor_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class FloorIntACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -59,10 +61,10 @@ class FloorIntACL(jt.Function):
|
|||
def execute(self, input):
|
||||
self.shape = input.shape
|
||||
result = floor_cmd("Floor", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code="op.jt_name=\"floor\";")[0]
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code="op.jt_name=\"floor\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return jt.zeros(self.shape, dtype=grad_output.dtype)
|
||||
return jt.zeros(self.shape, dtype=grad_output.dtype)
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
FloorOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def gather_scatter_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def gather_scatter_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class GatherACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -66,9 +68,9 @@ class GatherACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_scatter_cmd("Gather", [input, index],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
|
@ -80,12 +82,14 @@ class GatherACL(jt.Function):
|
|||
attr->reduction = {1};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_scatter_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
grad_input = gather_scatter_cmd("Scatter",
|
||||
[tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
||||
|
||||
class ScatterACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -103,9 +107,9 @@ class ScatterACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_scatter_cmd("Scatter", [input, self.index, src],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
|
@ -116,7 +120,7 @@ class ScatterACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_scatter_cmd("Gather", [grad_output, self.index],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[self.index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_output, None, None, grad_input
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[self.index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_output, None, None, grad_input
|
||||
|
|
|
@ -77,5 +77,4 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -9,16 +9,17 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
GatherOpRunner();
|
||||
};
|
||||
|
||||
|
||||
class ScatterOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
ScatterOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def getitem_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def getitem_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
def getitem_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
|
@ -90,7 +92,8 @@ def getitem_forward(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
data=extra_data)
|
||||
|
||||
|
||||
def caculate_shape(tensors):
|
||||
if isinstance(tensors, jt.Var):
|
||||
|
@ -105,6 +108,7 @@ def caculate_shape(tensors):
|
|||
else:
|
||||
assert False, f"not implemented for {type(tensors)}"
|
||||
|
||||
|
||||
def can_broadcast_and_shape(shape1, shape2):
|
||||
"""
|
||||
检查两个张量是否可以广播,并返回广播后的形状。
|
||||
|
@ -144,6 +148,7 @@ def can_broadcast_and_shape(shape1, shape2):
|
|||
|
||||
return True, tuple(broadcast_shape)
|
||||
|
||||
|
||||
class GetItemACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -174,9 +179,9 @@ class GetItemACL(jt.Function):
|
|||
op.jt_name = "maskedselect";
|
||||
"""
|
||||
result = getitem_cmd("MaskedSelect",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
result = result[:output_len]
|
||||
result.sync()
|
||||
return result
|
||||
|
@ -194,7 +199,7 @@ class GetItemACL(jt.Function):
|
|||
contains_slice = False
|
||||
for s in slices:
|
||||
if not isinstance(s, jt.Var) and (isinstance(s, slice)
|
||||
or s == Ellipsis):
|
||||
or s == Ellipsis):
|
||||
contains_slice = True
|
||||
break
|
||||
if not contains_slice:
|
||||
|
@ -212,9 +217,9 @@ class GetItemACL(jt.Function):
|
|||
output_shape = [1]
|
||||
for ii in slices:
|
||||
indices.append(jt.Var(ii).int32())
|
||||
if isinstance(slices[0], jt.Var) or isinstance(
|
||||
slices[0], int) or isinstance(
|
||||
slices[0], list) or isinstance(slices[0], tuple):
|
||||
if isinstance(slices[0],
|
||||
jt.Var) or isinstance(slices[0], int) or isinstance(
|
||||
slices[0], list) or isinstance(slices[0], tuple):
|
||||
self.indices = indices
|
||||
inputs = [x] + indices
|
||||
attr_code = f"""
|
||||
|
@ -222,10 +227,10 @@ class GetItemACL(jt.Function):
|
|||
"""
|
||||
self.type_ = 'index'
|
||||
result = getitem_cmd("Index",
|
||||
inputs=inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=attr_code)[0]
|
||||
result.sync()
|
||||
return result
|
||||
assert contains_slice, "slice type error"
|
||||
|
@ -235,8 +240,7 @@ class GetItemACL(jt.Function):
|
|||
if not isinstance(s, jt.Var) and s == Ellipsis:
|
||||
slices = slices[:slices.index(s)] + [
|
||||
slice(None, None, None)
|
||||
] * (x_dim - len(slices) + 1) + slices[slices.index(s) +
|
||||
1:]
|
||||
] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:]
|
||||
break
|
||||
slices = tuple(slices)
|
||||
|
||||
|
@ -313,11 +317,11 @@ class GetItemACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = getitem_forward("SliceV2",
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(sizes).shape],
|
||||
attr_code=attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
inputs,
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(sizes).shape],
|
||||
attr_code=attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
self.squeeze_dims = squeeze_dims
|
||||
for dim in squeeze_dims[::-1]:
|
||||
result = jt.squeeze(result, dim)
|
||||
|
@ -334,9 +338,9 @@ class GetItemACL(jt.Function):
|
|||
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
|
||||
# breakpoint()
|
||||
result = getitem_cmd("IndexPutImplAccumulate",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
result.sync()
|
||||
return result, None
|
||||
elif self.type_ == 'slicev2':
|
||||
|
@ -401,9 +405,9 @@ class GetItemACL(jt.Function):
|
|||
inputs = [grad_output]
|
||||
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
|
||||
result = getitem_cmd("StridedSliceAssignV2",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
result.sync()
|
||||
if expand_dim:
|
||||
result = result.squeeze(-1)
|
||||
|
@ -412,4 +416,4 @@ class GetItemACL(jt.Function):
|
|||
return self.mask.float()
|
||||
pass
|
||||
else:
|
||||
assert False, f"grad not implemented for {self.type_}"
|
||||
assert False, f"grad not implemented for {self.type_}"
|
||||
|
|
|
@ -53,11 +53,10 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
IndexOpRunner::IndexOpRunner() : BaseOpRunner("Index")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void IndexOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
|
@ -81,7 +80,7 @@ namespace jittor
|
|||
SliceV2OpRunner::SliceV2OpRunner() : BaseOpRunner("SliceV2")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void SliceV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
|
||||
|
@ -106,11 +105,10 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
IndexPutImplAccumulateOpRunner::IndexPutImplAccumulateOpRunner() : BaseOpRunner("IndexPutImplAccumulate")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void IndexPutImplAccumulateOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
|
@ -137,11 +135,9 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
StridedSliceAssignV2OpRunner::StridedSliceAssignV2OpRunner() : BaseOpRunner("StridedSliceAssignV2")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void StridedSliceAssignV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
|
@ -162,9 +158,7 @@ namespace jittor
|
|||
ret = aclnnStridedSliceAssignV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStridedSliceAssignV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
// syncRun();
|
||||
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
MaskedSelectOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
IndexOpRunner();
|
||||
};
|
||||
|
@ -27,6 +29,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SliceV2OpRunner();
|
||||
};
|
||||
|
@ -36,6 +39,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
IndexPutImplAccumulateOpRunner();
|
||||
};
|
||||
|
@ -45,6 +49,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
StridedSliceAssignV2OpRunner();
|
||||
};
|
||||
|
|
|
@ -11,14 +11,15 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def range_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -50,7 +51,8 @@ def range_forward(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class IndexACL(jt.Function):
|
||||
|
||||
|
@ -85,15 +87,13 @@ class IndexACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = range_forward("Range", [],
|
||||
output_dtypes=[tmp.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=range_attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
output_dtypes=[tmp.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=range_attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
broadcast_dims = list(range(len(inshape)))
|
||||
broadcast_dims.remove(d)
|
||||
result = jt.broadcast(result,
|
||||
shape=inshape,
|
||||
dims=broadcast_dims)
|
||||
result = jt.broadcast(result, shape=inshape, dims=broadcast_dims)
|
||||
results.append(result)
|
||||
|
||||
if len(results) != 1 or dim_input == None:
|
||||
|
@ -104,4 +104,4 @@ class IndexACL(jt.Function):
|
|||
return results
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
return grad_output
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
RangeOpRunner();
|
||||
};
|
||||
|
|
|
@ -12,15 +12,14 @@ from typing import Union
|
|||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
|
||||
def matmul_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -38,7 +37,7 @@ def matmul_forward(name: str,
|
|||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
|
@ -49,83 +48,83 @@ def matmul_forward(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
data=extra_data)
|
||||
|
||||
|
||||
class MatmulACL(jt.Function):
|
||||
|
||||
def __init__(self, trans_x2=False):
|
||||
super(MatmulACL, self).__init__()
|
||||
self.trans_x2 = trans_x2
|
||||
def __init__(self, trans_x2=False):
|
||||
super(MatmulACL, self).__init__()
|
||||
self.trans_x2 = trans_x2
|
||||
|
||||
def execute(self, x1, x2):
|
||||
self.input = [x1, x2]
|
||||
result = matmul_forward(
|
||||
"MatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
x1.shape[:-1] +
|
||||
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
|
||||
x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"matmul_trans_1\";"
|
||||
if self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||||
return result
|
||||
def execute(self, x1, x2):
|
||||
self.input = [x1, x2]
|
||||
result = matmul_forward(
|
||||
"MatMul", [x1, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
x1.shape[:-1] +
|
||||
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
|
||||
x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"matmul_trans_1\";"
|
||||
if self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
x1, x2 = self.input
|
||||
if len(x1) != len(x2):
|
||||
reshape_grad_x2 = True
|
||||
def grad(self, grad_output):
|
||||
x1, x2 = self.input
|
||||
if len(x1) != len(x2):
|
||||
reshape_grad_x2 = True
|
||||
else:
|
||||
reshape_grad_x2 = False
|
||||
grad_x1 = matmul_forward(
|
||||
"MatMul", [grad_output, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
|
||||
else grad_output.shape[:-1] + x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"matmul_trans_1\";"
|
||||
if not self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||||
|
||||
if self.trans_x2:
|
||||
if reshape_grad_x2:
|
||||
output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [
|
||||
grad_output.reshape(-1, grad_output.shape[-1]),
|
||||
x1.reshape(-1, x1.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
reshape_grad_x2 = False
|
||||
grad_x1 = matmul_forward(
|
||||
"MatMul", [grad_output, x2],
|
||||
output_dtypes=[x1.dtype],
|
||||
output_shapes=[
|
||||
grad_output.shape[:-1] + x2.shape[-2:-1]
|
||||
if not self.trans_x2 else grad_output.shape[:-1] +
|
||||
x2.shape[-1:]
|
||||
],
|
||||
attr_code="op.jt_name=\"matmul_trans_1\";"
|
||||
if not self.trans_x2 else "op.jt_name=\"matmul\";")[0]
|
||||
|
||||
if self.trans_x2:
|
||||
if reshape_grad_x2:
|
||||
output_shape = grad_output.shape[1:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [
|
||||
grad_output.reshape(-1, grad_output.shape[-1]),
|
||||
x1.reshape(-1, x1.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [grad_output, x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
output_shape = grad_output.shape[:-2] + grad_output.shape[
|
||||
-1:] + x1.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [grad_output, x1],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
if reshape_grad_x2:
|
||||
output_shape = x1.shape[1:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [
|
||||
x1.reshape(-1, x1.shape[-1]),
|
||||
grad_output.reshape(-1, grad_output.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
if reshape_grad_x2:
|
||||
output_shape = x1.shape[1:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [
|
||||
x1.reshape(-1, x1.shape[-1]),
|
||||
grad_output.reshape(-1, grad_output.shape[-1])
|
||||
],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
else:
|
||||
output_shape = x1.shape[:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [x1, grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
return grad_x1, grad_x2
|
||||
output_shape = x1.shape[:-2] + x1.shape[
|
||||
-1:] + grad_output.shape[-1:]
|
||||
grad_x2 = matmul_forward(
|
||||
"MatMul", [x1, grad_output],
|
||||
output_dtypes=[x2.dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
|
||||
return grad_x1, grad_x2
|
||||
|
|
|
@ -49,7 +49,7 @@ namespace jittor
|
|||
for (int idx = 0; idx < input_num; idx++)
|
||||
{
|
||||
inputTensors.push_back(nullptr);
|
||||
if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0) )
|
||||
if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0))
|
||||
{
|
||||
auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
|
@ -63,8 +63,8 @@ namespace jittor
|
|||
}
|
||||
void MatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
|
||||
ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
|
||||
ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
|
@ -72,6 +72,6 @@ namespace jittor
|
|||
}
|
||||
ret = aclnnMatmul(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
// syncRun();
|
||||
syncRun();
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ namespace jittor
|
|||
protected:
|
||||
void setupInputDesc() override;
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
MatMulOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def nantonum_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def nantonum_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class NanToNumACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -67,7 +69,7 @@ class NanToNumACL(jt.Function):
|
|||
"""
|
||||
self.attr_code = attr_code
|
||||
result = nantonum_cmd("NanToNum", [input],
|
||||
output_dtypes=[input[0].dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return result
|
||||
output_dtypes=[input[0].dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return result
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
NanToNumOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def pool_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,33 +52,32 @@ def pool_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class PoolACL(jt.Function):
|
||||
|
||||
def __init__(self,
|
||||
kernel_size,
|
||||
stride=None,
|
||||
padding=0,
|
||||
dilation=None,
|
||||
return_indices=None,
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
op='maximum'):
|
||||
kernel_size,
|
||||
stride=None,
|
||||
padding=0,
|
||||
dilation=None,
|
||||
return_indices=None,
|
||||
ceil_mode=False,
|
||||
count_include_pad=True,
|
||||
op='maximum'):
|
||||
self.kernel_size = kernel_size if isinstance(
|
||||
kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
stride = stride if stride else kernel_size
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride,
|
||||
stride)
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding,
|
||||
padding)
|
||||
padding)
|
||||
dilation = dilation if dilation else 1
|
||||
assert dilation == 1
|
||||
self.dilation = dilation if isinstance(
|
||||
dilation, tuple) else (dilation, dilation)
|
||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation,
|
||||
dilation)
|
||||
for item in self.kernel_size:
|
||||
if item <= 0:
|
||||
raise RuntimeError(
|
||||
f"kernel_size must be greater than zero, but got {item}"
|
||||
)
|
||||
f"kernel_size must be greater than zero, but got {item}")
|
||||
for item in self.stride:
|
||||
if item <= 0:
|
||||
raise RuntimeError(
|
||||
|
@ -108,7 +108,7 @@ class PoolACL(jt.Function):
|
|||
kernel_height, kernel_width = self.kernel_size[-2:]
|
||||
|
||||
output_height = (input_height + 2 * self.padding[0] -
|
||||
(kernel_height - 1) - 1) // self.stride[0] + 1
|
||||
(kernel_height - 1) - 1) // self.stride[0] + 1
|
||||
output_width = (input_width + 2 * self.padding[1] -
|
||||
(kernel_width - 1) - 1) // self.stride[1] + 1
|
||||
|
||||
|
@ -161,16 +161,16 @@ class PoolACL(jt.Function):
|
|||
output_dtypes = [input.dtype]
|
||||
if self.op == 'maximum':
|
||||
result = pool_cmd("MaxpoolBackward",
|
||||
inputs=[grad_output, input, self.index],
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=[grad_output, input, self.index],
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)[0]
|
||||
elif self.op == 'mean':
|
||||
result = pool_cmd("AvgpoolBackward",
|
||||
inputs=[grad_output, input],
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=[grad_output, input],
|
||||
output_dtypes=output_dtypes,
|
||||
output_shapes=output_shapes,
|
||||
attr_code=attr_code)[0]
|
||||
else:
|
||||
raise ValueError('no this type pool')
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -71,7 +71,6 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
AvgpoolOpRunner::AvgpoolOpRunner() : BaseOpRunner("Avgpool")
|
||||
{
|
||||
use_nchw = true;
|
||||
|
@ -109,7 +108,6 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
MaxpoolBackwardOpRunner::MaxpoolBackwardOpRunner() : BaseOpRunner("MaxpoolBackward")
|
||||
{
|
||||
use_nchw = true;
|
||||
|
@ -150,8 +148,6 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
AvgpoolBackwardOpRunner::AvgpoolBackwardOpRunner() : BaseOpRunner("AvgpoolBackward")
|
||||
{
|
||||
use_nchw = true;
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
MaxpoolOpRunner();
|
||||
};
|
||||
|
@ -18,16 +19,17 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
AvgpoolOpRunner();
|
||||
};
|
||||
|
||||
|
||||
class MaxpoolBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
MaxpoolBackwardOpRunner();
|
||||
};
|
||||
|
@ -37,6 +39,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
AvgpoolBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -44,7 +44,7 @@ namespace jittor
|
|||
void RandomOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<RandomAttr *>(op_attr.get());
|
||||
if(name == "RandomUniform")
|
||||
if (name == "RandomUniform")
|
||||
{
|
||||
ret = aclnnInplaceUniformGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
|
||||
|
||||
|
@ -58,7 +58,7 @@ namespace jittor
|
|||
ret = aclnnInplaceUniform(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnInplaceUniform failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
}
|
||||
else if(name == "RandomNormal")
|
||||
else if (name == "RandomNormal")
|
||||
{
|
||||
ret = aclnnInplaceNormalGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
|
||||
|
||||
|
@ -76,7 +76,7 @@ namespace jittor
|
|||
{
|
||||
LOGf << "Not supported random type : " << name;
|
||||
}
|
||||
// syncRun();
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -28,7 +28,6 @@
|
|||
#include "aclnn/aclnn.h"
|
||||
#include "reduce_op_acl.h"
|
||||
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
ReduceOpRunner::ReduceOpRunner() : BaseOpRunner("reduce")
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def relu_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def relu_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class ReLUACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -60,9 +62,9 @@ class ReLUACL(jt.Function):
|
|||
x = x.float32()
|
||||
self.input = x
|
||||
result = relu_cmd("ReLU", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"unary\";")[0]
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"unary\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
|
@ -72,11 +74,11 @@ class ReLUACL(jt.Function):
|
|||
output_shapes=[self.input.shape],
|
||||
attr_code="op.jt_name=\"binary\";")[0]
|
||||
grad_input = relu_cmd("Mul", [grad_output, mask],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code="op.jt_name=\"binary\";")[0]
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code="op.jt_name=\"binary\";")[0]
|
||||
return grad_input
|
||||
|
||||
|
||||
|
||||
class LeakyReLUACL(jt.Function):
|
||||
|
||||
|
@ -93,9 +95,9 @@ class LeakyReLUACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = relu_cmd("LeakyReLU", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code=attr_code)[0]
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code=attr_code)[0]
|
||||
self.negative_slope = negative_slope
|
||||
return result
|
||||
|
||||
|
@ -107,9 +109,8 @@ class LeakyReLUACL(jt.Function):
|
|||
attr->selfIsResult = false;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = relu_cmd("LeakyReLUBackward",
|
||||
[grad_output, self.input],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
grad_input = relu_cmd("LeakyReLUBackward", [grad_output, self.input],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
LeakyReLUOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
LeakyReLUBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def rope_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def rope_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class RopeACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -67,16 +69,16 @@ class RopeACL(jt.Function):
|
|||
assert freq_cos is not None and freq_sin is not None
|
||||
inputs = [xq, xk, freq_cos, freq_sin]
|
||||
results = rope_cmd("RotaryPosEmb",
|
||||
inputs,
|
||||
output_dtypes=[
|
||||
xq.dtype,
|
||||
],
|
||||
output_shapes=[
|
||||
xq.shape,
|
||||
],
|
||||
attr_code=attr_code)
|
||||
inputs,
|
||||
output_dtypes=[
|
||||
xq.dtype,
|
||||
],
|
||||
output_shapes=[
|
||||
xq.shape,
|
||||
],
|
||||
attr_code=attr_code)
|
||||
results[0].sync()
|
||||
return inputs[0], inputs[1]
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
return grad_output
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
RotaryPosEmbOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def setitem_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def setitem_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
def setitem_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
|
@ -90,7 +92,8 @@ def setitem_forward(name: str,
|
|||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
data=extra_data)
|
||||
|
||||
|
||||
def caculate_shape(tensors):
|
||||
if isinstance(tensors, jt.Var):
|
||||
|
@ -105,6 +108,7 @@ def caculate_shape(tensors):
|
|||
else:
|
||||
assert False, f"not implemented for {type(tensors)}"
|
||||
|
||||
|
||||
def can_broadcast_and_shape(shape1, shape2):
|
||||
"""
|
||||
检查两个张量是否可以广播,并返回广播后的形状。
|
||||
|
@ -144,6 +148,7 @@ def can_broadcast_and_shape(shape1, shape2):
|
|||
|
||||
return True, tuple(broadcast_shape)
|
||||
|
||||
|
||||
class SetItemACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -180,9 +185,9 @@ class SetItemACL(jt.Function):
|
|||
op.jt_name = "inplacemaskedscatter";
|
||||
"""
|
||||
result = setitem_cmd("InplaceMaskedScatter",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
# assert isinstance(value,jt.Var), "value must be jt.Var"
|
||||
|
@ -199,7 +204,7 @@ class SetItemACL(jt.Function):
|
|||
contains_slice = False
|
||||
for s in slices:
|
||||
if not isinstance(s, jt.Var) and (isinstance(s, slice)
|
||||
or s == Ellipsis):
|
||||
or s == Ellipsis):
|
||||
contains_slice = True
|
||||
break
|
||||
if not contains_slice:
|
||||
|
@ -220,9 +225,9 @@ class SetItemACL(jt.Function):
|
|||
self.value_shape = value_shape
|
||||
for ii in slices:
|
||||
indices.append(jt.Var(ii).int32())
|
||||
if isinstance(slices[0], jt.Var) or isinstance(
|
||||
slices[0], int) or isinstance(
|
||||
slices[0], list) or isinstance(slices[0], tuple):
|
||||
if isinstance(slices[0],
|
||||
jt.Var) or isinstance(slices[0], int) or isinstance(
|
||||
slices[0], list) or isinstance(slices[0], tuple):
|
||||
self.indices = indices
|
||||
self.type_ = 'index'
|
||||
attr_code = f"""
|
||||
|
@ -231,9 +236,9 @@ class SetItemACL(jt.Function):
|
|||
inputs = [value] + indices
|
||||
outputs = [x.clone()]
|
||||
result = setitem_cmd("IndexPutImpl",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
# result.sync()
|
||||
return result
|
||||
assert "not support"
|
||||
|
@ -244,8 +249,7 @@ class SetItemACL(jt.Function):
|
|||
if not isinstance(s, jt.Var) and s == Ellipsis:
|
||||
slices = slices[:slices.index(s)] + [
|
||||
slice(None, None, None)
|
||||
] * (x_dim - len(slices) + 1) + slices[slices.index(s) +
|
||||
1:]
|
||||
] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:]
|
||||
break
|
||||
slices = tuple(slices)
|
||||
self.input_slice = slices
|
||||
|
@ -335,10 +339,10 @@ class SetItemACL(jt.Function):
|
|||
inputs = [value]
|
||||
outputs = [x.clone()]
|
||||
result = setitem_forward("StridedSliceAssignV2",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
if expand_dim:
|
||||
result = result.squeeze(-1)
|
||||
# result.sync()
|
||||
|
@ -349,4 +353,4 @@ class SetItemACL(jt.Function):
|
|||
if self.value_var:
|
||||
value_grad = grad_output[self.input_slice]
|
||||
grad_output[self.input_slice] = jt.zeros(self.value_shape)
|
||||
return grad_output, None, value_grad
|
||||
return grad_output, None, value_grad
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace jittor
|
|||
InplaceMaskedScatterOpRunner::InplaceMaskedScatterOpRunner() : BaseOpRunner("InplaceMaskedScatter")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void InplaceMaskedScatterOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnInplaceMaskedScatterGetWorkspaceSize(outputTensors[0], inputTensors[0], inputTensors[1], &workspaceSize, &executor);
|
||||
|
@ -55,7 +55,7 @@ namespace jittor
|
|||
IndexPutImplOpRunner::IndexPutImplOpRunner() : BaseOpRunner("IndexPutImpl")
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void IndexPutImplOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto input_num = in_.size();
|
||||
|
@ -77,7 +77,7 @@ namespace jittor
|
|||
ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
// syncRun();
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
InplaceMaskedScatterOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
IndexPutImplOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def sigmoid_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def sigmoid_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class SigmoidACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -64,9 +66,9 @@ class SigmoidACL(jt.Function):
|
|||
op.jt_name = "sigmoid";
|
||||
"""
|
||||
result = sigmoid_cmd("Sigmoid",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
self.output = result
|
||||
return result
|
||||
|
||||
|
@ -77,7 +79,7 @@ class SigmoidACL(jt.Function):
|
|||
inputs = [grad_output, self.output]
|
||||
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
grad_input = sigmoid_cmd("SigmoidBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SigmoidOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SigmoidBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def silu_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def silu_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class SiLUACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -65,9 +67,9 @@ class SiLUACL(jt.Function):
|
|||
op.jt_name = "silu";
|
||||
"""
|
||||
result = silu_cmd("SiLU",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
|
@ -77,7 +79,7 @@ class SiLUACL(jt.Function):
|
|||
inputs = [grad_output, self.input]
|
||||
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
grad_input = silu_cmd("SiLUBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SiLUOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SiLUBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def softmax_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def softmax_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class SoftmaxACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -68,9 +70,9 @@ class SoftmaxACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = softmax_cmd("Softmax",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
self.output = result
|
||||
return result
|
||||
|
||||
|
@ -84,7 +86,7 @@ class SoftmaxACL(jt.Function):
|
|||
inputs = [grad_output, self.output]
|
||||
outputs = [jt.empty(grad_output.shape)]
|
||||
grad_input = softmax_cmd("SoftmaxBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SoftmaxOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
SoftmaxBackwardOpRunner();
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def stack_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def stack_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class StackACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -60,15 +62,12 @@ class StackACL(jt.Function):
|
|||
if type(input_tensors) is tuple:
|
||||
input_tensors = list(input_tensors)
|
||||
assert type(input_tensors) is list
|
||||
assert -1 * len(input_tensors) - 1 <= dim and dim <= len(
|
||||
input_tensors)
|
||||
assert -1 * len(input_tensors) - 1 <= dim and dim <= len(input_tensors)
|
||||
for i in range(len(input_tensors)):
|
||||
if input_tensors[i].dtype != input_tensors[0].dtype:
|
||||
raise ValueError(
|
||||
"All input tensors must have the same dtype")
|
||||
raise ValueError("All input tensors must have the same dtype")
|
||||
if input_tensors[i].shape != input_tensors[0].shape:
|
||||
raise ValueError(
|
||||
"All input tensors must have the same shape")
|
||||
raise ValueError("All input tensors must have the same shape")
|
||||
self.input = input_tensors
|
||||
input_shape = list(input_tensors[0].shape)
|
||||
output_shape = input_shape[:dim] + [len(input_tensors)
|
||||
|
@ -82,10 +81,10 @@ class StackACL(jt.Function):
|
|||
"""
|
||||
self.attr_code = attr_code
|
||||
result = stack_cmd("Stack",
|
||||
input_tensors,
|
||||
output_dtypes=[input_tensors[0].dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
input_tensors,
|
||||
output_dtypes=[input_tensors[0].dtype],
|
||||
output_shapes=[output_shape],
|
||||
attr_code=self.attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
|
@ -110,7 +109,7 @@ class StackACL(jt.Function):
|
|||
"""
|
||||
|
||||
result = stack_cmd("SplitWithSize", [grad_output],
|
||||
output_dtypes=dtypeVec,
|
||||
output_shapes=shapeVec,
|
||||
attr_code=attr_code)
|
||||
return result
|
||||
output_dtypes=dtypeVec,
|
||||
output_shapes=shapeVec,
|
||||
attr_code=attr_code)
|
||||
return result
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
StackOpRunner();
|
||||
};
|
||||
|
|
|
@ -48,7 +48,7 @@ namespace jittor
|
|||
|
||||
ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
//syncRun();
|
||||
// syncRun();
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ namespace jittor
|
|||
struct TernaryOpRunner : public BaseOpRunner
|
||||
{
|
||||
TernaryOpRunner();
|
||||
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
};
|
||||
|
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def transpose_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def transpose_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class TransPoseACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -75,9 +77,9 @@ class TransPoseACL(jt.Function):
|
|||
# calculate output shape
|
||||
output_shape = [x.shape[i] for i in dim]
|
||||
output = transpose_cmd("Transpose", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
self.dim = dim
|
||||
return output
|
||||
|
||||
|
@ -93,7 +95,7 @@ class TransPoseACL(jt.Function):
|
|||
op.op_attr.reset(attr);
|
||||
"""
|
||||
output = transpose_cmd("Transpose", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
return output
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
TransposeOpRunner();
|
||||
};
|
||||
|
|
|
@ -53,7 +53,7 @@ namespace jittor
|
|||
|
||||
ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
//syncRun();
|
||||
// syncRun();
|
||||
return;
|
||||
}
|
||||
}
|
|
@ -11,13 +11,14 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
|
||||
def where_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
|
@ -51,6 +52,7 @@ def where_cmd(name: str,
|
|||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
|
||||
class NonzeroACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -63,15 +65,16 @@ class NonzeroACL(jt.Function):
|
|||
nonzero_cnt = (x != 0.0).sum().item()
|
||||
|
||||
result = where_cmd("Nonzero", [x],
|
||||
output_dtypes=['int64'],
|
||||
output_shapes=[(nonzero_cnt, x.ndim)],
|
||||
attr_code=attr_code)[0]
|
||||
output_dtypes=['int64'],
|
||||
output_shapes=[(nonzero_cnt, x.ndim)],
|
||||
attr_code=attr_code)[0]
|
||||
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class WhereACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -104,9 +107,9 @@ class WhereACL(jt.Function):
|
|||
self.y = y
|
||||
|
||||
result = where_cmd("Where", [condition, x, y],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
|
@ -115,12 +118,12 @@ class WhereACL(jt.Function):
|
|||
else:
|
||||
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
|
||||
grad_x = where_cmd("Where", [self.condition, grad_output, tmp],
|
||||
output_dtypes=[self.x.dtype],
|
||||
output_shapes=[self.x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
output_dtypes=[self.x.dtype],
|
||||
output_shapes=[self.x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
|
||||
grad_y = where_cmd("Where", [self.condition, tmp, grad_output],
|
||||
output_dtypes=[self.y.dtype],
|
||||
output_shapes=[self.y.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
output_dtypes=[self.y.dtype],
|
||||
output_shapes=[self.y.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return grad_output, grad_x, grad_y
|
||||
|
|
|
@ -53,7 +53,6 @@ namespace jittor
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
NonzeroOpRunner::NonzeroOpRunner() : BaseOpRunner("Nonzero")
|
||||
{
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
WhereOpRunner();
|
||||
};
|
||||
|
@ -18,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
NonzeroOpRunner();
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue