Merge branch 'main' into splits

This commit is contained in:
Yuxuan Han 2024-12-23 15:57:22 +08:00 committed by GitHub
commit feeee7b3e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 728 additions and 669 deletions

View File

@ -158,33 +158,8 @@ def acl_cmd(name: str,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
# inputs: list,
# output_dtypes: list,
# output_shapes: list,
# attr_code: str = ""):
# input_code = ''
# for i in range(len(inputs)):
# input_code += f"op.add(in{i}, true);\n"
# output_code = ''
# for i in range(len(output_dtypes)):
# output_code += f"op.add(out{i}, false);\n"
# # read the tmp_file.cpp to the cuda_header
# with open(
# "/home/ma-user/work/zy/JittorHW/python/jittor/extern/acl/tmp_file.cpp",
# "r") as f:
# cuda_header = f.read()
# import jittor as jt
# return jt.code(output_shapes,
# output_dtypes,
# inputs,
# cuda_header=cuda_header,
# cuda_src=f"""
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
# read the tmp_file.cpp to the cuda_header
cuda_header = '#include "acl/aclops/aclops.h"'
import jittor as jt
outputs_ = []
@ -194,11 +169,10 @@ def acl_cmd(name: str,
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
# print(f'{name } output_dtypes', output_dtypes)
# print(f'{name } output_shapes', output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
# print(f'{name } outputs_', outputs_)
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
@ -240,11 +214,9 @@ def acl_cmd_forward(name: str,
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
# print(f'{name } output_dtypes', output_dtypes)
# print(f'{name } output_shapes', output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
# print(f'{name } outputs_', outputs_)
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
@ -269,6 +241,31 @@ def acl_cmd_forward(name: str,
def change_function():
import jittor as jt
from jittor import Function
from .aclops.flashattention_op import FlashAttentionACL
from .aclops.conv_op import ConvACL
from .aclops.pool_op import PoolACL
from .aclops.nantonum_op import NanToNumACL
from .aclops.stack_op import StackACL
from .aclops.rope_op import RopeACL
from .aclops.softmax_op import SoftmaxACL
from .aclops.sigmoid_op import SigmoidACL
from .aclops.silu_op import SiLUACL
from .aclops.dropout_op import DropoutACL
from .aclops.relu_op import LeakyReLUACL
from .aclops.flip_op import FlipACL
from .aclops.concat_op import ConcatACL
from .aclops.gather_scatter_op import GatherACL
from .aclops.cumsum_op import CumsumACL
from .aclops.index_op import IndexACL
from .aclops.gather_scatter_op import ScatterACL
from .aclops.where_op import WhereACL
from .aclops.where_op import NonzeroACL
from .aclops.floor_op import FloorIntACL
from .aclops.getitem_op import GetItemACL
from .aclops.setitem_op import SetItemACL
from .aclops.bmm_op import BmmACL
from .aclops.matmul_op import MatmulACL
from .aclops.transpose_op import TransPoseACL
from .aclops.triu_op import TriuACL
@ -276,6 +273,7 @@ def change_function():
return TriuACL()(x, diagonal)
from .aclops.conv_op import ConvACL
def conv_acl(x,
weight,
bias=None,
@ -385,6 +383,7 @@ def change_function():
self.padding, self.dilation, self.groups)
return ret
from .aclops.flip_op import FlipACL
def flip_acl(x, dim):
return FlipACL()(x, dim)
@ -441,6 +440,7 @@ def change_function():
return FloorIntACL()(x)
from .aclops.getitem_op import GetItemACL
def getitem_acl(x, slices, return_x=None):
# Transform numpy int to int
if isinstance(slices, (np.int8, np.int16, np.int32, np.int64)):
@ -489,19 +489,25 @@ def change_function():
return result
from .aclops.setitem_op import SetItemACL
def setitem_acl(x, slices, value):
res = SetItemACL()(x, slices, value)
return x.assign(res)
from .aclops.bmm_op import BmmACL
def bmm_acl(x1, x2):
return BmmACL()(x1, x2)
def bmm_transpose_acl(x1, x2):
return BmmACL(True)(x1, x2)
from .aclops.matmul_op import MatmulACL
def matmul_acl(x1, x2):
return MatmulACL()(x1, x2)
@ -512,6 +518,7 @@ def change_function():
def transpose_acl(x, *dim):
return TransPoseACL()(x, *dim)
class ReLUACL(Function):
def __init__(self):
@ -550,6 +557,7 @@ def change_function():
return ReLUACL()(x)
from .aclops.relu_op import LeakyReLUACL
class LeakyReLU(jt.nn.Module):
def __init__(self, negative_slope=0.01):
@ -563,6 +571,7 @@ def change_function():
return LeakyReLUACL()(x, scale)
from .aclops.dropout_op import DropoutACL
class Dropout(jt.nn.Module):
def __init__(self, p=0.5, is_train=False):
@ -620,8 +629,7 @@ def change_function():
# def execute(self, x):
# res = embedding_acl(x, self.weight)
# return res
from .aclops.softmax_op import SoftmaxACL
class Softmax(jt.nn.Module):
def __init__(self):
@ -642,7 +650,7 @@ def change_function():
return StackACL()(x, dim)
from .aclops.nantonum_op import NanToNumACL
def isnan_acl(x):
tonum = NanToNumACL()(x, -1.0)
return jt.not_equal(x, tonum).logical_and(
@ -694,8 +702,7 @@ def change_function():
jt.nn.conv2d = warp(jt.nn.conv2d, conv_acl)
jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D)
jt.nn.Conv = warp(jt.nn.Conv, Conv2D)
from .aclops.pool_op import PoolACL
jt.nn.Pool = warp(jt.nn.Pool, PoolACL)
jt.flip = warp(jt.flip, flip_acl)
@ -743,6 +750,7 @@ def change_function():
jt.Var.__setitem__ = lambda x, slices, value: warp(
fake_setitem, setitem_acl, name='setitem')(x, slices, value)
fake_matmul = jt.Var.matmul
jt.nn.bmm = warp(jt.nn.bmm, bmm_acl)
jt.bmm = warp(jt.bmm, bmm_acl)
jt.nn.matmul = warp(jt.matmul, matmul_acl)
@ -750,6 +758,7 @@ def change_function():
jt.nn.matmul_transpose = warp(jt.nn.matmul_transpose, matmul_transpose_acl)
jt.nn.bmm_transpose = warp(jt.nn.bmm_transpose, bmm_transpose_acl)
jt.bmm_transpose = warp(jt.bmm_transpose, bmm_transpose_acl)
jt.Var.__matmul__ = lambda x, y: warp(fake_matmul, matmul_acl)(x, y)
jt.transpose = warp(jt.transpose, transpose_acl)
fake_transpose = jt.transpose
@ -784,7 +793,7 @@ def change_function():
# from .aclops.norms_op import BatchNormACL,LayerNormACL
# jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL)
# jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL)
from .aclops.flashattention_op import FlashAttentionACL
jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL)
jt.isnan = warp(jt.isnan, isnan_acl)
jt.isinf = warp(jt.isinf, isinf_acl)

View File

@ -154,9 +154,9 @@ namespace jittor
std::queue<Op *> queue;
for (Op *op : fop->ops)
op_indeg[op] = 0;
op_indeg[op] = 0;
map<Op *, vector<Op *>> out_map;
map<Op *, vector<Op *>> out_map;
map<Var *, vector<Op *>> from;
int len = 0;
@ -303,15 +303,12 @@ namespace jittor
op.op_attr.reset(attr);
op.add(rop->y, false);
op.run();
aclrtSynchronizeStream(aclstream);
}
else if (op->name() == string("broadcast_to"))
{
auto bop = (BroadcastToOp *)op;
AclOpRunner op("Expand");
if (bop->x->shape.size() == 1 && bop->x->shape[0] == 1)
{
aclrtSynchronizeStream(aclstream);
}
op.jt_name = "expand";
NanoVector xshape, xshape_bk = bop->x->shape;
NanoVector zshape = bop->z->shape;
@ -333,13 +330,13 @@ namespace jittor
op.add(bop->z, false);
op.run();
bop->x->shape = xshape_bk;
// aclrtSynchronizeStream(aclstream);
aclrtSynchronizeStream(aclstream);
}
else if (op->name() == string("fuse_transpose"))
{
// replace fuse_transpose with transpose
auto top = (TransposeOp *)op;
AclOpRunner op("Transpose");
TransposeOpRunner op;
op.add(top->x, true);
op.add(top->y, false);
op.jt_name = "transpose";
@ -497,4 +494,4 @@ namespace jittor
}
}
} // jittor
} // jittor

