split relu,dropout,transpose,flashattention

This commit is contained in:
Exusial 2024-12-19 09:36:32 +08:00
parent 2a67644b0d
commit bfe1ceb82b
15 changed files with 1247 additions and 304 deletions

View File

@ -1867,213 +1867,215 @@ def change_function():
def matmul_transpose_acl(x1, x2):
return MatmulACL(True)(x1, x2)
class TransPoseACL(Function):
# class TransPoseACL(Function):
def __init__(self):
super(TransPoseACL, self).__init__()
# def __init__(self):
# super(TransPoseACL, self).__init__()
def execute(self, x, *dim):
self.input = x
if len(dim) == 1 and isinstance(dim[0], Sequence):
dim = dim[0]
elif len(dim) == 2:
axes = list(range(x.ndim))
a, b = dim
axes[a], axes[b] = axes[b], axes[a]
dim = axes
# def execute(self, x, *dim):
# self.input = x
# if len(dim) == 1 and isinstance(dim[0], Sequence):
# dim = dim[0]
# elif len(dim) == 2:
# axes = list(range(x.ndim))
# a, b = dim
# axes[a], axes[b] = axes[b], axes[a]
# dim = axes
attr_code = f"""
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
attr->axes = {{ {", ".join(map(str, dim))} }};
op.op_attr.reset(attr);
"""
# calculate output shape
output_shape = [x.shape[i] for i in dim]
output = acl_cmd("Transpose", [x],
output_dtypes=[x.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
self.dim = dim
return output
# attr_code = f"""
# op.jt_name = "transpose";
# ReduceAttr *attr = new ReduceAttr();
# attr->axes = {{ {", ".join(map(str, dim))} }};
# op.op_attr.reset(attr);
# """
# # calculate output shape
# output_shape = [x.shape[i] for i in dim]
# output = acl_cmd("Transpose", [x],
# output_dtypes=[x.dtype],
# output_shapes=[jt.empty(output_shape).shape],
# attr_code=attr_code)[0]
# self.dim = dim
# return output
def grad(self, grad_output):
dim = list(range(grad_output.ndim))
for i, p in enumerate(self.dim):
dim[p] = i
output_shape = [grad_output.shape[i] for i in dim]
attr_code = f"""
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
attr->axes = {{ {", ".join(map(str, dim))} }};
op.op_attr.reset(attr);
"""
output = acl_cmd("Transpose", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
return output
# def grad(self, grad_output):
# dim = list(range(grad_output.ndim))
# for i, p in enumerate(self.dim):
# dim[p] = i
# output_shape = [grad_output.shape[i] for i in dim]
# attr_code = f"""
# op.jt_name = "transpose";
# ReduceAttr *attr = new ReduceAttr();
# attr->axes = {{ {", ".join(map(str, dim))} }};
# op.op_attr.reset(attr);
# """
# output = acl_cmd("Transpose", [grad_output],
# output_dtypes=[grad_output.dtype],
# output_shapes=[jt.empty(output_shape).shape],
# attr_code=attr_code)[0]
# return output
from .aclops.transpose_op import TransPoseACL
def transpose_acl(x, *dim):
return TransPoseACL()(x, *dim)
class FlashAttentionACL(Function):
# class FlashAttentionACL(Function):
def __init__(self,
headnum,
layout="BNSD",
prefix=None,
qstart=None,
kvstart=None,
scale=1.0,
prob=1.0,
pretokens=2147483647,
nexttokens=2147483647,
innerprecise=0,
sparsemode=0,
psetype=1):
self.headnum = headnum
self.layout = layout
self.scale = scale
self.prob = prob
self.pretokens = pretokens
self.nexttokens = nexttokens
self.innerprecise = innerprecise
self.sparsemode = sparsemode
self.psetype = psetype
self.prefix = prefix
self.qstart = qstart
self.kvstart = kvstart
# def __init__(self,
# headnum,
# layout="BNSD",
# prefix=None,
# qstart=None,
# kvstart=None,
# scale=1.0,
# prob=1.0,
# pretokens=2147483647,
# nexttokens=2147483647,
# innerprecise=0,
# sparsemode=0,
# psetype=1):
# self.headnum = headnum
# self.layout = layout
# self.scale = scale
# self.prob = prob
# self.pretokens = pretokens
# self.nexttokens = nexttokens
# self.innerprecise = innerprecise
# self.sparsemode = sparsemode
# self.psetype = psetype
# self.prefix = prefix
# self.qstart = qstart
# self.kvstart = kvstart
def execute(
self,
q,
k,
v,
realshift=None,
dropMask=None,
paddingMask=None,
attenMask=None,
):
if self.layout == 'BSH':
B, SQ, H = q.shape
SKV = k.shape[1]
N = self.headnum
D = H / N
elif self.layout == 'SBH':
SQ, B, H = q.shape
SKV = k.shape[0]
N = self.headnum
D = H / N
elif self.layout == 'BSND':
B, SQ, N, D = q.shape
SKV = k.shape[1]
elif self.layout == 'BNSD':
B, N, SQ, D = q.shape
SKV = k.shape[2]
else:
raise ValueError(f"got invalid input layout {self.layout}")
# def execute(
# self,
# q,
# k,
# v,
# realshift=None,
# dropMask=None,
# paddingMask=None,
# attenMask=None,
# ):
# if self.layout == 'BSH':
# B, SQ, H = q.shape
# SKV = k.shape[1]
# N = self.headnum
# D = H / N
# elif self.layout == 'SBH':
# SQ, B, H = q.shape
# SKV = k.shape[0]
# N = self.headnum
# D = H / N
# elif self.layout == 'BSND':
# B, SQ, N, D = q.shape
# SKV = k.shape[1]
# elif self.layout == 'BNSD':
# B, N, SQ, D = q.shape
# SKV = k.shape[2]
# else:
# raise ValueError(f"got invalid input layout {self.layout}")
output_shape = (B, N, SQ, 8)
# output_shape = (B, N, SQ, 8)
self.q = q
self.k = k
self.v = v
# self.q = q
# self.k = k
# self.v = v
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
self.kvstart = self.kvstart if self.kvstart else [
0 for _ in range(B)
]
# self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
# self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
# self.kvstart = self.kvstart if self.kvstart else [
# 0 for _ in range(B)
# ]
self.hasRealshift = (not realshift == None)
self.hasDropmask = (not dropMask == None)
self.hasPaddingmask = (not paddingMask == None)
self.hasAttenmask = (not attenMask == None)
# self.hasRealshift = (not realshift == None)
# self.hasDropmask = (not dropMask == None)
# self.hasPaddingmask = (not paddingMask == None)
# self.hasAttenmask = (not attenMask == None)
# 待定目前设为nullptr
self.realshift = realshift if realshift else jt.zeros(
B, N, SQ, SKV)
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
self.paddingMask = paddingMask if paddingMask else jt.zeros(
B, N, SQ, SKV)
self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
# # 待定目前设为nullptr
# self.realshift = realshift if realshift else jt.zeros(
# B, N, SQ, SKV)
# self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
# self.paddingMask = paddingMask if paddingMask else jt.zeros(
# B, N, SQ, SKV)
# self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
attr_code = f"""
op.jt_name = "flashattention";
FlashAttentionAttr *attr = new FlashAttentionAttr();
attr->scale = {self.scale};
attr->keepProb = {self.prob};
attr->preToken = {self.pretokens};
attr->nextToken = {self.nexttokens};
attr->headNum = {self.headnum};
attr->inputLayout = "{self.layout}";
attr->innerPrecise = {self.innerprecise};
attr->sparseMode = {self.sparsemode};
attr->psetype = {self.psetype};
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
op.op_attr.reset(attr);
"""
# attr_code = f"""
# op.jt_name = "flashattention";
# FlashAttentionAttr *attr = new FlashAttentionAttr();
# attr->scale = {self.scale};
# attr->keepProb = {self.prob};
# attr->preToken = {self.pretokens};
# attr->nextToken = {self.nexttokens};
# attr->headNum = {self.headnum};
# attr->inputLayout = "{self.layout}";
# attr->innerPrecise = {self.innerprecise};
# attr->sparseMode = {self.sparsemode};
# attr->psetype = {self.psetype};
# attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
# attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
# attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
# attr->hasRealshift = {"true" if self.hasRealshift else "false"};
# attr->hasDropmask = {"true" if self.hasDropmask else "false"};
# attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
# attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
# op.op_attr.reset(attr);
# """
inputs = [
q, k, v, self.realshift, self.dropMask, self.paddingMask,
self.attenMask
]
# inputs = [
# q, k, v, self.realshift, self.dropMask, self.paddingMask,
# self.attenMask
# ]
result = acl_cmd(
"FlashAttention",
inputs,
output_dtypes=["float", "float", q.dtype],
output_shapes=[output_shape, output_shape, q.shape],
attr_code=attr_code)
# result = acl_cmd(
# "FlashAttention",
# inputs,
# output_dtypes=["float", "float", q.dtype],
# output_shapes=[output_shape, output_shape, q.shape],
# attr_code=attr_code)
self.maxout = result[0]
self.sumout = result[1]
self.attenout = result[2]
# self.maxout = result[0]
# self.sumout = result[1]
# self.attenout = result[2]
return self.attenout
# return self.attenout
def grad(self, dy):
attr_code = f"""
op.jt_name = "flashattentionbackward";
FlashAttentionAttr *attr = new FlashAttentionAttr();
attr->scale = {self.scale};
attr->keepProb = {self.prob};
attr->preToken = {self.pretokens};
attr->nextToken = {self.nexttokens};
attr->headNum = {self.headnum};
attr->inputLayout = "{self.layout}";
attr->innerPrecise = {self.innerprecise};
attr->sparseMode = {self.sparsemode};
attr->psetype = {self.psetype};
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
op.op_attr.reset(attr);
"""
inputs = [
self.q, self.k, self.v, dy, self.realshift, self.dropMask,
self.paddingMask, self.attenMask, self.maxout, self.sumout,
self.attenout
]
# def grad(self, dy):
# attr_code = f"""
# op.jt_name = "flashattentionbackward";
# FlashAttentionAttr *attr = new FlashAttentionAttr();
# attr->scale = {self.scale};
# attr->keepProb = {self.prob};
# attr->preToken = {self.pretokens};
# attr->nextToken = {self.nexttokens};
# attr->headNum = {self.headnum};
# attr->inputLayout = "{self.layout}";
# attr->innerPrecise = {self.innerprecise};
# attr->sparseMode = {self.sparsemode};
# attr->psetype = {self.psetype};
# attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
# attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
# attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
# attr->hasRealshift = {"true" if self.hasRealshift else "false"};
# attr->hasDropmask = {"true" if self.hasDropmask else "false"};
# attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
# attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
# op.op_attr.reset(attr);
# """
# inputs = [
# self.q, self.k, self.v, dy, self.realshift, self.dropMask,
# self.paddingMask, self.attenMask, self.maxout, self.sumout,
# self.attenout
# ]
result = acl_cmd(
"FlashAttentionBackward",
inputs,
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
attr_code=attr_code)
return result
# result = acl_cmd(
# "FlashAttentionBackward",
# inputs,
# output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
# output_shapes=[self.q.shape, self.k.shape, self.v.shape],
# attr_code=attr_code)
# return result
class ReLUACL(Function):
@ -2112,42 +2114,43 @@ def change_function():
def relu(x):
return ReLUACL()(x)
class LeakyReLUACL(Function):
# class LeakyReLUACL(Function):
def __init__(self):
super(LeakyReLUACL, self).__init__()
# def __init__(self):
# super(LeakyReLUACL, self).__init__()
def execute(self, x, negative_slope=0.01):
x = x.float32()
self.input = x
attr_code = f"""
op.jt_name = "leakyrelu";
LeakyReluAttr *attr = new LeakyReluAttr();
attr->negativeSlope = {negative_slope};
op.op_attr.reset(attr);
"""
result = acl_cmd("LeakyReLU", [x],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code=attr_code)[0]
self.negative_slope = negative_slope
return result
# def execute(self, x, negative_slope=0.01):
# x = x.float32()
# self.input = x
# attr_code = f"""
# op.jt_name = "leakyrelu";
# LeakyReluAttr *attr = new LeakyReluAttr();
# attr->negativeSlope = {negative_slope};
# op.op_attr.reset(attr);
# """
# result = acl_cmd("LeakyReLU", [x],
# output_dtypes=[x.dtype],
# output_shapes=[x.shape],
# attr_code=attr_code)[0]
# self.negative_slope = negative_slope
# return result
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "leakyrelubackward";
LeakyReluAttr *attr = new LeakyReluAttr();
attr->negativeSlope = {self.negative_slope};
attr->selfIsResult = false;
op.op_attr.reset(attr);
"""
grad_input = acl_cmd("LeakyReLUBackward",
[grad_output, self.input],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input
# def grad(self, grad_output):
# attr_code = f"""
# op.jt_name = "leakyrelubackward";
# LeakyReluAttr *attr = new LeakyReluAttr();
# attr->negativeSlope = {self.negative_slope};
# attr->selfIsResult = false;
# op.op_attr.reset(attr);
# """
# grad_input = acl_cmd("LeakyReLUBackward",
# [grad_output, self.input],
# output_dtypes=[grad_output.dtype],
# output_shapes=[grad_output.shape],
# attr_code=attr_code)[0]
# return grad_input
from .aclops.relu_op import LeakyReLUACL
class LeakyReLU(jt.nn.Module):
def __init__(self, negative_slope=0.01):
@ -2160,45 +2163,45 @@ def change_function():
def leaky_relu(x, scale=0.01):
return LeakyReLUACL()(x, scale)
class DropoutACL(Function):
# class DropoutACL(Function):
def __init__(self):
super(DropoutACL, self).__init__()
# def __init__(self):
# super(DropoutACL, self).__init__()
def execute(self, x, p=0.5, is_train=False):
self.input = x
num_elements = x.numel()
aligned_elements = (num_elements + 127) // 128 * 128
mask_shape = (aligned_elements // 8, )
attr_code = f"""
op.jt_name = "dropout";
DropoutAttr *attr = new DropoutAttr();
attr->p = {p};
attr->train = {"true" if is_train else "false"};
attr->seed = 0;
attr->offset = 0;
op.op_attr.reset(attr);
"""
result = acl_cmd("Dropout", [x],
output_dtypes=[x.dtype, "uint8"],
output_shapes=[x.shape, mask_shape],
attr_code=attr_code)
self.maskout = result[1]
return result[0]
# def execute(self, x, p=0.5, is_train=False):
# self.input = x
# num_elements = x.numel()
# aligned_elements = (num_elements + 127) // 128 * 128
# mask_shape = (aligned_elements // 8, )
# attr_code = f"""
# op.jt_name = "dropout";
# DropoutAttr *attr = new DropoutAttr();
# attr->p = {p};
# attr->train = {"true" if is_train else "false"};
# attr->seed = 0;
# attr->offset = 0;
# op.op_attr.reset(attr);
# """
# result = acl_cmd("Dropout", [x],
# output_dtypes=[x.dtype, "uint8"],
# output_shapes=[x.shape, mask_shape],
# attr_code=attr_code)
# self.maskout = result[1]
# return result[0]
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "dropoutbackward";
DropoutAttr *attr = new DropoutAttr();
attr->scale = 1.0;
op.op_attr.reset(attr);
"""
grad_input = acl_cmd("DropoutBackward",
[grad_output, self.maskout],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input
# def grad(self, grad_output):
# attr_code = f"""
# op.jt_name = "dropoutbackward";
# DropoutAttr *attr = new DropoutAttr();
# attr->scale = 1.0;
# op.op_attr.reset(attr);
# """
# grad_input = acl_cmd("DropoutBackward",
# [grad_output, self.maskout],
# output_dtypes=[grad_output.dtype],
# output_shapes=[grad_output.shape],
# attr_code=attr_code)[0]
# return grad_input
class SiLUACL(Function):
@ -2342,6 +2345,7 @@ def change_function():
res = embedding_acl(x, self.weight)
return res
from .aclops.dropout_op import DropoutACL
class Dropout(jt.nn.Module):
def __init__(self, p=0.5, is_train=False):
@ -2795,6 +2799,7 @@ def change_function():
# jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL)
# jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL)
from .aclops.flashattention_op import FlashAttentionACL
jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL)
jt.isnan = warp(jt.isnan, isnan_acl)
jt.isinf = warp(jt.isinf, isinf_acl)

View File

@ -273,11 +273,11 @@ namespace jittor
ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor);
break;
}
case 19:
{
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
break;
}
// case 19:
// {
// ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 20:
{
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
@ -442,32 +442,32 @@ namespace jittor
// ret = it->second.getWorkspaceSizeFuncRange(start, end, step, outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 38:
{
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
ret = it->second.getWorkspaceSizeFuncLeakyRelu(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
break;
}
case 39:
{
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
ret = it->second.getWorkspaceSizeFuncLeakyReluBackward(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
break;
}
case 40:
{
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
ret = it->second.getWorkspaceSizeFuncDropout(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
break;
}
case 41:
{
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
break;
}
// case 38:
// {
// auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
// negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
// ret = it->second.getWorkspaceSizeFuncLeakyRelu(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 39:
// {
// auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
// negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
// ret = it->second.getWorkspaceSizeFuncLeakyReluBackward(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 40:
// {
// auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncDropout(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
// break;
// }
// case 41:
// {
// auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
// ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 42:
{
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
@ -518,26 +518,26 @@ namespace jittor
// ret = it->second.getWorkspaceSizeFuncSplitWithSize(inputTensors[0], splitSize, attr->dim, tensorList, &workspaceSize, &executor);
// break;
// }
case 51:
{
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
char *layout = const_cast<char *>(attr->inputLayout.data());
ret = it->second.getWorkspaceSizeFuncFalshAttention(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
break;
}
case 52:
{
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
char *layout = const_cast<char *>(attr->inputLayout.data());
ret = it->second.getWorkspaceSizeFuncFalshAttentionBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
break;
}
// case 51:
// {
// auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
// auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
// auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
// auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
// char *layout = const_cast<char *>(attr->inputLayout.data());
// ret = it->second.getWorkspaceSizeFuncFalshAttention(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
// break;
// }
// case 52:
// {
// auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
// auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
// auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
// auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
// char *layout = const_cast<char *>(attr->inputLayout.data());
// ret = it->second.getWorkspaceSizeFuncFalshAttentionBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
// break;
// }
case 53:
{
auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());

View File

@ -18,3 +18,7 @@
#include <acl/aclops/index_op_acl.h>
#include <acl/aclops/where_op_acl.h>
#include <acl/aclops/floor_op_acl.h>
#include <acl/aclops/transpose_op_acl.h>
#include <acl/aclops/flashattention_op_acl.h>
#include <acl/aclops/relu_op_acl.h>
#include <acl/aclops/dropout_op_acl.h>

View File

@ -0,0 +1,92 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def dropout_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class DropoutACL(jt.Function):
def __init__(self):
super(DropoutACL, self).__init__()
def execute(self, x, p=0.5, is_train=False):
self.input = x
num_elements = x.numel()
aligned_elements = (num_elements + 127) // 128 * 128
mask_shape = (aligned_elements // 8, )
attr_code = f"""
op.jt_name = "dropout";
DropoutAttr *attr = new DropoutAttr();
attr->p = {p};
attr->train = {"true" if is_train else "false"};
attr->seed = 0;
attr->offset = 0;
op.op_attr.reset(attr);
"""
result = dropout_cmd("Dropout", [x],
output_dtypes=[x.dtype, "uint8"],
output_shapes=[x.shape, mask_shape],
attr_code=attr_code)
self.maskout = result[1]
return result[0]
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "dropoutbackward";
DropoutAttr *attr = new DropoutAttr();
attr->scale = 1.0;
op.op_attr.reset(attr);
"""
grad_input = dropout_cmd("DropoutBackward",
[grad_output, self.maskout],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input

View File

@ -0,0 +1,82 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "dropout_op_acl.h"
namespace jittor
{
DropoutOpRunner::DropoutOpRunner() : BaseOpRunner("Dropout")
{
}
void DropoutOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
ret = aclnnDropoutGetWorkspaceSize(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnDropout(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropout failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
DropoutBackwardOpRunner::DropoutBackwardOpRunner() : BaseOpRunner("DropoutBackward")
{
}
void DropoutBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
ret = aclnnDropoutBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnDropoutBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropoutBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class DropoutOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
DropoutOpRunner();
};
class DropoutBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
DropoutBackwardOpRunner();
};
}

View File

@ -0,0 +1,210 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def flashattention_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class FlashAttentionACL(jt.Function):
def __init__(self,
headnum,
layout="BNSD",
prefix=None,
qstart=None,
kvstart=None,
scale=1.0,
prob=1.0,
pretokens=2147483647,
nexttokens=2147483647,
innerprecise=0,
sparsemode=0,
psetype=1):
self.headnum = headnum
self.layout = layout
self.scale = scale
self.prob = prob
self.pretokens = pretokens
self.nexttokens = nexttokens
self.innerprecise = innerprecise
self.sparsemode = sparsemode
self.psetype = psetype
self.prefix = prefix
self.qstart = qstart
self.kvstart = kvstart
def execute(
self,
q,
k,
v,
realshift=None,
dropMask=None,
paddingMask=None,
attenMask=None,
):
if self.layout == 'BSH':
B, SQ, H = q.shape
SKV = k.shape[1]
N = self.headnum
D = H / N
elif self.layout == 'SBH':
SQ, B, H = q.shape
SKV = k.shape[0]
N = self.headnum
D = H / N
elif self.layout == 'BSND':
B, SQ, N, D = q.shape
SKV = k.shape[1]
elif self.layout == 'BNSD':
B, N, SQ, D = q.shape
SKV = k.shape[2]
else:
raise ValueError(f"got invalid input layout {self.layout}")
output_shape = (B, N, SQ, 8)
self.q = q
self.k = k
self.v = v
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
self.kvstart = self.kvstart if self.kvstart else [
0 for _ in range(B)
]
self.hasRealshift = (not realshift == None)
self.hasDropmask = (not dropMask == None)
self.hasPaddingmask = (not paddingMask == None)
self.hasAttenmask = (not attenMask == None)
# 待定目前设为nullptr
self.realshift = realshift if realshift else jt.zeros(
B, N, SQ, SKV)
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
self.paddingMask = paddingMask if paddingMask else jt.zeros(
B, N, SQ, SKV)
self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
attr_code = f"""
op.jt_name = "flashattention";
FlashAttentionAttr *attr = new FlashAttentionAttr();
attr->scale = {self.scale};
attr->keepProb = {self.prob};
attr->preToken = {self.pretokens};
attr->nextToken = {self.nexttokens};
attr->headNum = {self.headnum};
attr->inputLayout = "{self.layout}";
attr->innerPrecise = {self.innerprecise};
attr->sparseMode = {self.sparsemode};
attr->psetype = {self.psetype};
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
op.op_attr.reset(attr);
"""
inputs = [
q, k, v, self.realshift, self.dropMask, self.paddingMask,
self.attenMask
]
result = flashattention_cmd(
"FlashAttention",
inputs,
output_dtypes=["float", "float", q.dtype],
output_shapes=[output_shape, output_shape, q.shape],
attr_code=attr_code)
self.maxout = result[0]
self.sumout = result[1]
self.attenout = result[2]
return self.attenout
def grad(self, dy):
attr_code = f"""
op.jt_name = "flashattentionbackward";
FlashAttentionAttr *attr = new FlashAttentionAttr();
attr->scale = {self.scale};
attr->keepProb = {self.prob};
attr->preToken = {self.pretokens};
attr->nextToken = {self.nexttokens};
attr->headNum = {self.headnum};
attr->inputLayout = "{self.layout}";
attr->innerPrecise = {self.innerprecise};
attr->sparseMode = {self.sparsemode};
attr->psetype = {self.psetype};
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
op.op_attr.reset(attr);
"""
inputs = [
self.q, self.k, self.v, dy, self.realshift, self.dropMask,
self.paddingMask, self.attenMask, self.maxout, self.sumout,
self.attenout
]
result = flashattention_cmd(
"FlashAttentionBackward",
inputs,
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
attr_code=attr_code)
return result

View File

@ -0,0 +1,89 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "flashattention_op_acl.h"
namespace jittor
{
FlashAttentionOpRunner::FlashAttentionOpRunner() : BaseOpRunner("FlashAttention")
{
}
void FlashAttentionOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
char *layout = const_cast<char *>(attr->inputLayout.data());
ret = aclnnFlashAttentionScoreV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFlashAttentionScoreV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreV2 failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward")
{
}
void FlashAttentionBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
char *layout = const_cast<char *>(attr->inputLayout.data());
ret = aclnnFlashAttentionScoreGradV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFlashAttentionScoreGradV2(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreGradV2 failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class FlashAttentionOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlashAttentionOpRunner();
};
class FlashAttentionBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FlashAttentionBackwardOpRunner();
};
}

View File

@ -0,0 +1,115 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def relu_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class ReLUACL(jt.Function):
def __init__(self):
super(ReLUACL, self).__init__()
def execute(self, x):
x = x.float32()
self.input = x
result = relu_cmd("ReLU", [x],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"unary\";")[0]
return result
def grad(self, grad_output):
mask = relu_cmd("Greater",
[self.input, jt.zeros(self.input.shape)],
output_dtypes=[self.input.dtype],
output_shapes=[self.input.shape],
attr_code="op.jt_name=\"binary\";")[0]
grad_input = relu_cmd("Mul", [grad_output, mask],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code="op.jt_name=\"binary\";")[0]
return grad_input
class LeakyReLUACL(jt.Function):
def __init__(self):
super(LeakyReLUACL, self).__init__()
def execute(self, x, negative_slope=0.01):
x = x.float32()
self.input = x
attr_code = f"""
op.jt_name = "leakyrelu";
LeakyReluAttr *attr = new LeakyReluAttr();
attr->negativeSlope = {negative_slope};
op.op_attr.reset(attr);
"""
result = relu_cmd("LeakyReLU", [x],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code=attr_code)[0]
self.negative_slope = negative_slope
return result
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "leakyrelubackward";
LeakyReluAttr *attr = new LeakyReluAttr();
attr->negativeSlope = {self.negative_slope};
attr->selfIsResult = false;
op.op_attr.reset(attr);
"""
grad_input = relu_cmd("LeakyReLUBackward",
[grad_output, self.input],
output_dtypes=[grad_output.dtype],
output_shapes=[grad_output.shape],
attr_code=attr_code)[0]
return grad_input

View File

@ -0,0 +1,90 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "relu_op_acl.h"
namespace jittor
{
LeakyReLUOpRunner::LeakyReLUOpRunner() : BaseOpRunner("LeakyReLU")
{
}
void LeakyReLUOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
aclScalar *negativeSlope = nullptr;
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
ret = aclnnLeakyReluGetWorkspaceSize(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLeakyRelu failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyScalar(negativeSlope);
return;
}
LeakyReLUBackwardOpRunner::LeakyReLUBackwardOpRunner() : BaseOpRunner("LeakyReLUBackward")
{
}
void LeakyReLUBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
aclScalar *negativeSlope = nullptr;
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
ret = aclnnLeakyReluBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnLeakyReluBackward(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLeakyReluBackward failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyScalar(negativeSlope);
return;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class LeakyReLUOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
LeakyReLUOpRunner();
};
class LeakyReLUBackwardOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
LeakyReLUBackwardOpRunner();
};
}

View File

@ -0,0 +1,99 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def transpose_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class TransPoseACL(jt.Function):
def __init__(self):
super(TransPoseACL, self).__init__()
def execute(self, x, *dim):
self.input = x
if len(dim) == 1 and isinstance(dim[0], Sequence):
dim = dim[0]
elif len(dim) == 2:
axes = list(range(x.ndim))
a, b = dim
axes[a], axes[b] = axes[b], axes[a]
dim = axes
attr_code = f"""
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
attr->axes = {{ {", ".join(map(str, dim))} }};
op.op_attr.reset(attr);
"""
# calculate output shape
output_shape = [x.shape[i] for i in dim]
output = transpose_cmd("Transpose", [x],
output_dtypes=[x.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
self.dim = dim
return output
def grad(self, grad_output):
dim = list(range(grad_output.ndim))
for i, p in enumerate(self.dim):
dim[p] = i
output_shape = [grad_output.shape[i] for i in dim]
attr_code = f"""
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
attr->axes = {{ {", ".join(map(str, dim))} }};
op.op_attr.reset(attr);
"""
output = transpose_cmd("Transpose", [grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[jt.empty(output_shape).shape],
attr_code=attr_code)[0]
return output

View File

@ -0,0 +1,66 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "transpose_op_acl.h"
namespace jittor
{
TransposeOpRunner::TransposeOpRunner() : BaseOpRunner("Transpose")
{
}
void TransposeOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
aclIntArray *dim = nullptr;
dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
bool keepdims = attr->keepdims;
ret = aclnnPermuteGetWorkspaceSize(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnPermute(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnPermute failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
aclDestroyIntArray(dim);
return;
}
}

View File

@ -0,0 +1,16 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class TransposeOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
TransposeOpRunner();
};
}