Merge pull request #14 from CSCG-Lab/splits

split triu,embedding,batchnorm
This commit is contained in:
Yuxuan Han 2024-12-23 15:57:30 +08:00 committed by GitHub
commit 5ce5b45c58
12 changed files with 742 additions and 504 deletions

View File

@ -267,31 +267,13 @@ def change_function():
from .aclops.matmul_op import MatmulACL
from .aclops.transpose_op import TransPoseACL
class TriuACL(Function):
def __init__(self):
super(TriuACL, self).__init__()
def execute(self, input, diagonal):
attr_code = f"""
op.jt_name = "triu";
TriuAttr *attr = new TriuAttr();
attr->diagonal = {diagonal};
op.op_attr.reset(attr);
"""
result = acl_cmd("Triu", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
return grad_output
from .aclops.triu_op import TriuACL
def triu_acl(x, diagonal=0):
return TriuACL()(x, diagonal)
from .aclops.conv_op import ConvACL
def conv_acl(x,
weight,
bias=None,
@ -401,12 +383,17 @@ def change_function():
self.padding, self.dilation, self.groups)
return ret
from .aclops.flip_op import FlipACL
def flip_acl(x, dim):
return FlipACL()(x, dim)
from .aclops.concat_op import ConcatACL
def concat(x, dim=0):
return ConcatACL()(x, dim)
from .aclops.gather_scatter_op import GatherACL
def gather_acl(input, dim, index):
return GatherACL()(input, dim, index)
@ -416,6 +403,8 @@ def change_function():
else:
return jt.array([False])
from .aclops.cumsum_op import CumsumACL
def cumsum_acl(input, dim=-1):
return CumsumACL()(input, dim)
@ -424,23 +413,34 @@ def change_function():
x = cumsum_acl(x, dim=dim)
return jt.exp(x)
from .aclops.index_op import IndexACL
def index_acl(inshape: Union[jt.Var, list], dim=None, dtype="int32"):
if isinstance(inshape, jt.Var):
inshape = inshape.shape
return IndexACL()(inshape, dim, dtype)
from .aclops.gather_scatter_op import ScatterACL
def scatter_acl(input, dim, index, src, reduce='void'):
return ScatterACL()(input, dim, index, src, reduce)
from .aclops.where_op import WhereACL
def where_acl(condition, x=None, y=None):
return WhereACL()(condition, x, y)
from .aclops.where_op import NonzeroACL
def nonzero_acl(x):
return NonzeroACL()(x)
from .aclops.floor_op import FloorIntACL
def floor_int_acl(x):
return FloorIntACL()(x)
from .aclops.getitem_op import GetItemACL
def getitem_acl(x, slices, return_x=None):
# Transform numpy int to int
if isinstance(slices, (np.int8, np.int16, np.int32, np.int64)):
@ -489,22 +489,33 @@ def change_function():
return result
from .aclops.setitem_op import SetItemACL
def setitem_acl(x, slices, value):
res = SetItemACL()(x, slices, value)
return x.assign(res)
from .aclops.bmm_op import BmmACL
def bmm_acl(x1, x2):
return BmmACL()(x1, x2)
def bmm_transpose_acl(x1, x2):
return BmmACL(True)(x1, x2)
from .aclops.matmul_op import MatmulACL
def matmul_acl(x1, x2):
return MatmulACL()(x1, x2)
def matmul_transpose_acl(x1, x2):
return MatmulACL(True)(x1, x2)
from .aclops.transpose_op import TransPoseACL
def transpose_acl(x, *dim):
return TransPoseACL()(x, *dim)
@ -545,6 +556,8 @@ def change_function():
def relu(x):
return ReLUACL()(x)
from .aclops.relu_op import LeakyReLUACL
class LeakyReLU(jt.nn.Module):
def __init__(self, negative_slope=0.01):
@ -557,6 +570,8 @@ def change_function():
def leaky_relu(x, scale=0.01):
return LeakyReLUACL()(x, scale)
from .aclops.dropout_op import DropoutACL
class Dropout(jt.nn.Module):
def __init__(self, p=0.5, is_train=False):
@ -570,6 +585,8 @@ def change_function():
def dropout_acl(x, p=0.5, is_train=False):
return DropoutACL()(x, p, is_train)
from .aclops.silu_op import SiLUACL
def silu_acl(x):
return SiLUACL()(x)
@ -581,6 +598,8 @@ def change_function():
def execute(self, x):
return SiLUACL()(x)
from .aclops.sigmoid_op import SigmoidACL
def sigmoid_acl(x):
return SigmoidACL()(x)
@ -592,64 +611,25 @@ def change_function():
def execute(self, x):
return SigmoidACL()(x)
class EmbeddingACL(Function):
# class Embedding(jt.nn.Module):
def __init__(self):
super(EmbeddingACL, self).__init__()
def execute(
self,
indices,
weight,
):
inputs = [weight, indices]
self.indices = indices
self.weight_shape = weight.shape
output_shape = list(indices.shape) + list(weight.shape[1:])
outputs = [jt.empty(output_shape, weight.dtype)]
attr_code = f"""
op.jt_name = "embedding";
"""
result = acl_cmd("Embedding",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
inputs = [grad_output, self.indices]
outputs = [jt.empty(self.weight_shape, grad_output.dtype)]
attr_code = f"""
op.jt_name = "embeddingbackward";
EmbeddingAttr *attr = new EmbeddingAttr();
attr->numEmbeddings = {self.weight_shape[0]};
op.op_attr.reset(attr);
"""
grad_weight = acl_cmd("EmbeddingBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return None, grad_weight
class Embedding(jt.nn.Module):
def __init__(self,
num_embeddings,
embedding_dim,
padding_idx=None,
dtype="float32"):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weight = jt.init.gauss(
[self.num_embeddings, self.embedding_dim], dtype)
if padding_idx is not None:
self.weight[padding_idx] = 0
def execute(self, x):
res = embedding_acl(x, self.weight)
return res
# def __init__(self,
# num_embeddings,
# embedding_dim,
# padding_idx=None,
# dtype="float32"):
# self.num_embeddings = num_embeddings
# self.embedding_dim = embedding_dim
# self.padding_idx = padding_idx
# self.weight = jt.init.gauss(
# [self.num_embeddings, self.embedding_dim], dtype)
# if padding_idx is not None:
# self.weight[padding_idx] = 0
# def execute(self, x):
# res = embedding_acl(x, self.weight)
# return res
class Softmax(jt.nn.Module):
def __init__(self):
@ -661,89 +641,16 @@ def change_function():
def softmax_acl(x, dim):
return SoftmaxACL()(x, dim)
from .aclops.rope_op import RopeACL
def rope_acl(xq, xk, freqs_cis=None, freq_sin=None, freq_cos=None):
return RopeACL()(xq, xk, freqs_cis, freq_sin, freq_cos)
class BatchNormACL(Function):
def __init__(self,
num_features,
eps=1e-05,
momentum=0.1,
affine=True,
is_train=True,
sync=True):
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.is_train = is_train
self.sync = sync
self.weight = jt.init.constant(
(num_features, ), "float32", 1.0) if affine else 1.0
self.bias = jt.init.constant(
(num_features, ), "float32", 0.0) if affine else 0.0
self.running_mean = jt.init.constant((num_features, ), "float32",
0.0).stop_grad()
self.running_var = jt.init.constant((num_features, ), "float32",
1.0).stop_grad()
def execute(self, x):
# assert self.num_features == x.shape[-1]
self.input = x.float32()
inputs = [
self.input, self.weight, self.bias, self.running_mean,
self.running_var
]
outputs = [
jt.empty(x.shape),
jt.empty(self.num_features),
jt.empty(self.num_features)
]
attr_code = f"""
op.jt_name = "batchnorm";
BatchNormAttr *attr = new BatchNormAttr();
attr->is_train = {"true" if self.is_train else "false"};
attr->momentum = {self.momentum};
attr->eps = {self.eps};
op.op_attr.reset(attr);
"""
result = acl_cmd("BatchNorm",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)
self.output = result[0]
self.saveMean = result[1]
self.saveInvstd = result[2]
return self.output
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "batchnorm";
BatchNormAttr *attr = new BatchNormAttr();
attr->is_train = {"true" if self.is_train else "false"};
attr->momentum = {self.momentum};
attr->eps = {self.eps};
op.op_attr.reset(attr);
"""
inputs = [
grad_output, self.input, self.weight, self.running_mean,
self.running_var, self.saveMean, self.saveInvstd
]
outputs = [
jt.empty(self.input.shape),
jt.empty(self.num_features),
jt.empty(self.num_features)
]
grad_input = acl_cmd("SoftmaxBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input
from .aclops.stack_op import StackACL
def stack_acl(x, dim=0):
return StackACL()(x, dim)
from .aclops.nantonum_op import NanToNumACL
def isnan_acl(x):
tonum = NanToNumACL()(x, -1.0)
return jt.not_equal(x, tonum).logical_and(
@ -872,6 +779,7 @@ def change_function():
jt.sigmoid = warp(jt.sigmoid, sigmoid_acl)
jt.nn.Sigmoid = warp(jt.nn.Sigmoid, Sigmoid)
# from .aclops.embedding_op import EmbeddingACL
# def embedding_acl(indices, weight):
# return EmbeddingACL()(indices, weight)
@ -882,6 +790,7 @@ def change_function():
jt.nn.softmax = warp(jt.nn.softmax, softmax_acl)
# from .aclops.norms_op import BatchNormACL,LayerNormACL
# jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL)
# jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL)

View File

@ -214,70 +214,6 @@ namespace jittor
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
break;
}
// case 7:
// {
// ret = it->second.getWorkspaceSizeFuncMatmul(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
// break;
// }
// case 8:
// {
// ret = it->second.getWorkspaceSizeFuncMatmul(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor);
// break;
// }
// case 9:
// {
// ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 10:
// {
// ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 11:
// {
// ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 12:
// {
// ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 13:
// {
// auto attr = dynamic_cast<RandomAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
// break;
// }
// case 14:
// {
// auto attr = dynamic_cast<RandomAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
// break;
// }
// case 15:
// {
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 17:
// {
// ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 18:
{
auto attr = dynamic_cast<TriuAttr *>(op_attr.get());
ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor);
break;
}
// case 19:
// {
// ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 20:
{
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
@ -309,291 +245,6 @@ namespace jittor
ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
break;
}
// case 22:
// {
// auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
// kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
// strides = aclCreateIntArray(attr->poolStrides.data(), 2);
// pads = aclCreateIntArray(attr->poolPads.data(), 2);
// dilations = aclCreateIntArray(attr->poolDilations.data(), 2);
// ret = it->second.getWorkspaceSizeFuncMaxPool(inputTensors[0], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
// break;
// }
// case 23:
// {
// auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
// kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
// strides = aclCreateIntArray(attr->poolStrides.data(), 2);
// pads = aclCreateIntArray(attr->poolPads.data(), 2);
// dilations = aclCreateIntArray(attr->poolDilations.data(), 2);
// ret = it->second.getWorkspaceSizeFuncMaxPoolBackward(inputTensors[0], inputTensors[1], inputTensors[2], kernel_size, strides, pads, dilations, attr->poolCeil, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 24:
// {
// auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
// kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
// strides = aclCreateIntArray(attr->poolStrides.data(), 2);
// pads = aclCreateIntArray(attr->poolPads.data(), 2);
// ret = it->second.getWorkspaceSizeFuncAvgPool(inputTensors[0], kernel_size, strides, pads, attr->poolCeil, attr->countIncludePad, attr->divisorOverride, attr->divisorOverride, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 25:
// {
// auto attr = dynamic_cast<PoolAttr *>(op_attr.get());
// kernel_size = aclCreateIntArray(attr->kernel_size.data(), 2);
// strides = aclCreateIntArray(attr->poolStrides.data(), 2);
// pads = aclCreateIntArray(attr->poolPads.data(), 2);
// ret = it->second.getWorkspaceSizeFuncAvgPoolBackward(inputTensors[0], inputTensors[1], kernel_size, strides, pads, attr->countIncludePad, attr->divisorOverride, attr->divisorOverride, attr->poolCeil, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 26:
// {
// auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
// dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
// ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 27:
// {
// std::vector<aclTensor *> concatTensorList = {};
// for (int i = 0; i < input_num; i++)
// {
// concatTensorList.push_back(inputTensors[i]);
// }
// auto concatTensorListInput = aclCreateTensorList(&concatTensorList[0], input_num);
// auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncConcat(concatTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 28:
// {
// auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncGather(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 29:
// {
// auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncCumsum(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 30:
// {
// auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 31:
// {
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 32:
// {
// auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);
// ret = it->second.getWorkspaceSizeFuncIndex(inputTensors[0], indexTensorList, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 33:
// {
// auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
// auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
// auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
// auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
// auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
// ret = it->second.getWorkspaceSizeFuncSliceV2(inputTensors[0], begins, ends, axes, steps, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 34:
// {
// std::vector<aclTensor *> indexTensorList = {};
// for (int i = 1; i < input_num; i++)
// {
// indexTensorList.push_back(inputTensors[i]);
// }
// auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1);
// ret = it->second.getWorkspaceSizeFuncIndexPutImpl(outputTensors[0], indexTensorListInput, inputTensors[0], false, true, &workspaceSize, &executor);
// break;
// }
// case 35:
// {
// std::vector<aclTensor *> indexTensorList = {};
// for (int i = 1; i < input_num; i++)
// {
// indexTensorList.push_back(inputTensors[i]);
// }
// auto indexTensorListInput = aclCreateTensorList(&indexTensorList[0], input_num - 1);
// ret = it->second.getWorkspaceSizeFuncIndexPutImpl(outputTensors[0], indexTensorListInput, inputTensors[0], true, true, &workspaceSize, &executor);
// break;
// }
// case 36:
// {
// auto attr = dynamic_cast<StrideAttr *>(op_attr.get());
// auto begins = aclCreateIntArray(attr->begins.data(), attr->begins.size());
// auto ends = aclCreateIntArray(attr->ends.data(), attr->ends.size());
// auto steps = aclCreateIntArray(attr->steps.data(), attr->steps.size());
// auto axes = aclCreateIntArray(attr->axes.data(), attr->axes.size());
// ret = it->second.getWorkspaceSizeFuncStridedSliceAssignV2(outputTensors[0], inputTensors[0], begins, ends, steps, axes, &workspaceSize, &executor);
// break;
// }
// case 37:
// {
// ret = it->second.getWorkspaceSizeFuncRange(start, end, step, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 38:
// {
// auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
// negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
// ret = it->second.getWorkspaceSizeFuncLeakyRelu(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 39:
// {
// auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
// negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
// ret = it->second.getWorkspaceSizeFuncLeakyReluBackward(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 40:
// {
// auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncDropout(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
// break;
// }
// case 41:
// {
// auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 42:
// {
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 43:
// {
// ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 44:
// {
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 45:
// {
// ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 46:
{
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
break;
}
case 47:
{
auto attr = dynamic_cast<EmbeddingAttr *>(op_attr.get());
auto numEmbeddings = attr->numEmbeddings;
ret = it->second.getWorkspaceSizeFuncEmbeddingBackward(inputTensors[0], inputTensors[1], numEmbeddings, 0, false, outputTensors[0], &workspaceSize, &executor);
break;
}
// case 48:
// {
// ret = it->second.getWorkspaceSizeFuncBinary(outputTensors[0], inputTensors[1], inputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 49:
// {
// ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 50:
// {
// auto attr = dynamic_cast<SplitWithSizeAttr *>(op_attr.get());
// auto splitSize = aclCreateIntArray(attr->splitSize.data(), attr->splitSize.size());
// auto tensorList = aclCreateTensorList(&outputTensors[0], output_num);
// ret = it->second.getWorkspaceSizeFuncSplitWithSize(inputTensors[0], splitSize, attr->dim, tensorList, &workspaceSize, &executor);
// break;
// }
// case 51:
// {
// auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
// auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
// auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
// auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
// char *layout = const_cast<char *>(attr->inputLayout.data());
// ret = it->second.getWorkspaceSizeFuncFalshAttention(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
// break;
// }
// case 52:
// {
// auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
// auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
// auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
// auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
// char *layout = const_cast<char *>(attr->inputLayout.data());
// ret = it->second.getWorkspaceSizeFuncFalshAttentionBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
// break;
// }
// case 53:
// {
// auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->dim), outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 54:
// {
// auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->dim, outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 55:
{
auto attr = dynamic_cast<BatchNormAttr *>(op_attr.get());
ret = it->second.getWorkspaceSizeFuncBatchNorm(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], attr->is_train, attr->momentum, attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
break;
}
case 56:
{
auto attr = dynamic_cast<BatchNormAttr *>(op_attr.get());
bool outputMask[3] = {true, true, true};
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
ret = it->second.getWorkspaceSizeFuncBatchNormBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], inputTensors[5], inputTensors[6], attr->is_train, attr->eps, outMask, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
break;
}
case 57:
{
auto attr = dynamic_cast<LayerNormAttr *>(op_attr.get());
normalizedShape = aclCreateIntArray(attr->normalizedShape.data(), attr->size);
ret = it->second.getWorkspaceSizeFuncLayerNorm(inputTensors[0], normalizedShape, inputTensors[1], inputTensors[2], attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
break;
}
// case 58:
// {
// ret = it->second.getWorkspaceSizeFuncRotaryPosEmb(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], (int64_t)1, &workspaceSize, &executor);
// break;
// }
// case 59:
// {
// std::vector<aclTensor *> stackTensorList = {};
// for (int i = 0; i < input_num; i++)
// {
// stackTensorList.push_back(inputTensors[i]);
// }
// auto stackTensorListInput = aclCreateTensorList(&stackTensorList[0], input_num);
// auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncConcat(stackTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 60:
// {
// auto attr = dynamic_cast<NanToNumAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncProdDim(inputTensors[0], attr->nan, attr->posinf, attr->neginf, outputTensors[0], &workspaceSize, &executor);
// break;
// }
default:
{
LOGir << "not supported op: " << name;

View File

@ -27,4 +27,7 @@
#include <acl/aclops/softmax_op_acl.h>
#include <acl/aclops/stack_op_acl.h>
#include <acl/aclops/nantonum_op_acl.h>
#include <acl/aclops/rope_op_acl.h>
#include <acl/aclops/rope_op_acl.h>
#include <acl/aclops/triu_op_acl.h>
#include <acl/aclops/embedding_op_acl.h>
#include <acl/aclops/norms_op_acl.h>

View File

@ -0,0 +1,91 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def embedding_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class EmbeddingACL(jt.Function):
def __init__(self):
super(EmbeddingACL, self).__init__()
def execute(
self,
indices,
weight,
):
inputs = [weight, indices]
self.indices = indices
self.weight_shape = weight.shape
output_shape = list(indices.shape) + list(weight.shape[1:])
outputs = [jt.empty(output_shape, weight.dtype)]
attr_code = f"""
op.jt_name = "embedding";
"""
result = embedding_cmd("Embedding",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
inputs = [grad_output, self.indices]
outputs = [jt.empty(self.weight_shape, grad_output.dtype)]
attr_code = f"""
op.jt_name = "embeddingbackward";
EmbeddingAttr *attr = new EmbeddingAttr();
attr->numEmbeddings = {self.weight_shape[0]};
op.op_attr.reset(attr);
"""
grad_weight = embedding_cmd("EmbeddingBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return None, grad_weight

View File

@ -0,0 +1,82 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "embedding_op_acl.h"
namespace jittor
{
EmbeddingOpRunner::EmbeddingOpRunner() : BaseOpRunner("Embedding")
{
}
void EmbeddingOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnEmbeddingGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnEmbedding(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbedding failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
EmbeddingBackwardOpRunner::EmbeddingBackwardOpRunner() : BaseOpRunner("EmbeddingBackward")
{
}
void EmbeddingBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<EmbeddingAttr *>(op_attr.get());
auto numEmbeddings = attr->numEmbeddings;
ret = aclnnEmbeddingDenseBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], numEmbeddings, 0, false, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnEmbeddingDenseBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnEmbeddingDenseBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class EmbeddingOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
EmbeddingOpRunner();
};
class EmbeddingBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
EmbeddingBackwardOpRunner();
};
}

View File

@ -0,0 +1,184 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def norms_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class BatchNormACL(jt.Function):
def __init__(self,
num_features,
eps=1e-05,
momentum=0.1,
affine=True,
is_train=True,
sync=True):
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.is_train = is_train
self.sync = sync
self.weight = jt.init.constant(
(num_features, ), "float32", 1.0) if affine else 1.0
self.bias = jt.init.constant(
(num_features, ), "float32", 0.0) if affine else 0.0
self.running_mean = jt.init.constant((num_features, ), "float32",
0.0).stop_grad()
self.running_var = jt.init.constant((num_features, ), "float32",
1.0).stop_grad()
def execute(self, x):
# assert self.num_features == x.shape[-1]
self.input = x.float32()
inputs = [
self.input, self.weight, self.bias, self.running_mean,
self.running_var
]
outputs = [
jt.empty(x.shape),
jt.empty(self.num_features),
jt.empty(self.num_features)
]
attr_code = f"""
op.jt_name = "batchnorm";
BatchNormAttr *attr = new BatchNormAttr();
attr->is_train = {"true" if self.is_train else "false"};
attr->momentum = {self.momentum};
attr->eps = {self.eps};
op.op_attr.reset(attr);
"""
result = norms_cmd("BatchNorm",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)
self.output = result[0]
self.saveMean = result[1]
self.saveInvstd = result[2]
return self.output
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "batchnorm";
BatchNormAttr *attr = new BatchNormAttr();
attr->is_train = {"true" if self.is_train else "false"};
attr->momentum = {self.momentum};
attr->eps = {self.eps};
op.op_attr.reset(attr);
"""
inputs = [
grad_output, self.input, self.weight, self.running_mean,
self.running_var, self.saveMean, self.saveInvstd
]
outputs = [
jt.empty(self.input.shape),
jt.empty(self.num_features),
jt.empty(self.num_features)
]
grad_input = norms_cmd("BatchNormBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input
class LayerNormACL(jt.Function):
def __init__(self,
normalized_shape,
eps: float = 1e-5,
elementwise_affine: bool = True):
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape, )
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
self.weight = jt.init.constant(normalized_shape, "float32",
1.0) if elementwise_affine else 1.0
self.bias = jt.init.constant(normalized_shape, "float32",
0.0) if elementwise_affine else 0.0
def execute(self, x):
self.input = x.float32()
inputs = [self.input, self.weight, self.bias]
outputs = [jt.empty(x.shape), jt.empty(x.shape), jt.empty(x.shape)]
attr_code = f"""
op.jt_name = "layernorm";
LayerNormAttr *attr = new LayerNormAttr();
attr->eps = {self.eps};
attr->normalizedShape = {{{', '.join(map(str, (list(self.normalized_shape))))}}};
attr->size = {x.shape[-1]};
op.op_attr.reset(attr);
"""
result = norms_cmd("LayerNorm",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)
self.output = result[0]
self.meanout = result[1]
self.rstdout = result[2]
return self.output
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "batchnorm";
BatchNormAttr *attr = new BatchNormAttr();
attr->is_train = {"true" if self.is_train else "false"};
attr->momentum = {self.momentum};
attr->eps = {self.eps};
op.op_attr.reset(attr);
"""
inputs = [grad_output, self.input, self.weight, self.running_mean, self.running_var, self.saveMean, self.saveInvstd]
outputs = [jt.empty(self.input.shape), jt.empty(self.num_features), jt.empty(self.num_features)]
grad_input = norms_cmd("SoftmaxBackward",
inputs=inputs,
outputs=outputs,
attr_code=attr_code)[0]
return grad_input

View File

@ -0,0 +1,111 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "norms_op_acl.h"
namespace jittor
{
BatchNormOpRunner::BatchNormOpRunner() : BaseOpRunner("BatchNorm")
{
}
void BatchNormOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<BatchNormAttr *>(op_attr.get());
ret = aclnnBatchNormGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], attr->is_train, attr->momentum, attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnBatchNorm(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchNorm failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
BatchNormBackwardOpRunner::BatchNormBackwardOpRunner() : BaseOpRunner("BatchNormBackward")
{
}
void BatchNormBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<BatchNormAttr *>(op_attr.get());
bool outputMask[3] = {true, true, true};
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
ret = aclnnBatchNormBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], inputTensors[4], inputTensors[5], inputTensors[6], attr->is_train, attr->eps, outMask, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnBatchNormBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnBatchNormBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
LayerNormOpRunner::LayerNormOpRunner() : BaseOpRunner("LayerNorm")
{
}
void LayerNormOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<LayerNormAttr *>(op_attr.get());
aclIntArray *normalizedShape = nullptr;
normalizedShape = aclCreateIntArray(attr->normalizedShape.data(), attr->size);
ret = aclnnLayerNormGetWorkspaceSize(inputTensors[0], normalizedShape, inputTensors[1], inputTensors[2], attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnLayerNorm(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLayerNorm failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyIntArray(normalizedShape);
return;
}
}

View File

@ -0,0 +1,34 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class BatchNormOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
BatchNormOpRunner();
};
class BatchNormBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
BatchNormBackwardOpRunner();
};
class LayerNormOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
LayerNormOpRunner();
};
}

View File

@ -0,0 +1,74 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def triu_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class TriuACL(jt.Function):
def __init__(self):
super(TriuACL, self).__init__()
def execute(self, input, diagonal):
attr_code = f"""
op.jt_name = "triu";
TriuAttr *attr = new TriuAttr();
attr->diagonal = {diagonal};
op.op_attr.reset(attr);
"""
result = triu_cmd("Triu", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
return grad_output

View File

@ -0,0 +1,58 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "triu_op_acl.h"
namespace jittor
{
TriuOpRunner::TriuOpRunner() : BaseOpRunner("Triu")
{
}
void TriuOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<TriuAttr *>(op_attr.get());
ret = aclnnTriuGetWorkspaceSize(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnTriu(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnTriu failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,16 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class TriuOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
TriuOpRunner();
};
}