View File

@ -274,7 +274,7 @@ namespace jittor
// ret = aclrtSynchronizeStream(aclstream);
// CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
// }
// 6. 释放aclTensor和aclScalar需要根据具体API的接口定义修改
// destroy tensor
// for (int idx = 0; idx < input_num; idx++)

View File

@ -33,19 +33,20 @@ namespace jittor
// Common functionality for adding input/output variables
void add(Var *v, bool is_input);
virtual void setupInputDesc();
void cleanupDesc();
virtual void setupOutputDesc();
virtual void syncRun();
void checkRet(aclnnStatus ret);
// Base run method with common operator lookup logic
void run();
protected:
// Virtual method for specific operator execution
virtual void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) = 0;

View File

@ -105,13 +105,13 @@ namespace jittor
void BaseOpRunner::syncRun()
{
if(sync_run) {
ret = aclrtSynchronizeStream(aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
if (sync_run)
{
// ret = aclrtSynchronizeStream(aclstream);
// CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return);
}
}
void BaseOpRunner::checkRet(aclnnStatus ret)
{
if (ret != ACL_SUCCESS)

View File

@ -116,7 +116,7 @@ namespace jittor
ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
//syncRun();
syncRun();
aclDestroyScalar(alpha);
return;

View File

@ -12,15 +12,14 @@ from typing import Union
from collections.abc import Sequence, Iterable
def acl_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -38,7 +37,7 @@ def acl_cmd(name: str,
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
@ -49,8 +48,9 @@ def acl_cmd(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
class BmmACL(jt.Function):
def __init__(self, trans_x2=False):
@ -59,16 +59,14 @@ class BmmACL(jt.Function):
def execute(self, x1, x2):
self.input = [x1, x2]
result = acl_cmd(
"BatchMatMul", [x1, x2],
output_dtypes=[x1.dtype],
output_shapes=[
x1.shape[:-1] +
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
x2.shape[-1:]
],
attr_code="op.jt_name=\"bmm_trans_1\";"
if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
result = acl_cmd("BatchMatMul", [x1, x2],
output_dtypes=[x1.dtype],
output_shapes=[
x1.shape[:-1] + x2.shape[-2:-1] if self.trans_x2
else x1.shape[:-1] + x2.shape[-1:]
],
attr_code="op.jt_name=\"bmm_trans_1\";"
if self.trans_x2 else "op.jt_name=\"bmm\";")[0]
return result
@ -78,57 +76,53 @@ class BmmACL(jt.Function):
reshape_grad_x2 = True
else:
reshape_grad_x2 = False
grad_x1 = acl_cmd("BatchMatMul", [grad_output, x2],
output_dtypes=[x1.dtype],
output_shapes=[
grad_output.shape[:-1] +
x2.shape[-2:-1] if not self.trans_x2 else
grad_output.shape[:-1] + x1.shape[-1:]
],
attr_code="op.jt_name=\"bmm_trans_1\";" if
not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
grad_x1 = acl_cmd(
"BatchMatMul", [grad_output, x2],
output_dtypes=[x1.dtype],
output_shapes=[
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
else grad_output.shape[:-1] + x1.shape[-1:]
],
attr_code="op.jt_name=\"bmm_trans_1\";"
if not self.trans_x2 else "op.jt_name=\"bmm\";")[0]
if self.trans_x2:
if reshape_grad_x2:
output_shape = grad_output.shape[1:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = acl_cmd(
"BatchMatMul", [
grad_output.reshape(-1, grad_output.shape[-1]),
x1.reshape(-1, x1.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
grad_x2 = acl_cmd("BatchMatMul", [
grad_output.reshape(-1, grad_output.shape[-1]),
x1.reshape(-1, x1.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
else:
output_shape = grad_output.shape[:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = acl_cmd(
"BatchMatMul", [grad_output, x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
grad_x2 = acl_cmd("BatchMatMul", [grad_output, x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
else:
if reshape_grad_x2:
output_shape = x1.shape[1:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = acl_cmd(
"BatchMatMul", [
x1.reshape(-1, x1.shape[-1]),
grad_output.reshape(-1, grad_output.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
grad_x2 = acl_cmd("BatchMatMul", [
x1.reshape(-1, x1.shape[-1]),
grad_output.reshape(-1, grad_output.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
else:
output_shape = x1.shape[:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = acl_cmd(
"BatchMatMul", [x1, grad_output],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
grad_x2 = acl_cmd("BatchMatMul", [x1, grad_output],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"bmm_trans_0\";")[0]
if len(grad_x1.shape) > len(x1.shape):
grad_x1 = grad_x1.sum(0)
if len(grad_x2.shape) > len(x2.shape):
grad_x2 = grad_x2.sum(0)
return grad_x1, grad_x2
return grad_x1, grad_x2

View File

@ -63,8 +63,8 @@ namespace jittor
}
void BatchMatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
ret = aclnnBatchMatMulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
if (workspaceSize > 0)
{
@ -72,6 +72,6 @@ namespace jittor
}
ret = aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnbatchMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
// syncRun();
syncRun();
}
}

View File

@ -10,6 +10,7 @@ namespace jittor
protected:
void setupInputDesc() override;
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
BatchMatMulOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def concat_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def concat_cmd(name: str,
{attr_code}
op.run();""")
class ConcatACL(jt.Function):
def __init__(self):
@ -114,13 +116,11 @@ class ConcatACL(jt.Function):
self.dim = dim
for i in range(len(input_tensors)):
if input_tensors[i].dtype != input_tensors[0].dtype:
raise ValueError(
"All input tensors must have the same dtype")
raise ValueError("All input tensors must have the same dtype")
if input_tensors[i].shape[:dim] != input_tensors[
0].shape[:dim] or input_tensors[i].shape[
dim + 1:] != input_tensors[0].shape[dim + 1:]:
raise ValueError(
"All input tensors must have the same shape")
raise ValueError("All input tensors must have the same shape")
attr_code = f"""
op.jt_name = "concat";
ConcatAttr *attr = new ConcatAttr();
@ -133,15 +133,13 @@ class ConcatACL(jt.Function):
input_tensors,
output_dtypes=[input_tensors[0].dtype],
output_shapes=[
jt.empty(self.calculate_output_shape(input_tensors,
dim)).shape
jt.empty(self.calculate_output_shape(input_tensors, dim)).shape
],
attr_code=attr_code)[0]
return result
def _grad(self, *args):
new_args = ((args[i] if i >= 0 else None)
for i in self.output_mask)
new_args = ((args[i] if i >= 0 else None) for i in self.output_mask)
ret = self.grad(*new_args)
new_ret = []
for i, r in enumerate(ret):
@ -185,4 +183,4 @@ class ConcatACL(jt.Function):
output_dtypes=dtypeVec,
output_shapes=shapeVec,
attr_code=attr_code)
return result
return result

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
ConcatOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SplitWithSizeOpRunner();
};

View File

@ -22,16 +22,18 @@ def _ntuple(n):
return parse
_pair = _ntuple(2)
def conv_forward(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -49,7 +51,7 @@ def conv_forward(name: str,
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
@ -60,16 +62,17 @@ def conv_forward(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
def conv_forward(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
# TODO: not done for now
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
@ -88,7 +91,7 @@ def conv_forward(name: str,
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
@ -99,32 +102,33 @@ def conv_forward(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
class ConvACL(jt.Function):
def execute(self,
x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
self.input = x
self.weight = weight
self.bias = bias
padding = _pair(padding)
stride = _pair(stride)
dilation = _pair(dilation)
out_channels = weight.shape[0]
if groups <= 0:
raise ValueError("groups must be a positive integer")
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
attr_code = f"""
def execute(self,
x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
self.input = x
self.weight = weight
self.bias = bias
padding = _pair(padding)
stride = _pair(stride)
dilation = _pair(dilation)
out_channels = weight.shape[0]
if groups <= 0:
raise ValueError("groups must be a positive integer")
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
attr_code = f"""
op.jt_name = "conv2d";
ConvAttr *attr = new ConvAttr();
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
@ -134,49 +138,48 @@ class ConvACL(jt.Function):
attr->convOutPads = {{1,1}};
op.op_attr.reset(attr);
"""
input_height, input_width = x.shape[-2:]
kernel_height, kernel_width = weight.shape[-2:]
input_height, input_width = x.shape[-2:]
kernel_height, kernel_width = weight.shape[-2:]
output_height = (input_height + 2 * padding[0] - dilation[0] *
(kernel_height - 1) - 1) // stride[0] + 1
output_width = (input_width + 2 * padding[1] - dilation[1] *
(kernel_width - 1) - 1) // stride[1] + 1
output_height = (input_height + 2 * padding[0] - dilation[0] *
(kernel_height - 1) - 1) // stride[0] + 1
output_width = (input_width + 2 * padding[1] - dilation[1] *
(kernel_width - 1) - 1) // stride[1] + 1
output_shape = (x.shape[0], out_channels, output_height,
output_width)
output_shape = (x.shape[0], out_channels, output_height, output_width)
inputs = [x, weight]
if bias is not None:
inputs.append(bias)
result = conv_forward(
"Conv2d",
inputs,
output_dtypes=[x.dtype],
output_shapes=[output_shape],
attr_code=attr_code,
)[0]
return result
inputs = [x, weight]
if bias is not None:
inputs.append(bias)
result = conv_forward(
"Conv2d",
inputs,
output_dtypes=[x.dtype],
output_shapes=[output_shape],
attr_code=attr_code,
)[0]
return result
def grad(self, grad_output):
x = self.input
weight = self.weight
bias = self.bias
inputs = [grad_output, x, weight]
if bias is not None:
inputs.append(bias)
output_shapes = [x.shape, weight.shape]
output_dtypes = [x.dtype, weight.dtype]
if bias is not None:
output_shapes.append(bias.shape)
output_dtypes.append(bias.dtype)
else:
output_shapes.append([1])
output_dtypes.append(x.dtype)
padding = self.padding
stride = self.stride
dilation = self.dilation
groups = self.groups
attr_code = f"""
def grad(self, grad_output):
x = self.input
weight = self.weight
bias = self.bias
inputs = [grad_output, x, weight]
if bias is not None:
inputs.append(bias)
output_shapes = [x.shape, weight.shape]
output_dtypes = [x.dtype, weight.dtype]
if bias is not None:
output_shapes.append(bias.shape)
output_dtypes.append(bias.dtype)
else:
output_shapes.append([1])
output_dtypes.append(x.dtype)
padding = self.padding
stride = self.stride
dilation = self.dilation
groups = self.groups
attr_code = f"""
op.jt_name = "conv2dbackward";
ConvAttr *attr = new ConvAttr();
attr->convStrides = {{ {stride[0]}, {stride[1]} }};
@ -186,12 +189,12 @@ class ConvACL(jt.Function):
attr->convOutPads = {{ 1,1}};
op.op_attr.reset(attr);
"""
results = acl_cmd_forward("Conv2dBackward",
inputs,
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)
if self.bias is None:
return results[0], results[1]
results = acl_cmd_forward("Conv2dBackward",
inputs,
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)
if self.bias is None:
return results[0], results[1]
return results
return results

View File

@ -34,7 +34,7 @@ namespace jittor
{
use_nchw = true;
}
void ConvOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
// for conv
@ -66,7 +66,7 @@ namespace jittor
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
// syncRun();
syncRun();
aclDestroyIntArray(strides);
aclDestroyIntArray(pads);

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
ConvOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def cumsum_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def cumsum_cmd(name: str,
{attr_code}
op.run();""")
class CumsumACL(jt.Function):
def __init__(self):
@ -85,15 +87,15 @@ class CumsumACL(jt.Function):
op.op_attr.reset(attr);
"""
flipped_grad_output = cumsum_cmd("Flip", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=flip_attr_code)[0]
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=flip_attr_code)[0]
cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=cumsum_attr_code)[0]
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=cumsum_attr_code)[0]
grad_input = cumsum_cmd("Flip", [cumulative_grad],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=flip_attr_code)[0]
return grad_input
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
CumsumOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def dropout_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def dropout_cmd(name: str,
{attr_code}
op.run();""")
class DropoutACL(jt.Function):
def __init__(self):
@ -71,9 +73,9 @@ class DropoutACL(jt.Function):
op.op_attr.reset(attr);
"""
result = dropout_cmd("Dropout", [x],
output_dtypes=[x.dtype, "uint8"],
output_shapes=[x.shape, mask_shape],
attr_code=attr_code)
output_dtypes=[x.dtype, "uint8"],
output_shapes=[x.shape, mask_shape],
attr_code=attr_code)
self.maskout = result[1]
return result[0]
@ -85,8 +87,8 @@ class DropoutACL(jt.Function):
op.op_attr.reset(attr);
"""
grad_input = dropout_cmd("DropoutBackward",
[grad_output, self.maskout],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input
[grad_output, self.maskout],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
DropoutOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
DropoutBackwardOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def flashattention_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,21 +52,22 @@ def flashattention_cmd(name: str,
{attr_code}
op.run();""")
class FlashAttentionACL(jt.Function):
def __init__(self,
headnum,
layout="BNSD",
prefix=None,
qstart=None,
kvstart=None,
scale=1.0,
prob=1.0,
pretokens=2147483647,
nexttokens=2147483647,
innerprecise=0,
sparsemode=0,
psetype=1):
headnum,
layout="BNSD",
prefix=None,
qstart=None,
kvstart=None,
scale=1.0,
prob=1.0,
pretokens=2147483647,
nexttokens=2147483647,
innerprecise=0,
sparsemode=0,
psetype=1):
self.headnum = headnum
self.layout = layout
self.scale = scale
@ -116,9 +118,7 @@ class FlashAttentionACL(jt.Function):
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
self.kvstart = self.kvstart if self.kvstart else [
0 for _ in range(B)
]
self.kvstart = self.kvstart if self.kvstart else [0 for _ in range(B)]
self.hasRealshift = (not realshift == None)
self.hasDropmask = (not dropMask == None)
@ -126,8 +126,7 @@ class FlashAttentionACL(jt.Function):
self.hasAttenmask = (not attenMask == None)
# 待定目前设为nullptr
self.realshift = realshift if realshift else jt.zeros(
B, N, SQ, SKV)
self.realshift = realshift if realshift else jt.zeros(B, N, SQ, SKV)
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
self.paddingMask = paddingMask if paddingMask else jt.zeros(
B, N, SQ, SKV)
@ -207,4 +206,4 @@ class FlashAttentionACL(jt.Function):
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
attr_code=attr_code)
return result
return result

View File

@ -58,7 +58,6 @@ namespace jittor
return;
}
FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward")
{
}

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlashAttentionOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlashAttentionBackwardOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def flip_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def flip_cmd(name: str,
{attr_code}
op.run();""")
class FlipACL(jt.Function):
def __init__(self):
@ -70,14 +72,14 @@ class FlipACL(jt.Function):
"""
self.attr_code = attr_code
result = flip_cmd("Flip", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=self.attr_code)[0]
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=self.attr_code)[0]
return result
def grad(self, grad_output):
grad_input = flip_cmd("Flip", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=self.attr_code)[0]
return grad_input
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=self.attr_code)[0]
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlipOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def floor_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def floor_cmd(name: str,
{attr_code}
op.run();""")
class FloorIntACL(jt.Function):
def __init__(self):
@ -59,10 +61,10 @@ class FloorIntACL(jt.Function):
def execute(self, input):
self.shape = input.shape
result = floor_cmd("Floor", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code="op.jt_name=\"floor\";")[0]
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code="op.jt_name=\"floor\";")[0]
return result
def grad(self, grad_output):
return jt.zeros(self.shape, dtype=grad_output.dtype)
return jt.zeros(self.shape, dtype=grad_output.dtype)

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FloorOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def gather_scatter_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def gather_scatter_cmd(name: str,
{attr_code}
op.run();""")
class GatherACL(jt.Function):
def __init__(self):
@ -66,9 +68,9 @@ class GatherACL(jt.Function):
op.op_attr.reset(attr);
"""
result = gather_scatter_cmd("Gather", [input, index],
output_dtypes=[input.dtype],
output_shapes=[index.shape],
attr_code=attr_code)[0]
output_dtypes=[input.dtype],
output_shapes=[index.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
@ -80,12 +82,14 @@ class GatherACL(jt.Function):
attr->reduction = {1};
op.op_attr.reset(attr);
"""
grad_input = gather_scatter_cmd("Scatter", [tmp, self.index, grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[tmp.shape],
attr_code=attr_code)[0]
grad_input = gather_scatter_cmd("Scatter",
[tmp, self.index, grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[tmp.shape],
attr_code=attr_code)[0]
return grad_input
class ScatterACL(jt.Function):
def __init__(self):
@ -103,9 +107,9 @@ class ScatterACL(jt.Function):
op.op_attr.reset(attr);
"""
result = gather_scatter_cmd("Scatter", [input, self.index, src],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
@ -116,7 +120,7 @@ class ScatterACL(jt.Function):
op.op_attr.reset(attr);
"""
grad_input = gather_scatter_cmd("Gather", [grad_output, self.index],
output_dtypes=[grad_output.dtype],
output_shapes=[self.index.shape],
attr_code=attr_code)[0]
return grad_output, None, None, grad_input
output_dtypes=[grad_output.dtype],
output_shapes=[self.index.shape],
attr_code=attr_code)[0]
return grad_output, None, None, grad_input

View File

@ -77,5 +77,4 @@ namespace jittor
return;
}
}

View File

@ -9,16 +9,17 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
GatherOpRunner();
};
class ScatterOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
ScatterOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def getitem_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def getitem_cmd(name: str,
{attr_code}
op.run();""")
def getitem_forward(name: str,
inputs: list,
output_dtypes: list = None,
@ -90,7 +92,8 @@ def getitem_forward(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
def caculate_shape(tensors):
if isinstance(tensors, jt.Var):
@ -105,6 +108,7 @@ def caculate_shape(tensors):
else:
assert False, f"not implemented for {type(tensors)}"
def can_broadcast_and_shape(shape1, shape2):
"""
检查两个张量是否可以广播并返回广播后的形状
@ -144,6 +148,7 @@ def can_broadcast_and_shape(shape1, shape2):
return True, tuple(broadcast_shape)
class GetItemACL(jt.Function):
def __init__(self):
@ -174,9 +179,9 @@ class GetItemACL(jt.Function):
op.jt_name = "maskedselect";
"""
result = getitem_cmd("MaskedSelect",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
result = result[:output_len]
result.sync()
return result
@ -194,7 +199,7 @@ class GetItemACL(jt.Function):
contains_slice = False
for s in slices:
if not isinstance(s, jt.Var) and (isinstance(s, slice)
or s == Ellipsis):
or s == Ellipsis):
contains_slice = True
break
if not contains_slice:
@ -212,9 +217,9 @@ class GetItemACL(jt.Function):
output_shape = [1]
for ii in slices:
indices.append(jt.Var(ii).int32())
if isinstance(slices[0], jt.Var) or isinstance(
slices[0], int) or isinstance(
slices[0], list) or isinstance(slices[0], tuple):
if isinstance(slices[0],
jt.Var) or isinstance(slices[0], int) or isinstance(
slices[0], list) or isinstance(slices[0], tuple):
self.indices = indices
inputs = [x] + indices
attr_code = f"""
@ -222,10 +227,10 @@ class GetItemACL(jt.Function):
"""
self.type_ = 'index'
result = getitem_cmd("Index",
inputs=inputs,
output_dtypes=[x.dtype],
output_shapes=[output_shape],
attr_code=attr_code)[0]
inputs=inputs,
output_dtypes=[x.dtype],
output_shapes=[output_shape],
attr_code=attr_code)[0]
result.sync()
return result
assert contains_slice, "slice type error"
@ -235,8 +240,7 @@ class GetItemACL(jt.Function):
if not isinstance(s, jt.Var) and s == Ellipsis:
slices = slices[:slices.index(s)] + [
slice(None, None, None)
] * (x_dim - len(slices) + 1) + slices[slices.index(s) +
1:]
] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:]
break
slices = tuple(slices)
@ -313,11 +317,11 @@ class GetItemACL(jt.Function):
op.op_attr.reset(attr);
"""
result = getitem_forward("SliceV2",
inputs,
output_dtypes=[x.dtype],
output_shapes=[jt.empty(sizes).shape],
attr_code=attr_code,
extra_data=extra_data)[0]
inputs,
output_dtypes=[x.dtype],
output_shapes=[jt.empty(sizes).shape],
attr_code=attr_code,
extra_data=extra_data)[0]
self.squeeze_dims = squeeze_dims
for dim in squeeze_dims[::-1]:
result = jt.squeeze(result, dim)
@ -334,9 +338,9 @@ class GetItemACL(jt.Function):
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
# breakpoint()
result = getitem_cmd("IndexPutImplAccumulate",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
result.sync()
return result, None
elif self.type_ == 'slicev2':
@ -401,9 +405,9 @@ class GetItemACL(jt.Function):
inputs = [grad_output]
outputs = [jt.zeros(self.x_shape, dtype=grad_output.dtype)]
result = getitem_cmd("StridedSliceAssignV2",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
result.sync()
if expand_dim:
result = result.squeeze(-1)
@ -412,4 +416,4 @@ class GetItemACL(jt.Function):
return self.mask.float()
pass
else:
assert False, f"grad not implemented for {self.type_}"
assert False, f"grad not implemented for {self.type_}"

View File

@ -53,11 +53,10 @@ namespace jittor
return;
}
IndexOpRunner::IndexOpRunner() : BaseOpRunner("Index")
{
}
void IndexOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto input_num = in_.size();
@ -81,7 +80,7 @@ namespace jittor
SliceV2OpRunner::SliceV2OpRunner() : BaseOpRunner("SliceV2")
{
}
void SliceV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
@ -106,11 +105,10 @@ namespace jittor
return;
}
IndexPutImplAccumulateOpRunner::IndexPutImplAccumulateOpRunner() : BaseOpRunner("IndexPutImplAccumulate")
{
}
void IndexPutImplAccumulateOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto input_num = in_.size();
@ -137,11 +135,9 @@ namespace jittor
return;
}
StridedSliceAssignV2OpRunner::StridedSliceAssignV2OpRunner() : BaseOpRunner("StridedSliceAssignV2")
{
}
void StridedSliceAssignV2OpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
@ -162,9 +158,7 @@ namespace jittor
ret = aclnnStridedSliceAssignV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStridedSliceAssignV2 failed. ERROR: %d\n", name.c_str(), ret); return);
// syncRun();
syncRun();
return;
}

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
MaskedSelectOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
IndexOpRunner();
};
@ -27,6 +29,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SliceV2OpRunner();
};
@ -36,6 +39,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
IndexPutImplAccumulateOpRunner();
};
@ -45,6 +49,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
StridedSliceAssignV2OpRunner();
};

View File

@ -11,14 +11,15 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def range_forward(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -50,7 +51,8 @@ def range_forward(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
class IndexACL(jt.Function):
@ -85,15 +87,13 @@ class IndexACL(jt.Function):
op.op_attr.reset(attr);
"""
result = range_forward("Range", [],
output_dtypes=[tmp.dtype],
output_shapes=[tmp.shape],
attr_code=range_attr_code,
extra_data=extra_data)[0]
output_dtypes=[tmp.dtype],
output_shapes=[tmp.shape],
attr_code=range_attr_code,
extra_data=extra_data)[0]
broadcast_dims = list(range(len(inshape)))
broadcast_dims.remove(d)
result = jt.broadcast(result,
shape=inshape,
dims=broadcast_dims)
result = jt.broadcast(result, shape=inshape, dims=broadcast_dims)
results.append(result)
if len(results) != 1 or dim_input == None:
@ -104,4 +104,4 @@ class IndexACL(jt.Function):
return results
def grad(self, grad_output):
return grad_output
return grad_output

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
RangeOpRunner();
};

View File

@ -12,15 +12,14 @@ from typing import Union
from collections.abc import Sequence, Iterable
def matmul_forward(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None,
extra_data: dict = {}):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -38,7 +37,7 @@ def matmul_forward(name: str,
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
@ -49,83 +48,83 @@ def matmul_forward(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
class MatmulACL(jt.Function):
def __init__(self, trans_x2=False):
super(MatmulACL, self).__init__()
self.trans_x2 = trans_x2
def __init__(self, trans_x2=False):
super(MatmulACL, self).__init__()
self.trans_x2 = trans_x2
def execute(self, x1, x2):
self.input = [x1, x2]
result = matmul_forward(
"MatMul", [x1, x2],
output_dtypes=[x1.dtype],
output_shapes=[
x1.shape[:-1] +
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
x2.shape[-1:]
],
attr_code="op.jt_name=\"matmul_trans_1\";"
if self.trans_x2 else "op.jt_name=\"matmul\";")[0]
return result
def execute(self, x1, x2):
self.input = [x1, x2]
result = matmul_forward(
"MatMul", [x1, x2],
output_dtypes=[x1.dtype],
output_shapes=[
x1.shape[:-1] +
x2.shape[-2:-1] if self.trans_x2 else x1.shape[:-1] +
x2.shape[-1:]
],
attr_code="op.jt_name=\"matmul_trans_1\";"
if self.trans_x2 else "op.jt_name=\"matmul\";")[0]
return result
def grad(self, grad_output):
x1, x2 = self.input
if len(x1) != len(x2):
reshape_grad_x2 = True
def grad(self, grad_output):
x1, x2 = self.input
if len(x1) != len(x2):
reshape_grad_x2 = True
else:
reshape_grad_x2 = False
grad_x1 = matmul_forward(
"MatMul", [grad_output, x2],
output_dtypes=[x1.dtype],
output_shapes=[
grad_output.shape[:-1] + x2.shape[-2:-1] if not self.trans_x2
else grad_output.shape[:-1] + x2.shape[-1:]
],
attr_code="op.jt_name=\"matmul_trans_1\";"
if not self.trans_x2 else "op.jt_name=\"matmul\";")[0]
if self.trans_x2:
if reshape_grad_x2:
output_shape = grad_output.shape[1:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [
grad_output.reshape(-1, grad_output.shape[-1]),
x1.reshape(-1, x1.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
else:
reshape_grad_x2 = False
grad_x1 = matmul_forward(
"MatMul", [grad_output, x2],
output_dtypes=[x1.dtype],
output_shapes=[
grad_output.shape[:-1] + x2.shape[-2:-1]
if not self.trans_x2 else grad_output.shape[:-1] +
x2.shape[-1:]
],
attr_code="op.jt_name=\"matmul_trans_1\";"
if not self.trans_x2 else "op.jt_name=\"matmul\";")[0]
if self.trans_x2:
if reshape_grad_x2:
output_shape = grad_output.shape[1:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [
grad_output.reshape(-1, grad_output.shape[-1]),
x1.reshape(-1, x1.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
else:
output_shape = grad_output.shape[:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [grad_output, x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
output_shape = grad_output.shape[:-2] + grad_output.shape[
-1:] + x1.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [grad_output, x1],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
else:
if reshape_grad_x2:
output_shape = x1.shape[1:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [
x1.reshape(-1, x1.shape[-1]),
grad_output.reshape(-1, grad_output.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
else:
if reshape_grad_x2:
output_shape = x1.shape[1:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [
x1.reshape(-1, x1.shape[-1]),
grad_output.reshape(-1, grad_output.shape[-1])
],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
else:
output_shape = x1.shape[:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [x1, grad_output],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
return grad_x1, grad_x2
output_shape = x1.shape[:-2] + x1.shape[
-1:] + grad_output.shape[-1:]
grad_x2 = matmul_forward(
"MatMul", [x1, grad_output],
output_dtypes=[x2.dtype],
output_shapes=[output_shape],
attr_code="op.jt_name=\"matmul_trans_0\";")[0]
return grad_x1, grad_x2

View File

@ -49,7 +49,7 @@ namespace jittor
for (int idx = 0; idx < input_num; idx++)
{
inputTensors.push_back(nullptr);
if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0) )
if ((jt_name == "matmul_trans_1" && idx == 1) || (jt_name == "matmul_trans_0" && idx == 0))
{
auto ret = CreateFakeTransAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
@ -63,8 +63,8 @@ namespace jittor
}
void MatMulOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
ret = aclnnMatmulGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmulGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
if (workspaceSize > 0)
{
@ -72,6 +72,6 @@ namespace jittor
}
ret = aclnnMatmul(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnMatmul failed. ERROR: %d\n", name.c_str(), ret); return);
// syncRun();
syncRun();
}
}

View File

@ -10,6 +10,7 @@ namespace jittor
protected:
void setupInputDesc() override;
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
MatMulOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def nantonum_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def nantonum_cmd(name: str,
{attr_code}
op.run();""")
class NanToNumACL(jt.Function):
def __init__(self):
@ -67,7 +69,7 @@ class NanToNumACL(jt.Function):
"""
self.attr_code = attr_code
result = nantonum_cmd("NanToNum", [input],
output_dtypes=[input[0].dtype],
output_shapes=[input.shape],
attr_code=self.attr_code)[0]
return result
output_dtypes=[input[0].dtype],
output_shapes=[input.shape],
attr_code=self.attr_code)[0]
return result

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
NanToNumOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def pool_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,33 +52,32 @@ def pool_cmd(name: str,
{attr_code}
op.run();""")
class PoolACL(jt.Function):
def __init__(self,
kernel_size,
stride=None,
padding=0,
dilation=None,
return_indices=None,
ceil_mode=False,
count_include_pad=True,
op='maximum'):
kernel_size,
stride=None,
padding=0,
dilation=None,
return_indices=None,
ceil_mode=False,
count_include_pad=True,
op='maximum'):
self.kernel_size = kernel_size if isinstance(
kernel_size, tuple) else (kernel_size, kernel_size)
stride = stride if stride else kernel_size
self.stride = stride if isinstance(stride, tuple) else (stride,
stride)
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
self.padding = padding if isinstance(padding, tuple) else (padding,
padding)
padding)
dilation = dilation if dilation else 1
assert dilation == 1
self.dilation = dilation if isinstance(
dilation, tuple) else (dilation, dilation)
self.dilation = dilation if isinstance(dilation, tuple) else (dilation,
dilation)
for item in self.kernel_size:
if item <= 0:
raise RuntimeError(
f"kernel_size must be greater than zero, but got {item}"
)
f"kernel_size must be greater than zero, but got {item}")
for item in self.stride:
if item <= 0:
raise RuntimeError(
@ -108,7 +108,7 @@ class PoolACL(jt.Function):
kernel_height, kernel_width = self.kernel_size[-2:]
output_height = (input_height + 2 * self.padding[0] -
(kernel_height - 1) - 1) // self.stride[0] + 1
(kernel_height - 1) - 1) // self.stride[0] + 1
output_width = (input_width + 2 * self.padding[1] -
(kernel_width - 1) - 1) // self.stride[1] + 1
@ -161,16 +161,16 @@ class PoolACL(jt.Function):
output_dtypes = [input.dtype]
if self.op == 'maximum':
result = pool_cmd("MaxpoolBackward",
inputs=[grad_output, input, self.index],
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)[0]
inputs=[grad_output, input, self.index],
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)[0]
elif self.op == 'mean':
result = pool_cmd("AvgpoolBackward",
inputs=[grad_output, input],
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)[0]
inputs=[grad_output, input],
output_dtypes=output_dtypes,
output_shapes=output_shapes,
attr_code=attr_code)[0]
else:
raise ValueError('no this type pool')
return result
return result

View File

@ -71,7 +71,6 @@ namespace jittor
return;
}
AvgpoolOpRunner::AvgpoolOpRunner() : BaseOpRunner("Avgpool")
{
use_nchw = true;
@ -109,7 +108,6 @@ namespace jittor
return;
}
MaxpoolBackwardOpRunner::MaxpoolBackwardOpRunner() : BaseOpRunner("MaxpoolBackward")
{
use_nchw = true;
@ -150,8 +148,6 @@ namespace jittor
return;
}
AvgpoolBackwardOpRunner::AvgpoolBackwardOpRunner() : BaseOpRunner("AvgpoolBackward")
{
use_nchw = true;

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
MaxpoolOpRunner();
};
@ -18,16 +19,17 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
AvgpoolOpRunner();
};
class MaxpoolBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
MaxpoolBackwardOpRunner();
};
@ -37,6 +39,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
AvgpoolBackwardOpRunner();
};

View File

@ -44,7 +44,7 @@ namespace jittor
void RandomOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<RandomAttr *>(op_attr.get());
if(name == "RandomUniform")
if (name == "RandomUniform")
{
ret = aclnnInplaceUniformGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
@ -58,7 +58,7 @@ namespace jittor
ret = aclnnInplaceUniform(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnInplaceUniform failed. ERROR: %d\n", name.c_str(), ret); return);
}
else if(name == "RandomNormal")
else if (name == "RandomNormal")
{
ret = aclnnInplaceNormalGetWorkspaceSize(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
@ -76,7 +76,7 @@ namespace jittor
{
LOGf << "Not supported random type : " << name;
}
// syncRun();
syncRun();
return;
}
}

View File

@ -28,7 +28,6 @@
#include "aclnn/aclnn.h"
#include "reduce_op_acl.h"
namespace jittor
{
ReduceOpRunner::ReduceOpRunner() : BaseOpRunner("reduce")

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def relu_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def relu_cmd(name: str,
{attr_code}
op.run();""")
class ReLUACL(jt.Function):
def __init__(self):
@ -60,9 +62,9 @@ class ReLUACL(jt.Function):
x = x.float32()
self.input = x
result = relu_cmd("ReLU", [x],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"unary\";")[0]
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"unary\";")[0]
return result
def grad(self, grad_output):
@ -72,11 +74,11 @@ class ReLUACL(jt.Function):
output_shapes=[self.input.shape],
attr_code="op.jt_name=\"binary\";")[0]
grad_input = relu_cmd("Mul", [grad_output, mask],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code="op.jt_name=\"binary\";")[0]
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code="op.jt_name=\"binary\";")[0]
return grad_input
class LeakyReLUACL(jt.Function):
@ -93,9 +95,9 @@ class LeakyReLUACL(jt.Function):
op.op_attr.reset(attr);
"""
result = relu_cmd("LeakyReLU", [x],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code=attr_code)[0]
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code=attr_code)[0]
self.negative_slope = negative_slope
return result
@ -107,9 +109,8 @@ class LeakyReLUACL(jt.Function):
attr->selfIsResult = false;
op.op_attr.reset(attr);
"""
grad_input = relu_cmd("LeakyReLUBackward",
[grad_output, self.input],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input
grad_input = relu_cmd("LeakyReLUBackward", [grad_output, self.input],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
LeakyReLUOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
LeakyReLUBackwardOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def rope_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def rope_cmd(name: str,
{attr_code}
op.run();""")
class RopeACL(jt.Function):
def __init__(self):
@ -67,16 +69,16 @@ class RopeACL(jt.Function):
assert freq_cos is not None and freq_sin is not None
inputs = [xq, xk, freq_cos, freq_sin]
results = rope_cmd("RotaryPosEmb",
inputs,
output_dtypes=[
xq.dtype,
],
output_shapes=[
xq.shape,
],
attr_code=attr_code)
inputs,
output_dtypes=[
xq.dtype,
],
output_shapes=[
xq.shape,
],
attr_code=attr_code)
results[0].sync()
return inputs[0], inputs[1]
def grad(self, grad_output):
return grad_output
return grad_output

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
RotaryPosEmbOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def setitem_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def setitem_cmd(name: str,
{attr_code}
op.run();""")
def setitem_forward(name: str,
inputs: list,
output_dtypes: list = None,
@ -90,7 +92,8 @@ def setitem_forward(name: str,
op.add(out0, false);
{attr_code}
op.run();""",
data=extra_data)
data=extra_data)
def caculate_shape(tensors):
if isinstance(tensors, jt.Var):
@ -105,6 +108,7 @@ def caculate_shape(tensors):
else:
assert False, f"not implemented for {type(tensors)}"
def can_broadcast_and_shape(shape1, shape2):
"""
检查两个张量是否可以广播并返回广播后的形状
@ -144,6 +148,7 @@ def can_broadcast_and_shape(shape1, shape2):
return True, tuple(broadcast_shape)
class SetItemACL(jt.Function):
def __init__(self):
@ -180,9 +185,9 @@ class SetItemACL(jt.Function):
op.jt_name = "inplacemaskedscatter";
"""
result = setitem_cmd("InplaceMaskedScatter",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return result
# assert isinstance(value,jt.Var), "value must be jt.Var"
@ -199,7 +204,7 @@ class SetItemACL(jt.Function):
contains_slice = False
for s in slices:
if not isinstance(s, jt.Var) and (isinstance(s, slice)
or s == Ellipsis):
or s == Ellipsis):
contains_slice = True
break
if not contains_slice:
@ -220,9 +225,9 @@ class SetItemACL(jt.Function):
self.value_shape = value_shape
for ii in slices:
indices.append(jt.Var(ii).int32())
if isinstance(slices[0], jt.Var) or isinstance(
slices[0], int) or isinstance(
slices[0], list) or isinstance(slices[0], tuple):
if isinstance(slices[0],
jt.Var) or isinstance(slices[0], int) or isinstance(
slices[0], list) or isinstance(slices[0], tuple):
self.indices = indices
self.type_ = 'index'
attr_code = f"""
@ -231,9 +236,9 @@ class SetItemACL(jt.Function):
inputs = [value] + indices
outputs = [x.clone()]
result = setitem_cmd("IndexPutImpl",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
# result.sync()
return result
assert "not support"
@ -244,8 +249,7 @@ class SetItemACL(jt.Function):
if not isinstance(s, jt.Var) and s == Ellipsis:
slices = slices[:slices.index(s)] + [
slice(None, None, None)
] * (x_dim - len(slices) + 1) + slices[slices.index(s) +
1:]
] * (x_dim - len(slices) + 1) + slices[slices.index(s) + 1:]
break
slices = tuple(slices)
self.input_slice = slices
@ -335,10 +339,10 @@ class SetItemACL(jt.Function):
inputs = [value]
outputs = [x.clone()]
result = setitem_forward("StridedSliceAssignV2",
inputs=inputs,
outputs=outputs,
attr_code=attr_code,
extra_data=extra_data)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code,
extra_data=extra_data)[0]
if expand_dim:
result = result.squeeze(-1)
# result.sync()
@ -349,4 +353,4 @@ class SetItemACL(jt.Function):
if self.value_var:
value_grad = grad_output[self.input_slice]
grad_output[self.input_slice] = jt.zeros(self.value_shape)
return grad_output, None, value_grad
return grad_output, None, value_grad

View File

@ -33,7 +33,7 @@ namespace jittor
InplaceMaskedScatterOpRunner::InplaceMaskedScatterOpRunner() : BaseOpRunner("InplaceMaskedScatter")
{
}
void InplaceMaskedScatterOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnInplaceMaskedScatterGetWorkspaceSize(outputTensors[0], inputTensors[0], inputTensors[1], &workspaceSize, &executor);
@ -55,7 +55,7 @@ namespace jittor
IndexPutImplOpRunner::IndexPutImplOpRunner() : BaseOpRunner("IndexPutImpl")
{
}
void IndexPutImplOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto input_num = in_.size();
@ -77,7 +77,7 @@ namespace jittor
ret = aclnnIndexPutImpl(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnIndexPutImpl failed. ERROR: %d\n", name.c_str(), ret); return);
// syncRun();
syncRun();
return;
}

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
InplaceMaskedScatterOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
IndexPutImplOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def sigmoid_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def sigmoid_cmd(name: str,
{attr_code}
op.run();""")
class SigmoidACL(jt.Function):
def __init__(self):
@ -64,9 +66,9 @@ class SigmoidACL(jt.Function):
op.jt_name = "sigmoid";
"""
result = sigmoid_cmd("Sigmoid",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
self.output = result
return result
@ -77,7 +79,7 @@ class SigmoidACL(jt.Function):
inputs = [grad_output, self.output]
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
grad_input = sigmoid_cmd("SigmoidBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SigmoidOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SigmoidBackwardOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def silu_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def silu_cmd(name: str,
{attr_code}
op.run();""")
class SiLUACL(jt.Function):
def __init__(self):
@ -65,9 +67,9 @@ class SiLUACL(jt.Function):
op.jt_name = "silu";
"""
result = silu_cmd("SiLU",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
@ -77,7 +79,7 @@ class SiLUACL(jt.Function):
inputs = [grad_output, self.input]
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
grad_input = silu_cmd("SiLUBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SiLUOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SiLUBackwardOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def softmax_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def softmax_cmd(name: str,
{attr_code}
op.run();""")
class SoftmaxACL(jt.Function):
def __init__(self):
@ -68,9 +70,9 @@ class SoftmaxACL(jt.Function):
op.op_attr.reset(attr);
"""
result = softmax_cmd("Softmax",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
self.output = result
return result
@ -84,7 +86,7 @@ class SoftmaxACL(jt.Function):
inputs = [grad_output, self.output]
outputs = [jt.empty(grad_output.shape)]
grad_input = softmax_cmd("SoftmaxBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SoftmaxOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
SoftmaxBackwardOpRunner();
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def stack_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def stack_cmd(name: str,
{attr_code}
op.run();""")
class StackACL(jt.Function):
def __init__(self):
@ -60,15 +62,12 @@ class StackACL(jt.Function):
if type(input_tensors) is tuple:
input_tensors = list(input_tensors)
assert type(input_tensors) is list
assert -1 * len(input_tensors) - 1 <= dim and dim <= len(
input_tensors)
assert -1 * len(input_tensors) - 1 <= dim and dim <= len(input_tensors)
for i in range(len(input_tensors)):
if input_tensors[i].dtype != input_tensors[0].dtype:
raise ValueError(
"All input tensors must have the same dtype")
raise ValueError("All input tensors must have the same dtype")
if input_tensors[i].shape != input_tensors[0].shape:
raise ValueError(
"All input tensors must have the same shape")
raise ValueError("All input tensors must have the same shape")
self.input = input_tensors
input_shape = list(input_tensors[0].shape)
output_shape = input_shape[:dim] + [len(input_tensors)
@ -82,10 +81,10 @@ class StackACL(jt.Function):
"""
self.attr_code = attr_code
result = stack_cmd("Stack",
input_tensors,
output_dtypes=[input_tensors[0].dtype],
output_shapes=[output_shape],
attr_code=self.attr_code)[0]
input_tensors,
output_dtypes=[input_tensors[0].dtype],
output_shapes=[output_shape],
attr_code=self.attr_code)[0]
return result
def grad(self, grad_output):
@ -110,7 +109,7 @@ class StackACL(jt.Function):
"""
result = stack_cmd("SplitWithSize", [grad_output],
output_dtypes=dtypeVec,
output_shapes=shapeVec,
attr_code=attr_code)
return result
output_dtypes=dtypeVec,
output_shapes=shapeVec,
attr_code=attr_code)
return result

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
StackOpRunner();
};

