mirror of https://github.com/Jittor/Jittor
split silu,sigmoid,softmax
This commit is contained in:
parent
bfe1ceb82b
commit
144b7bc57d
|
@ -2203,36 +2203,53 @@ def change_function():
|
|||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
class SiLUACL(Function):
|
||||
from .aclops.dropout_op import DropoutACL
|
||||
class Dropout(jt.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SiLUACL, self).__init__()
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
super(Dropout, self).__init__()
|
||||
self.p = p
|
||||
self.is_train = is_train
|
||||
|
||||
def execute(self, x):
|
||||
x = x.float32()
|
||||
inputs = [x]
|
||||
self.input = x
|
||||
outputs = [jt.empty(x.shape, x.dtype)]
|
||||
attr_code = f"""
|
||||
op.jt_name = "silu";
|
||||
"""
|
||||
result = acl_cmd("SiLU",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
return DropoutACL()(x, self.p, self.is_train)
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "silubackward";
|
||||
"""
|
||||
inputs = [grad_output, self.input]
|
||||
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
grad_input = acl_cmd("SiLUBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
def dropout_acl(x, p=0.5, is_train=False):
|
||||
return DropoutACL()(x, p, is_train)
|
||||
|
||||
# class SiLUACL(Function):
|
||||
|
||||
# def __init__(self):
|
||||
# super(SiLUACL, self).__init__()
|
||||
|
||||
# def execute(self, x):
|
||||
# x = x.float32()
|
||||
# inputs = [x]
|
||||
# self.input = x
|
||||
# outputs = [jt.empty(x.shape, x.dtype)]
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "silu";
|
||||
# """
|
||||
# result = acl_cmd("SiLU",
|
||||
# inputs=inputs,
|
||||
# outputs=outputs,
|
||||
# attr_code=attr_code)[0]
|
||||
# return result
|
||||
|
||||
# def grad(self, grad_output):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "silubackward";
|
||||
# """
|
||||
# inputs = [grad_output, self.input]
|
||||
# outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
# grad_input = acl_cmd("SiLUBackward",
|
||||
# inputs=inputs,
|
||||
# outputs=outputs,
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
|
||||
from .aclops.silu_op import SiLUACL
|
||||
|
||||
def silu_acl(x):
|
||||
return SiLUACL()(x)
|
||||
|
@ -2245,37 +2262,40 @@ def change_function():
|
|||
def execute(self, x):
|
||||
return SiLUACL()(x)
|
||||
|
||||
class SigmoidACL(Function):
|
||||
# class SigmoidACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(SigmoidACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(SigmoidACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
x = x.float32()
|
||||
inputs = [x]
|
||||
outputs = [jt.empty(x.shape, x.dtype)]
|
||||
attr_code = f"""
|
||||
op.jt_name = "sigmoid";
|
||||
"""
|
||||
result = acl_cmd("Sigmoid",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
self.output = result
|
||||
return result
|
||||
# def execute(self, x):
|
||||
# x = x.float32()
|
||||
# inputs = [x]
|
||||
# outputs = [jt.empty(x.shape, x.dtype)]
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "sigmoid";
|
||||
# """
|
||||
# result = acl_cmd("Sigmoid",
|
||||
# inputs=inputs,
|
||||
# outputs=outputs,
|
||||
# attr_code=attr_code)[0]
|
||||
# self.output = result
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "sigmoidbackward";
|
||||
"""
|
||||
inputs = [grad_output, self.output]
|
||||
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
grad_input = acl_cmd("SigmoidBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
# def grad(self, grad_output):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "sigmoidbackward";
|
||||
# """
|
||||
# inputs = [grad_output, self.output]
|
||||
# outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
# grad_input = acl_cmd("SigmoidBackward",
|
||||
# inputs=inputs,
|
||||
# outputs=outputs,
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
|
||||
from .aclops.sigmoid_op import SigmoidACL
|
||||
|
||||
def sigmoid_acl(x):
|
||||
return SigmoidACL()(x)
|
||||
|
||||
|
@ -2345,58 +2365,45 @@ def change_function():
|
|||
res = embedding_acl(x, self.weight)
|
||||
return res
|
||||
|
||||
from .aclops.dropout_op import DropoutACL
|
||||
class Dropout(jt.nn.Module):
|
||||
# class SoftmaxACL(Function):
|
||||
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
super(Dropout, self).__init__()
|
||||
self.p = p
|
||||
self.is_train = is_train
|
||||
# def __init__(self):
|
||||
# super(SoftmaxACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
return DropoutACL()(x, self.p, self.is_train)
|
||||
# def execute(self, x, dim):
|
||||
# x = x.float32()
|
||||
# inputs = [x]
|
||||
# outputs = [jt.empty(x.shape)]
|
||||
# self.dim = dim
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "softmax";
|
||||
# SoftmaxAttr *attr = new SoftmaxAttr();
|
||||
# attr->dim = {dim};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd("Softmax",
|
||||
# inputs=inputs,
|
||||
# outputs=outputs,
|
||||
# attr_code=attr_code)[0]
|
||||
# self.output = result
|
||||
# return result
|
||||
|
||||
def dropout_acl(x, p=0.5, is_train=False):
|
||||
return DropoutACL()(x, p, is_train)
|
||||
|
||||
class SoftmaxACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(SoftmaxACL, self).__init__()
|
||||
|
||||
def execute(self, x, dim):
|
||||
x = x.float32()
|
||||
inputs = [x]
|
||||
outputs = [jt.empty(x.shape)]
|
||||
self.dim = dim
|
||||
attr_code = f"""
|
||||
op.jt_name = "softmax";
|
||||
SoftmaxAttr *attr = new SoftmaxAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Softmax",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
self.output = result
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "softmax";
|
||||
SoftmaxAttr *attr = new SoftmaxAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
inputs = [grad_output, self.output]
|
||||
outputs = [jt.empty(grad_output.shape)]
|
||||
grad_input = acl_cmd("SoftmaxBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
# def grad(self, grad_output):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "softmax";
|
||||
# SoftmaxAttr *attr = new SoftmaxAttr();
|
||||
# attr->dim = {self.dim};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# inputs = [grad_output, self.output]
|
||||
# outputs = [jt.empty(grad_output.shape)]
|
||||
# grad_input = acl_cmd("SoftmaxBackward",
|
||||
# inputs=inputs,
|
||||
# outputs=outputs,
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
from .aclops.softmax_op import SoftmaxACL
|
||||
class Softmax(jt.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -468,26 +468,26 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->scale, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 42:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 43:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 44:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 45:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 42:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 43:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 44:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 45:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 46:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
@ -538,18 +538,18 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncFalshAttentionBackward(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], attr->hasRealshift ? inputTensors[4] : nullptr, attr->hasDropmask ? inputTensors[5] : nullptr, nullptr, attr->hasAttentmask ? inputTensors[7] : nullptr, inputTensors[8], inputTensors[9], nullptr, inputTensors[10], prefix, qstart, kvstart, attr->scale, attr->keepProb, attr->preToken, attr->nextToken, attr->headNum, layout, attr->innerPrecise, attr->sparseMode, attr->psetype, outputTensors[0], outputTensors[1], outputTensors[2], nullptr, &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 53:
|
||||
{
|
||||
auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->dim), outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 54:
|
||||
{
|
||||
auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 53:
|
||||
// {
|
||||
// auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->dim), outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 54:
|
||||
// {
|
||||
// auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncDropoutBackward(inputTensors[0], inputTensors[1], attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 55:
|
||||
{
|
||||
auto attr = dynamic_cast<BatchNormAttr *>(op_attr.get());
|
||||
|
|
|
@ -22,3 +22,6 @@
|
|||
#include <acl/aclops/flashattention_op_acl.h>
|
||||
#include <acl/aclops/relu_op_acl.h>
|
||||
#include <acl/aclops/dropout_op_acl.h>
|
||||
#include <acl/aclops/silu_op_acl.h>
|
||||
#include <acl/aclops/sigmoid_op_acl.h>
|
||||
#include <acl/aclops/softmax_op_acl.h>
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def sigmoid_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class SigmoidACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(SigmoidACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
x = x.float32()
|
||||
inputs = [x]
|
||||
outputs = [jt.empty(x.shape, x.dtype)]
|
||||
attr_code = f"""
|
||||
op.jt_name = "sigmoid";
|
||||
"""
|
||||
result = sigmoid_cmd("Sigmoid",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
self.output = result
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "sigmoidbackward";
|
||||
"""
|
||||
inputs = [grad_output, self.output]
|
||||
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
grad_input = sigmoid_cmd("SigmoidBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,80 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "sigmoid_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
SigmoidOpRunner::SigmoidOpRunner() : BaseOpRunner("Sigmoid")
|
||||
{
|
||||
}
|
||||
|
||||
void SigmoidOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnSigmoidGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSigmoid(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSigmoid failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
SigmoidBackwardOpRunner::SigmoidBackwardOpRunner() : BaseOpRunner("SigmoidBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void SigmoidBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnSigmoidBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSigmoidBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSigmoidBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class SigmoidOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
SigmoidOpRunner();
|
||||
};
|
||||
|
||||
class SigmoidBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
SigmoidBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def silu_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class SiLUACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(SiLUACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
x = x.float32()
|
||||
inputs = [x]
|
||||
self.input = x
|
||||
outputs = [jt.empty(x.shape, x.dtype)]
|
||||
attr_code = f"""
|
||||
op.jt_name = "silu";
|
||||
"""
|
||||
result = silu_cmd("SiLU",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "silubackward";
|
||||
"""
|
||||
inputs = [grad_output, self.input]
|
||||
outputs = [jt.empty(grad_output.shape, grad_output.dtype)]
|
||||
grad_input = silu_cmd("SiLUBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,80 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "silu_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
SiLUOpRunner::SiLUOpRunner() : BaseOpRunner("SiLU")
|
||||
{
|
||||
}
|
||||
|
||||
void SiLUOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnSiluGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSilu(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSilu failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
SiLUBackwardOpRunner::SiLUBackwardOpRunner() : BaseOpRunner("SiLUBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void SiLUBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnSiluBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSiluBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSiluBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class SiLUOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
SiLUOpRunner();
|
||||
};
|
||||
|
||||
class SiLUBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
SiLUBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def softmax_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class SoftmaxACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(SoftmaxACL, self).__init__()
|
||||
|
||||
def execute(self, x, dim):
|
||||
x = x.float32()
|
||||
inputs = [x]
|
||||
outputs = [jt.empty(x.shape)]
|
||||
self.dim = dim
|
||||
attr_code = f"""
|
||||
op.jt_name = "softmax";
|
||||
SoftmaxAttr *attr = new SoftmaxAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = softmax_cmd("Softmax",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
self.output = result
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "softmax";
|
||||
SoftmaxAttr *attr = new SoftmaxAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
inputs = [grad_output, self.output]
|
||||
outputs = [jt.empty(grad_output.shape)]
|
||||
grad_input = softmax_cmd("SoftmaxBackward",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,82 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "softmax_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
SoftmaxOpRunner::SoftmaxOpRunner() : BaseOpRunner("Softmax")
|
||||
{
|
||||
}
|
||||
|
||||
void SoftmaxOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
ret = aclnnSoftmaxGetWorkspaceSize(inputTensors[0], aclDataType(attr->dim), outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSoftmax(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSoftmax failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
SoftmaxBackwardOpRunner::SoftmaxBackwardOpRunner() : BaseOpRunner("SoftmaxBackward")
|
||||
{
|
||||
}
|
||||
|
||||
void SoftmaxBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<SoftmaxAttr *>(op_attr.get());
|
||||
ret = aclnnSoftmaxBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSoftmaxBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSoftmaxBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class SoftmaxOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
SoftmaxOpRunner();
|
||||
};
|
||||
|
||||
class SoftmaxBackwardOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
SoftmaxBackwardOpRunner();
|
||||
};
|
||||
|
||||
}
|
Loading…
Reference in New Issue