mirror of https://github.com/Jittor/Jittor
split relu,dropout,transpose,flashattention
This commit is contained in:
parent
2a67644b0d
commit
bfe1ceb82b
|
@ -1867,213 +1867,215 @@ def change_function():
|
|||
def matmul_transpose_acl(x1, x2):
|
||||
return MatmulACL(True)(x1, x2)
|
||||
|
||||
class TransPoseACL(Function):
|
||||
# class TransPoseACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(TransPoseACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(TransPoseACL, self).__init__()
|
||||
|
||||
def execute(self, x, *dim):
|
||||
self.input = x
|
||||
if len(dim) == 1 and isinstance(dim[0], Sequence):
|
||||
dim = dim[0]
|
||||
elif len(dim) == 2:
|
||||
axes = list(range(x.ndim))
|
||||
a, b = dim
|
||||
axes[a], axes[b] = axes[b], axes[a]
|
||||
dim = axes
|
||||
# def execute(self, x, *dim):
|
||||
# self.input = x
|
||||
# if len(dim) == 1 and isinstance(dim[0], Sequence):
|
||||
# dim = dim[0]
|
||||
# elif len(dim) == 2:
|
||||
# axes = list(range(x.ndim))
|
||||
# a, b = dim
|
||||
# axes[a], axes[b] = axes[b], axes[a]
|
||||
# dim = axes
|
||||
|
||||
attr_code = f"""
|
||||
op.jt_name = "transpose";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{ {", ".join(map(str, dim))} }};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
# calculate output shape
|
||||
output_shape = [x.shape[i] for i in dim]
|
||||
output = acl_cmd("Transpose", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
self.dim = dim
|
||||
return output
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "transpose";
|
||||
# ReduceAttr *attr = new ReduceAttr();
|
||||
# attr->axes = {{ {", ".join(map(str, dim))} }};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# # calculate output shape
|
||||
# output_shape = [x.shape[i] for i in dim]
|
||||
# output = acl_cmd("Transpose", [x],
|
||||
# output_dtypes=[x.dtype],
|
||||
# output_shapes=[jt.empty(output_shape).shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# self.dim = dim
|
||||
# return output
|
||||
|
||||
def grad(self, grad_output):
|
||||
dim = list(range(grad_output.ndim))
|
||||
for i, p in enumerate(self.dim):
|
||||
dim[p] = i
|
||||
output_shape = [grad_output.shape[i] for i in dim]
|
||||
attr_code = f"""
|
||||
op.jt_name = "transpose";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{ {", ".join(map(str, dim))} }};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
output = acl_cmd("Transpose", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
return output
|
||||
# def grad(self, grad_output):
|
||||
# dim = list(range(grad_output.ndim))
|
||||
# for i, p in enumerate(self.dim):
|
||||
# dim[p] = i
|
||||
# output_shape = [grad_output.shape[i] for i in dim]
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "transpose";
|
||||
# ReduceAttr *attr = new ReduceAttr();
|
||||
# attr->axes = {{ {", ".join(map(str, dim))} }};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# output = acl_cmd("Transpose", [grad_output],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[jt.empty(output_shape).shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return output
|
||||
|
||||
from .aclops.transpose_op import TransPoseACL
|
||||
|
||||
def transpose_acl(x, *dim):
|
||||
return TransPoseACL()(x, *dim)
|
||||
|
||||
class FlashAttentionACL(Function):
|
||||
# class FlashAttentionACL(Function):
|
||||
|
||||
def __init__(self,
|
||||
headnum,
|
||||
layout="BNSD",
|
||||
prefix=None,
|
||||
qstart=None,
|
||||
kvstart=None,
|
||||
scale=1.0,
|
||||
prob=1.0,
|
||||
pretokens=2147483647,
|
||||
nexttokens=2147483647,
|
||||
innerprecise=0,
|
||||
sparsemode=0,
|
||||
psetype=1):
|
||||
self.headnum = headnum
|
||||
self.layout = layout
|
||||
self.scale = scale
|
||||
self.prob = prob
|
||||
self.pretokens = pretokens
|
||||
self.nexttokens = nexttokens
|
||||
self.innerprecise = innerprecise
|
||||
self.sparsemode = sparsemode
|
||||
self.psetype = psetype
|
||||
self.prefix = prefix
|
||||
self.qstart = qstart
|
||||
self.kvstart = kvstart
|
||||
# def __init__(self,
|
||||
# headnum,
|
||||
# layout="BNSD",
|
||||
# prefix=None,
|
||||
# qstart=None,
|
||||
# kvstart=None,
|
||||
# scale=1.0,
|
||||
# prob=1.0,
|
||||
# pretokens=2147483647,
|
||||
# nexttokens=2147483647,
|
||||
# innerprecise=0,
|
||||
# sparsemode=0,
|
||||
# psetype=1):
|
||||
# self.headnum = headnum
|
||||
# self.layout = layout
|
||||
# self.scale = scale
|
||||
# self.prob = prob
|
||||
# self.pretokens = pretokens
|
||||
# self.nexttokens = nexttokens
|
||||
# self.innerprecise = innerprecise
|
||||
# self.sparsemode = sparsemode
|
||||
# self.psetype = psetype
|
||||
# self.prefix = prefix
|
||||
# self.qstart = qstart
|
||||
# self.kvstart = kvstart
|
||||
|
||||
def execute(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
realshift=None,
|
||||
dropMask=None,
|
||||
paddingMask=None,
|
||||
attenMask=None,
|
||||
):
|
||||
if self.layout == 'BSH':
|
||||
B, SQ, H = q.shape
|
||||
SKV = k.shape[1]
|
||||
N = self.headnum
|
||||
D = H / N
|
||||
elif self.layout == 'SBH':
|
||||
SQ, B, H = q.shape
|
||||
SKV = k.shape[0]
|
||||
N = self.headnum
|
||||
D = H / N
|
||||
elif self.layout == 'BSND':
|
||||
B, SQ, N, D = q.shape
|
||||
SKV = k.shape[1]
|
||||
elif self.layout == 'BNSD':
|
||||
B, N, SQ, D = q.shape
|
||||
SKV = k.shape[2]
|
||||
else:
|
||||
raise ValueError(f"got invalid input layout {self.layout}")
|
||||
# def execute(
|
||||
# self,
|
||||
# q,
|
||||
# k,
|
||||
# v,
|
||||
# realshift=None,
|
||||
# dropMask=None,
|
||||
# paddingMask=None,
|
||||
# attenMask=None,
|
||||
# ):
|
||||
# if self.layout == 'BSH':
|
||||
# B, SQ, H = q.shape
|
||||
# SKV = k.shape[1]
|
||||
# N = self.headnum
|
||||
# D = H / N
|
||||
# elif self.layout == 'SBH':
|
||||
# SQ, B, H = q.shape
|
||||
# SKV = k.shape[0]
|
||||
# N = self.headnum
|
||||
# D = H / N
|
||||
# elif self.layout == 'BSND':
|
||||
# B, SQ, N, D = q.shape
|
||||
# SKV = k.shape[1]
|
||||
# elif self.layout == 'BNSD':
|
||||
# B, N, SQ, D = q.shape
|
||||
# SKV = k.shape[2]
|
||||
# else:
|
||||
# raise ValueError(f"got invalid input layout {self.layout}")
|
||||
|
||||
output_shape = (B, N, SQ, 8)
|
||||
# output_shape = (B, N, SQ, 8)
|
||||
|
||||
self.q = q
|
||||
self.k = k
|
||||
self.v = v
|
||||
# self.q = q
|
||||
# self.k = k
|
||||
# self.v = v
|
||||
|
||||
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
|
||||
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
|
||||
self.kvstart = self.kvstart if self.kvstart else [
|
||||
0 for _ in range(B)
|
||||
]
|
||||
# self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
|
||||
# self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
|
||||
# self.kvstart = self.kvstart if self.kvstart else [
|
||||
# 0 for _ in range(B)
|
||||
# ]
|
||||
|
||||
self.hasRealshift = (not realshift == None)
|
||||
self.hasDropmask = (not dropMask == None)
|
||||
self.hasPaddingmask = (not paddingMask == None)
|
||||
self.hasAttenmask = (not attenMask == None)
|
||||
# self.hasRealshift = (not realshift == None)
|
||||
# self.hasDropmask = (not dropMask == None)
|
||||
# self.hasPaddingmask = (not paddingMask == None)
|
||||
# self.hasAttenmask = (not attenMask == None)
|
||||
|
||||
# 待定,目前设为nullptr
|
||||
self.realshift = realshift if realshift else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
|
||||
self.paddingMask = paddingMask if paddingMask else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
|
||||
# # 待定,目前设为nullptr
|
||||
# self.realshift = realshift if realshift else jt.zeros(
|
||||
# B, N, SQ, SKV)
|
||||
# self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
|
||||
# self.paddingMask = paddingMask if paddingMask else jt.zeros(
|
||||
# B, N, SQ, SKV)
|
||||
# self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
|
||||
|
||||
attr_code = f"""
|
||||
op.jt_name = "flashattention";
|
||||
FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
attr->scale = {self.scale};
|
||||
attr->keepProb = {self.prob};
|
||||
attr->preToken = {self.pretokens};
|
||||
attr->nextToken = {self.nexttokens};
|
||||
attr->headNum = {self.headnum};
|
||||
attr->inputLayout = "{self.layout}";
|
||||
attr->innerPrecise = {self.innerprecise};
|
||||
attr->sparseMode = {self.sparsemode};
|
||||
attr->psetype = {self.psetype};
|
||||
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "flashattention";
|
||||
# FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
# attr->scale = {self.scale};
|
||||
# attr->keepProb = {self.prob};
|
||||
# attr->preToken = {self.pretokens};
|
||||
# attr->nextToken = {self.nexttokens};
|
||||
# attr->headNum = {self.headnum};
|
||||
# attr->inputLayout = "{self.layout}";
|
||||
# attr->innerPrecise = {self.innerprecise};
|
||||
# attr->sparseMode = {self.sparsemode};
|
||||
# attr->psetype = {self.psetype};
|
||||
# attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
# attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
# attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
# attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
# attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
# attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
# attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
|
||||
inputs = [
|
||||
q, k, v, self.realshift, self.dropMask, self.paddingMask,
|
||||
self.attenMask
|
||||
]
|
||||
# inputs = [
|
||||
# q, k, v, self.realshift, self.dropMask, self.paddingMask,
|
||||
# self.attenMask
|
||||
# ]
|
||||
|
||||
result = acl_cmd(
|
||||
"FlashAttention",
|
||||
inputs,
|
||||
output_dtypes=["float", "float", q.dtype],
|
||||
output_shapes=[output_shape, output_shape, q.shape],
|
||||
attr_code=attr_code)
|
||||
# result = acl_cmd(
|
||||
# "FlashAttention",
|
||||
# inputs,
|
||||
# output_dtypes=["float", "float", q.dtype],
|
||||
# output_shapes=[output_shape, output_shape, q.shape],
|
||||
# attr_code=attr_code)
|
||||
|
||||
self.maxout = result[0]
|
||||
self.sumout = result[1]
|
||||
self.attenout = result[2]
|
||||
# self.maxout = result[0]
|
||||
# self.sumout = result[1]
|
||||
# self.attenout = result[2]
|
||||
|
||||
return self.attenout
|
||||
# return self.attenout
|
||||
|
||||
def grad(self, dy):
|
||||
attr_code = f"""
|
||||
op.jt_name = "flashattentionbackward";
|
||||
FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
attr->scale = {self.scale};
|
||||
attr->keepProb = {self.prob};
|
||||
attr->preToken = {self.pretokens};
|
||||
attr->nextToken = {self.nexttokens};
|
||||
attr->headNum = {self.headnum};
|
||||
attr->inputLayout = "{self.layout}";
|
||||
attr->innerPrecise = {self.innerprecise};
|
||||
attr->sparseMode = {self.sparsemode};
|
||||
attr->psetype = {self.psetype};
|
||||
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
inputs = [
|
||||
self.q, self.k, self.v, dy, self.realshift, self.dropMask,
|
||||
self.paddingMask, self.attenMask, self.maxout, self.sumout,
|
||||
self.attenout
|
||||
]
|
||||
# def grad(self, dy):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "flashattentionbackward";
|
||||
# FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
# attr->scale = {self.scale};
|
||||
# attr->keepProb = {self.prob};
|
||||
# attr->preToken = {self.pretokens};
|
||||
# attr->nextToken = {self.nexttokens};
|
||||
# attr->headNum = {self.headnum};
|
||||
# attr->inputLayout = "{self.layout}";
|
||||
# attr->innerPrecise = {self.innerprecise};
|
||||
# attr->sparseMode = {self.sparsemode};
|
||||
# attr->psetype = {self.psetype};
|
||||
# attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
# attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
# attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
# attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
# attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
# attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
# attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# inputs = [
|
||||
# self.q, self.k, self.v, dy, self.realshift, self.dropMask,
|
||||
# self.paddingMask, self.attenMask, self.maxout, self.sumout,
|
||||
# self.attenout
|
||||
# ]
|
||||
|
||||
result = acl_cmd(
|
||||
"FlashAttentionBackward",
|
||||
inputs,
|
||||
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
|
||||
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
|
||||
attr_code=attr_code)
|
||||
return result
|
||||
# result = acl_cmd(
|
||||
# "FlashAttentionBackward",
|
||||
# inputs,
|
||||
# output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
|
||||
# output_shapes=[self.q.shape, self.k.shape, self.v.shape],
|
||||
# attr_code=attr_code)
|
||||
# return result
|
||||
|
||||
class ReLUACL(Function):
|
||||
|
||||
|
@ -2112,42 +2114,43 @@ def change_function():
|
|||
def relu(x):
|
||||
return ReLUACL()(x)
|
||||
|
||||
class LeakyReLUACL(Function):
|
||||
# class LeakyReLUACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(LeakyReLUACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(LeakyReLUACL, self).__init__()
|
||||
|
||||
def execute(self, x, negative_slope=0.01):
|
||||
x = x.float32()
|
||||
self.input = x
|
||||
attr_code = f"""
|
||||
op.jt_name = "leakyrelu";
|
||||
LeakyReluAttr *attr = new LeakyReluAttr();
|
||||
attr->negativeSlope = {negative_slope};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("LeakyReLU", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code=attr_code)[0]
|
||||
self.negative_slope = negative_slope
|
||||
return result
|
||||
# def execute(self, x, negative_slope=0.01):
|
||||
# x = x.float32()
|
||||
# self.input = x
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "leakyrelu";
|
||||
# LeakyReluAttr *attr = new LeakyReluAttr();
|
||||
# attr->negativeSlope = {negative_slope};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd("LeakyReLU", [x],
|
||||
# output_dtypes=[x.dtype],
|
||||
# output_shapes=[x.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# self.negative_slope = negative_slope
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "leakyrelubackward";
|
||||
LeakyReluAttr *attr = new LeakyReluAttr();
|
||||
attr->negativeSlope = {self.negative_slope};
|
||||
attr->selfIsResult = false;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = acl_cmd("LeakyReLUBackward",
|
||||
[grad_output, self.input],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
# def grad(self, grad_output):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "leakyrelubackward";
|
||||
# LeakyReluAttr *attr = new LeakyReluAttr();
|
||||
# attr->negativeSlope = {self.negative_slope};
|
||||
# attr->selfIsResult = false;
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# grad_input = acl_cmd("LeakyReLUBackward",
|
||||
# [grad_output, self.input],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[grad_output.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
from .aclops.relu_op import LeakyReLUACL
|
||||
class LeakyReLU(jt.nn.Module):
|
||||
|
||||
def __init__(self, negative_slope=0.01):
|
||||
|
@ -2160,45 +2163,45 @@ def change_function():
|
|||
def leaky_relu(x, scale=0.01):
|
||||
return LeakyReLUACL()(x, scale)
|
||||
|
||||
class DropoutACL(Function):
|
||||
# class DropoutACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(DropoutACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(DropoutACL, self).__init__()
|
||||
|
||||
def execute(self, x, p=0.5, is_train=False):
|
||||
self.input = x
|
||||
num_elements = x.numel()
|
||||
aligned_elements = (num_elements + 127) // 128 * 128
|
||||
mask_shape = (aligned_elements // 8, )
|
||||
attr_code = f"""
|
||||
op.jt_name = "dropout";
|
||||
DropoutAttr *attr = new DropoutAttr();
|
||||
attr->p = {p};
|
||||
attr->train = {"true" if is_train else "false"};
|
||||
attr->seed = 0;
|
||||
attr->offset = 0;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Dropout", [x],
|
||||
output_dtypes=[x.dtype, "uint8"],
|
||||
output_shapes=[x.shape, mask_shape],
|
||||
attr_code=attr_code)
|
||||
self.maskout = result[1]
|
||||
return result[0]
|
||||
# def execute(self, x, p=0.5, is_train=False):
|
||||
# self.input = x
|
||||
# num_elements = x.numel()
|
||||
# aligned_elements = (num_elements + 127) // 128 * 128
|
||||
# mask_shape = (aligned_elements // 8, )
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "dropout";
|
||||
# DropoutAttr *attr = new DropoutAttr();
|
||||
# attr->p = {p};
|
||||
# attr->train = {"true" if is_train else "false"};
|
||||
# attr->seed = 0;
|
||||
# attr->offset = 0;
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd("Dropout", [x],
|
||||
# output_dtypes=[x.dtype, "uint8"],
|
||||
# output_shapes=[x.shape, mask_shape],
|
||||
# attr_code=attr_code)
|
||||
# self.maskout = result[1]
|
||||
# return result[0]
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "dropoutbackward";
|
||||
DropoutAttr *attr = new DropoutAttr();
|
||||
attr->scale = 1.0;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = acl_cmd("DropoutBackward",
|
||||
[grad_output, self.maskout],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
# def grad(self, grad_output):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "dropoutbackward";
|
||||
# DropoutAttr *attr = new DropoutAttr();
|
||||
# attr->scale = 1.0;
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# grad_input = acl_cmd("DropoutBackward",
|
||||
# [grad_output, self.maskout],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[grad_output.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
class SiLUACL(Function):
|
||||
|
||||
|
@ -2342,6 +2345,7 @@ def change_function():
|
|||
res = embedding_acl(x, self.weight)
|
||||
return res
|
||||
|
||||
from .aclops.dropout_op import DropoutACL
|
||||
class Dropout(jt.nn.Module):
|
||||
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
|
@ -2795,6 +2799,7 @@ def change_function():
|
|||
|
||||
# jt.nn.BatchNorm = warp(jt.nn.BatchNorm, BatchNormACL)
|
||||
# jt.nn.LayerNorm = warp(jt.nn.LayerNorm, LayerNormACL)
|
||||
from .aclops.flashattention_op import FlashAttentionACL
|
||||
jt.nn.FlashAttention = warp(jt.nn.FlashAttention, FlashAttentionACL)
|
||||
jt.isnan = warp(jt.isnan, isnan_acl)
|
||||
jt.isinf = warp(jt.isinf, isinf_acl)
|
||||
|
|
|
@ -273,11 +273,11 @@ namespace jittor
|
|||
ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 19:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 19:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 20:
|
||||
{
|
||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||
|
@ -442,32 +442,32 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncRange(start, end, step, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 38:
|
||||
{
|
||||
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
|
||||
ret = it->second.getWorkspaceSizeFuncLeakyRelu(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 39:
|
||||
{
|
||||
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
|
||||
ret = it->second.getWorkspaceSizeFuncLeakyReluBackward(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 40:
|
||||
{
|
||||
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncDropout(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 41:
|
||||
{
|
||||
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 38:
|
||||
// {
|
||||
// auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
// negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
|
||||
// ret = it->second.getWorkspaceSizeFuncLeakyRelu(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 39:
|
||||
// {
|
||||
// auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
// negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
|
||||
// ret = it->second.getWorkspaceSizeFuncLeakyReluBackward(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 40:
|
||||
// {
|
||||
// auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncDropout(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 41:
|
||||
// {
|
||||
// auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 42:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
@ -518,26 +518,26 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncSplitWithSize(inputTensors[0], splitSize, attr->dim, tensorList, &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 51:
|
||||
{
|
||||
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
ret = it->second.getWorkspaceSizeFuncFalshAttention(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 52:
|
||||
{
|
||||
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
ret = it->second.getWorkspaceSizeFuncFalshAttentionBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 51:
|
||||
// {
|
||||
// auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
// auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
// auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
// auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
// char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
// ret = it->second.getWorkspaceSizeFuncFalshAttention(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 52:
|
||||
// {
|
||||
// auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
// auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
// auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
// auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
// char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
// ret = it->second.getWorkspaceSizeFuncFalshAttentionBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 53:
|
||||
{
|
||||
auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
|
|
|
@ -18,3 +18,7 @@
|
|||
#include <acl/aclops/index_op_acl.h>
|
||||
#include <acl/aclops/where_op_acl.h>
|
||||
#include <acl/aclops/floor_op_acl.h>
|
||||
#include <acl/aclops/transpose_op_acl.h>
|
||||
#include <acl/aclops/flashattention_op_acl.h>
|
||||
#include <acl/aclops/relu_op_acl.h>
|
||||
#include <acl/aclops/dropout_op_acl.h>
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def dropout_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class DropoutACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(DropoutACL, self).__init__()
|
||||
|
||||
def execute(self, x, p=0.5, is_train=False):
|
||||
self.input = x
|
||||
num_elements = x.numel()
|
||||
aligned_elements = (num_elements + 127) // 128 * 128
|
||||
mask_shape = (aligned_elements // 8, )
|
||||
attr_code = f"""
|
||||
op.jt_name = "dropout";
|
||||
DropoutAttr *attr = new DropoutAttr();
|
||||
attr->p = {p};
|
||||
attr->train = {"true" if is_train else "false"};
|
||||
attr->seed = 0;
|
||||
attr->offset = 0;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = dropout_cmd("Dropout", [x],
|
||||
output_dtypes=[x.dtype, "uint8"],
|
||||
output_shapes=[x.shape, mask_shape],
|
||||
attr_code=attr_code)
|
||||
self.maskout = result[1]
|
||||
return result[0]
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "dropoutbackward";
|
||||
DropoutAttr *attr = new DropoutAttr();
|
||||
attr->scale = 1.0;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = dropout_cmd("DropoutBackward",
|
||||
[grad_output, self.maskout],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,82 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "dropout_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
DropoutOpRunner::DropoutOpRunner() : BaseOpRunner("Dropout")
|
||||
{
|
||||
}
|
||||
|
||||
void DropoutOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
ret = aclnnDropoutGetWorkspaceSize(inputTensors[0], attr->p, attr->train, attr->seed, attr->offset, outputTensors[0], outputTensors[1], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnDropout(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropout failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
DropoutBackwardOpRunner::DropoutBackwardOpRunner() : BaseOpRunner("DropoutBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void DropoutBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<DropoutAttr *>(op_attr.get());
|
||||
ret = aclnnDropoutBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnDropoutBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnDropoutBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class DropoutOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
DropoutOpRunner();
|
||||
};
|
||||
|
||||
class DropoutBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
DropoutBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def flashattention_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class FlashAttentionACL(jt.Function):
|
||||
|
||||
def __init__(self,
|
||||
headnum,
|
||||
layout="BNSD",
|
||||
prefix=None,
|
||||
qstart=None,
|
||||
kvstart=None,
|
||||
scale=1.0,
|
||||
prob=1.0,
|
||||
pretokens=2147483647,
|
||||
nexttokens=2147483647,
|
||||
innerprecise=0,
|
||||
sparsemode=0,
|
||||
psetype=1):
|
||||
self.headnum = headnum
|
||||
self.layout = layout
|
||||
self.scale = scale
|
||||
self.prob = prob
|
||||
self.pretokens = pretokens
|
||||
self.nexttokens = nexttokens
|
||||
self.innerprecise = innerprecise
|
||||
self.sparsemode = sparsemode
|
||||
self.psetype = psetype
|
||||
self.prefix = prefix
|
||||
self.qstart = qstart
|
||||
self.kvstart = kvstart
|
||||
|
||||
def execute(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
realshift=None,
|
||||
dropMask=None,
|
||||
paddingMask=None,
|
||||
attenMask=None,
|
||||
):
|
||||
if self.layout == 'BSH':
|
||||
B, SQ, H = q.shape
|
||||
SKV = k.shape[1]
|
||||
N = self.headnum
|
||||
D = H / N
|
||||
elif self.layout == 'SBH':
|
||||
SQ, B, H = q.shape
|
||||
SKV = k.shape[0]
|
||||
N = self.headnum
|
||||
D = H / N
|
||||
elif self.layout == 'BSND':
|
||||
B, SQ, N, D = q.shape
|
||||
SKV = k.shape[1]
|
||||
elif self.layout == 'BNSD':
|
||||
B, N, SQ, D = q.shape
|
||||
SKV = k.shape[2]
|
||||
else:
|
||||
raise ValueError(f"got invalid input layout {self.layout}")
|
||||
|
||||
output_shape = (B, N, SQ, 8)
|
||||
|
||||
self.q = q
|
||||
self.k = k
|
||||
self.v = v
|
||||
|
||||
self.prefix = self.prefix if self.prefix else [0 for _ in range(B)]
|
||||
self.qstart = self.qstart if self.qstart else [0 for _ in range(B)]
|
||||
self.kvstart = self.kvstart if self.kvstart else [
|
||||
0 for _ in range(B)
|
||||
]
|
||||
|
||||
self.hasRealshift = (not realshift == None)
|
||||
self.hasDropmask = (not dropMask == None)
|
||||
self.hasPaddingmask = (not paddingMask == None)
|
||||
self.hasAttenmask = (not attenMask == None)
|
||||
|
||||
# 待定,目前设为nullptr
|
||||
self.realshift = realshift if realshift else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
self.dropMask = dropMask if dropMask else jt.ones(B, N, SQ, SKV)
|
||||
self.paddingMask = paddingMask if paddingMask else jt.zeros(
|
||||
B, N, SQ, SKV)
|
||||
self.attenMask = attenMask if attenMask else jt.zeros(SQ, SKV)
|
||||
|
||||
attr_code = f"""
|
||||
op.jt_name = "flashattention";
|
||||
FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
attr->scale = {self.scale};
|
||||
attr->keepProb = {self.prob};
|
||||
attr->preToken = {self.pretokens};
|
||||
attr->nextToken = {self.nexttokens};
|
||||
attr->headNum = {self.headnum};
|
||||
attr->inputLayout = "{self.layout}";
|
||||
attr->innerPrecise = {self.innerprecise};
|
||||
attr->sparseMode = {self.sparsemode};
|
||||
attr->psetype = {self.psetype};
|
||||
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
|
||||
inputs = [
|
||||
q, k, v, self.realshift, self.dropMask, self.paddingMask,
|
||||
self.attenMask
|
||||
]
|
||||
|
||||
result = flashattention_cmd(
|
||||
"FlashAttention",
|
||||
inputs,
|
||||
output_dtypes=["float", "float", q.dtype],
|
||||
output_shapes=[output_shape, output_shape, q.shape],
|
||||
attr_code=attr_code)
|
||||
|
||||
self.maxout = result[0]
|
||||
self.sumout = result[1]
|
||||
self.attenout = result[2]
|
||||
|
||||
return self.attenout
|
||||
|
||||
def grad(self, dy):
|
||||
attr_code = f"""
|
||||
op.jt_name = "flashattentionbackward";
|
||||
FlashAttentionAttr *attr = new FlashAttentionAttr();
|
||||
attr->scale = {self.scale};
|
||||
attr->keepProb = {self.prob};
|
||||
attr->preToken = {self.pretokens};
|
||||
attr->nextToken = {self.nexttokens};
|
||||
attr->headNum = {self.headnum};
|
||||
attr->inputLayout = "{self.layout}";
|
||||
attr->innerPrecise = {self.innerprecise};
|
||||
attr->sparseMode = {self.sparsemode};
|
||||
attr->psetype = {self.psetype};
|
||||
attr->prefix = {{ {", ".join(map(str, self.prefix))} }};
|
||||
attr->qStartIdx = {{ {", ".join(map(str, self.qstart))} }};
|
||||
attr->kvStartIdx = {{ {", ".join(map(str, self.kvstart))} }};
|
||||
attr->hasRealshift = {"true" if self.hasRealshift else "false"};
|
||||
attr->hasDropmask = {"true" if self.hasDropmask else "false"};
|
||||
attr->hasPaddingmask = {"true" if self.hasPaddingmask else "false"};
|
||||
attr->hasAttentmask = {"true" if self.hasAttenmask else "false"};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
inputs = [
|
||||
self.q, self.k, self.v, dy, self.realshift, self.dropMask,
|
||||
self.paddingMask, self.attenMask, self.maxout, self.sumout,
|
||||
self.attenout
|
||||
]
|
||||
|
||||
result = flashattention_cmd(
|
||||
"FlashAttentionBackward",
|
||||
inputs,
|
||||
output_dtypes=[self.q.dtype, self.k.dtype, self.v.dtype],
|
||||
output_shapes=[self.q.shape, self.k.shape, self.v.shape],
|
||||
attr_code=attr_code)
|
||||
return result
|
|
@ -0,0 +1,89 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "flashattention_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
FlashAttentionOpRunner::FlashAttentionOpRunner() : BaseOpRunner("FlashAttention")
|
||||
{
|
||||
}
|
||||
|
||||
void FlashAttentionOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
ret = aclnnFlashAttentionScoreV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], attr->hasRealshift ? inputTensors[3] : nullptr, attr->hasDropmask ? inputTensors[4] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[6] : nullptr, prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], nullptr, outputTensors[2], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFlashAttentionScoreV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
FlashAttentionBackwardOpRunner::FlashAttentionBackwardOpRunner() : BaseOpRunner("FlashAttentionBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void FlashAttentionBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<FlashAttentionAttr *>(op_attr.get());
|
||||
auto prefix = aclCreateIntArray(attr->prefix.data(), attr->prefix.size());
|
||||
auto qstart = aclCreateIntArray(attr->qStartIdx.data(), attr->qStartIdx.size());
|
||||
auto kvstart = aclCreateIntArray(attr->kvStartIdx.data(), attr->kvStartIdx.size());
|
||||
char *layout = const_cast<char *>(attr->inputLayout.data());
|
||||
ret = aclnnFlashAttentionScoreGradV2GetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFlashAttentionScoreGradV2(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFlashAttentionScoreGradV2 failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class FlashAttentionOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
FlashAttentionOpRunner();
|
||||
};
|
||||
|
||||
class FlashAttentionBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
FlashAttentionBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def relu_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class ReLUACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ReLUACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
x = x.float32()
|
||||
self.input = x
|
||||
result = relu_cmd("ReLU", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"unary\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
mask = relu_cmd("Greater",
|
||||
[self.input, jt.zeros(self.input.shape)],
|
||||
output_dtypes=[self.input.dtype],
|
||||
output_shapes=[self.input.shape],
|
||||
attr_code="op.jt_name=\"binary\";")[0]
|
||||
grad_input = relu_cmd("Mul", [grad_output, mask],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code="op.jt_name=\"binary\";")[0]
|
||||
return grad_input
|
||||
|
||||
|
||||
class LeakyReLUACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(LeakyReLUACL, self).__init__()
|
||||
|
||||
def execute(self, x, negative_slope=0.01):
|
||||
x = x.float32()
|
||||
self.input = x
|
||||
attr_code = f"""
|
||||
op.jt_name = "leakyrelu";
|
||||
LeakyReluAttr *attr = new LeakyReluAttr();
|
||||
attr->negativeSlope = {negative_slope};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = relu_cmd("LeakyReLU", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code=attr_code)[0]
|
||||
self.negative_slope = negative_slope
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "leakyrelubackward";
|
||||
LeakyReluAttr *attr = new LeakyReluAttr();
|
||||
attr->negativeSlope = {self.negative_slope};
|
||||
attr->selfIsResult = false;
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = relu_cmd("LeakyReLUBackward",
|
||||
[grad_output, self.input],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,90 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "relu_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
LeakyReLUOpRunner::LeakyReLUOpRunner() : BaseOpRunner("LeakyReLU")
|
||||
{
|
||||
}
|
||||
|
||||
void LeakyReLUOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
aclScalar *negativeSlope = nullptr;
|
||||
|
||||
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
|
||||
ret = aclnnLeakyReluGetWorkspaceSize(inputTensors[0], negativeSlope, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLeakyRelu failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyScalar(negativeSlope);
|
||||
return;
|
||||
}
|
||||
|
||||
LeakyReLUBackwardOpRunner::LeakyReLUBackwardOpRunner() : BaseOpRunner("LeakyReLUBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void LeakyReLUBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
aclScalar *negativeSlope = nullptr;
|
||||
|
||||
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
negativeSlope = aclCreateScalar(&attr->negativeSlope, aclDataType::ACL_FLOAT);
|
||||
ret = aclnnLeakyReluBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], negativeSlope, attr->selfIsResult, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnLeakyReluBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnLeakyReluBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyScalar(negativeSlope);
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class LeakyReLUOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
LeakyReLUOpRunner();
|
||||
};
|
||||
|
||||
class LeakyReLUBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
LeakyReLUBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def transpose_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class TransPoseACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(TransPoseACL, self).__init__()
|
||||
|
||||
def execute(self, x, *dim):
|
||||
self.input = x
|
||||
if len(dim) == 1 and isinstance(dim[0], Sequence):
|
||||
dim = dim[0]
|
||||
elif len(dim) == 2:
|
||||
axes = list(range(x.ndim))
|
||||
a, b = dim
|
||||
axes[a], axes[b] = axes[b], axes[a]
|
||||
dim = axes
|
||||
|
||||
attr_code = f"""
|
||||
op.jt_name = "transpose";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{ {", ".join(map(str, dim))} }};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
# calculate output shape
|
||||
output_shape = [x.shape[i] for i in dim]
|
||||
output = transpose_cmd("Transpose", [x],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
self.dim = dim
|
||||
return output
|
||||
|
||||
def grad(self, grad_output):
|
||||
dim = list(range(grad_output.ndim))
|
||||
for i, p in enumerate(self.dim):
|
||||
dim[p] = i
|
||||
output_shape = [grad_output.shape[i] for i in dim]
|
||||
attr_code = f"""
|
||||
op.jt_name = "transpose";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{ {", ".join(map(str, dim))} }};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
output = transpose_cmd("Transpose", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[jt.empty(output_shape).shape],
|
||||
attr_code=attr_code)[0]
|
||||
return output
|
|
@ -0,0 +1,66 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "transpose_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
TransposeOpRunner::TransposeOpRunner() : BaseOpRunner("Transpose")
|
||||
{
|
||||
}
|
||||
|
||||
void TransposeOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<ReduceAttr *>(op_attr.get());
|
||||
|
||||
aclIntArray *dim = nullptr;
|
||||
|
||||
dim = aclCreateIntArray(attr->axes.data(), attr->axes.size());
|
||||
bool keepdims = attr->keepdims;
|
||||
|
||||
ret = aclnnPermuteGetWorkspaceSize(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnPermute(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnPermute failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyIntArray(dim);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class TransposeOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
TransposeOpRunner();
|
||||
};
|
||||
|
||||
}
|
Loading…
Reference in New Issue