View File

@ -48,7 +48,7 @@ namespace jittor
ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
//syncRun();
// syncRun();
return;
}
}

View File

@ -7,7 +7,7 @@ namespace jittor
struct TernaryOpRunner : public BaseOpRunner
{
TernaryOpRunner();
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
};

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def transpose_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def transpose_cmd(name: str,
{attr_code}
op.run();""")
class TransPoseACL(jt.Function):
def __init__(self):
@ -75,9 +77,9 @@ class TransPoseACL(jt.Function):
# calculate output shape
output_shape = [x.shape[i] for i in dim]
output = transpose_cmd("Transpose", [x],
output_dtypes=[x.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
output_dtypes=[x.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
self.dim = dim
return output
@ -93,7 +95,7 @@ class TransPoseACL(jt.Function):
op.op_attr.reset(attr);
"""
output = transpose_cmd("Transpose", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
output_dtypes=[grad_output.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
return output

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
TransposeOpRunner();
};

View File

@ -53,7 +53,7 @@ namespace jittor
ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
//syncRun();
// syncRun();
return;
}
}

View File

@ -11,13 +11,14 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def where_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
@ -51,6 +52,7 @@ def where_cmd(name: str,
{attr_code}
op.run();""")
class NonzeroACL(jt.Function):
def __init__(self):
@ -63,15 +65,16 @@ class NonzeroACL(jt.Function):
nonzero_cnt = (x != 0.0).sum().item()
result = where_cmd("Nonzero", [x],
output_dtypes=['int64'],
output_shapes=[(nonzero_cnt, x.ndim)],
attr_code=attr_code)[0]
output_dtypes=['int64'],
output_shapes=[(nonzero_cnt, x.ndim)],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
return grad_output
class WhereACL(jt.Function):
def __init__(self):
@ -104,9 +107,9 @@ class WhereACL(jt.Function):
self.y = y
result = where_cmd("Where", [condition, x, y],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"where\";")[0]
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"where\";")[0]
return result
def grad(self, grad_output):
@ -115,12 +118,12 @@ class WhereACL(jt.Function):
else:
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
grad_x = where_cmd("Where", [self.condition, grad_output, tmp],
output_dtypes=[self.x.dtype],
output_shapes=[self.x.shape],
attr_code="op.jt_name=\"where\";")[0]
output_dtypes=[self.x.dtype],
output_shapes=[self.x.shape],
attr_code="op.jt_name=\"where\";")[0]
grad_y = where_cmd("Where", [self.condition, tmp, grad_output],
output_dtypes=[self.y.dtype],
output_shapes=[self.y.shape],
attr_code="op.jt_name=\"where\";")[0]
output_dtypes=[self.y.dtype],
output_shapes=[self.y.shape],
attr_code="op.jt_name=\"where\";")[0]
return grad_output, grad_x, grad_y

View File

@ -53,7 +53,6 @@ namespace jittor
return;
}
NonzeroOpRunner::NonzeroOpRunner() : BaseOpRunner("Nonzero")
{
}

View File

@ -9,6 +9,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
WhereOpRunner();
};
@ -18,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
NonzeroOpRunner();
};