mirror of https://github.com/Jittor/Jittor
split stack,rope,nantonum
This commit is contained in:
parent
c833f841b1
commit
021ffea7c7
|
@ -2415,36 +2415,37 @@ def change_function():
|
||||||
def softmax_acl(x, dim):
|
def softmax_acl(x, dim):
|
||||||
return SoftmaxACL()(x, dim)
|
return SoftmaxACL()(x, dim)
|
||||||
|
|
||||||
class RopeACL(Function):
|
# class RopeACL(Function):
|
||||||
|
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
super(RopeACL, self).__init__()
|
# super(RopeACL, self).__init__()
|
||||||
|
|
||||||
def execute(self, xq, xk, freqs_cis, freq_cos, freq_sin):
|
# def execute(self, xq, xk, freqs_cis, freq_cos, freq_sin):
|
||||||
attr_code = f"""
|
# attr_code = f"""
|
||||||
op.jt_name = "RotaryPosEmb";
|
# op.jt_name = "RotaryPosEmb";
|
||||||
"""
|
# """
|
||||||
if freqs_cis is not None:
|
# if freqs_cis is not None:
|
||||||
freq_cos = freqs_cis[..., 0]
|
# freq_cos = freqs_cis[..., 0]
|
||||||
freq_sin = freqs_cis[..., 1]
|
# freq_sin = freqs_cis[..., 1]
|
||||||
else:
|
# else:
|
||||||
assert freq_cos is not None and freq_sin is not None
|
# assert freq_cos is not None and freq_sin is not None
|
||||||
inputs = [xq, xk, freq_cos, freq_sin]
|
# inputs = [xq, xk, freq_cos, freq_sin]
|
||||||
results = acl_cmd("RotaryPosEmb",
|
# results = acl_cmd("RotaryPosEmb",
|
||||||
inputs,
|
# inputs,
|
||||||
output_dtypes=[
|
# output_dtypes=[
|
||||||
xq.dtype,
|
# xq.dtype,
|
||||||
],
|
# ],
|
||||||
output_shapes=[
|
# output_shapes=[
|
||||||
xq.shape,
|
# xq.shape,
|
||||||
],
|
# ],
|
||||||
attr_code=attr_code)
|
# attr_code=attr_code)
|
||||||
results[0].sync()
|
# results[0].sync()
|
||||||
return inputs[0], inputs[1]
|
# return inputs[0], inputs[1]
|
||||||
|
|
||||||
def grad(self, grad_output):
|
# def grad(self, grad_output):
|
||||||
return grad_output
|
# return grad_output
|
||||||
|
|
||||||
|
from .aclops.rope_op import RopeACL
|
||||||
def rope_acl(xq, xk, freqs_cis=None, freq_sin=None, freq_cos=None):
|
def rope_acl(xq, xk, freqs_cis=None, freq_sin=None, freq_cos=None):
|
||||||
return RopeACL()(xq, xk, freqs_cis, freq_sin, freq_cos)
|
return RopeACL()(xq, xk, freqs_cis, freq_sin, freq_cos)
|
||||||
|
|
||||||
|
@ -2579,93 +2580,96 @@ def change_function():
|
||||||
# attr_code=attr_code)[0]
|
# attr_code=attr_code)[0]
|
||||||
# return grad_input
|
# return grad_input
|
||||||
|
|
||||||
class StackACL(Function):
|
# class StackACL(Function):
|
||||||
|
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
super(StackACL, self).__init__()
|
# super(StackACL, self).__init__()
|
||||||
|
|
||||||
def execute(self, input_tensors, dim):
|
# def execute(self, input_tensors, dim):
|
||||||
if type(input_tensors) is tuple:
|
# if type(input_tensors) is tuple:
|
||||||
input_tensors = list(input_tensors)
|
# input_tensors = list(input_tensors)
|
||||||
assert type(input_tensors) is list
|
# assert type(input_tensors) is list
|
||||||
assert -1 * len(input_tensors) - 1 <= dim and dim <= len(
|
# assert -1 * len(input_tensors) - 1 <= dim and dim <= len(
|
||||||
input_tensors)
|
# input_tensors)
|
||||||
for i in range(len(input_tensors)):
|
# for i in range(len(input_tensors)):
|
||||||
if input_tensors[i].dtype != input_tensors[0].dtype:
|
# if input_tensors[i].dtype != input_tensors[0].dtype:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"All input tensors must have the same dtype")
|
# "All input tensors must have the same dtype")
|
||||||
if input_tensors[i].shape != input_tensors[0].shape:
|
# if input_tensors[i].shape != input_tensors[0].shape:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"All input tensors must have the same shape")
|
# "All input tensors must have the same shape")
|
||||||
self.input = input_tensors
|
# self.input = input_tensors
|
||||||
input_shape = list(input_tensors[0].shape)
|
# input_shape = list(input_tensors[0].shape)
|
||||||
output_shape = input_shape[:dim] + [len(input_tensors)
|
# output_shape = input_shape[:dim] + [len(input_tensors)
|
||||||
] + input_shape[dim:]
|
# ] + input_shape[dim:]
|
||||||
attr_code = f"""
|
# attr_code = f"""
|
||||||
op.jt_name = "stack";
|
# op.jt_name = "stack";
|
||||||
ConcatAttr *attr = new ConcatAttr();
|
# ConcatAttr *attr = new ConcatAttr();
|
||||||
attr->tensorNum = {len(input_tensors)};
|
# attr->tensorNum = {len(input_tensors)};
|
||||||
attr->dim = {dim};
|
# attr->dim = {dim};
|
||||||
op.op_attr.reset(attr);
|
# op.op_attr.reset(attr);
|
||||||
"""
|
# """
|
||||||
self.attr_code = attr_code
|
# self.attr_code = attr_code
|
||||||
result = acl_cmd("Stack",
|
# result = acl_cmd("Stack",
|
||||||
input_tensors,
|
# input_tensors,
|
||||||
output_dtypes=[input_tensors[0].dtype],
|
# output_dtypes=[input_tensors[0].dtype],
|
||||||
output_shapes=[output_shape],
|
# output_shapes=[output_shape],
|
||||||
attr_code=self.attr_code)[0]
|
# attr_code=self.attr_code)[0]
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
def grad(self, grad_output):
|
# def grad(self, grad_output):
|
||||||
grad_inputs = self.split_grad(grad_output, self.input, self.dim)
|
# grad_inputs = self.split_grad(grad_output, self.input, self.dim)
|
||||||
return grad_inputs
|
# return grad_inputs
|
||||||
|
|
||||||
def split_grad(self, grad_output, input_tensors, axis):
|
# def split_grad(self, grad_output, input_tensors, axis):
|
||||||
offset = []
|
# offset = []
|
||||||
shapeVec = []
|
# shapeVec = []
|
||||||
dtypeVec = []
|
# dtypeVec = []
|
||||||
for tensor in input_tensors:
|
# for tensor in input_tensors:
|
||||||
offset.append(tensor.shape[axis])
|
# offset.append(tensor.shape[axis])
|
||||||
dtypeVec.append(tensor.dtype)
|
# dtypeVec.append(tensor.dtype)
|
||||||
shapeVec.append(tensor.shape)
|
# shapeVec.append(tensor.shape)
|
||||||
|
|
||||||
attr_code = f"""
|
# attr_code = f"""
|
||||||
op.jt_name = "splitwithsize";
|
# op.jt_name = "splitwithsize";
|
||||||
auto *attr = new SplitWithSizeAttr();
|
# auto *attr = new SplitWithSizeAttr();
|
||||||
attr->splitSize = {{ {", ".join(map(str, offset))} }};
|
# attr->splitSize = {{ {", ".join(map(str, offset))} }};
|
||||||
attr->dim = {axis};
|
# attr->dim = {axis};
|
||||||
op.op_attr.reset(attr);
|
# op.op_attr.reset(attr);
|
||||||
"""
|
# """
|
||||||
|
|
||||||
result = acl_cmd("SplitWithSize", [grad_output],
|
# result = acl_cmd("SplitWithSize", [grad_output],
|
||||||
output_dtypes=dtypeVec,
|
# output_dtypes=dtypeVec,
|
||||||
output_shapes=shapeVec,
|
# output_shapes=shapeVec,
|
||||||
attr_code=attr_code)
|
# attr_code=attr_code)
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
|
from .aclops.stack_op import StackACL
|
||||||
def stack_acl(x, dim=0):
|
def stack_acl(x, dim=0):
|
||||||
return StackACL()(x, dim)
|
return StackACL()(x, dim)
|
||||||
|
|
||||||
class NanToNumACL(Function):
|
# class NanToNumACL(Function):
|
||||||
|
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
super(NanToNumACL, self).__init__()
|
# super(NanToNumACL, self).__init__()
|
||||||
|
|
||||||
def execute(self, input, nan_or_inf):
|
# def execute(self, input, nan_or_inf):
|
||||||
attr_code = f"""
|
# attr_code = f"""
|
||||||
op.jt_name = "NanToNum";
|
# op.jt_name = "NanToNum";
|
||||||
NanToNumAttr *attr = new NanToNumAttr();
|
# NanToNumAttr *attr = new NanToNumAttr();
|
||||||
attr->nan = {nan_or_inf};
|
# attr->nan = {nan_or_inf};
|
||||||
attr->posinf = {-nan_or_inf};
|
# attr->posinf = {-nan_or_inf};
|
||||||
attr->neginf = {-nan_or_inf};
|
# attr->neginf = {-nan_or_inf};
|
||||||
op.op_attr.reset(attr);
|
# op.op_attr.reset(attr);
|
||||||
"""
|
# """
|
||||||
self.attr_code = attr_code
|
# self.attr_code = attr_code
|
||||||
result = acl_cmd("NanToNum", [input],
|
# result = acl_cmd("NanToNum", [input],
|
||||||
output_dtypes=[input[0].dtype],
|
# output_dtypes=[input[0].dtype],
|
||||||
output_shapes=[input.shape],
|
# output_shapes=[input.shape],
|
||||||
attr_code=self.attr_code)[0]
|
# attr_code=self.attr_code)[0]
|
||||||
return result
|
# return result
|
||||||
|
|
||||||
|
from .aclops.nantonum_op import NanToNumACL
|
||||||
|
|
||||||
def isnan_acl(x):
|
def isnan_acl(x):
|
||||||
tonum = NanToNumACL()(x, -1.0)
|
tonum = NanToNumACL()(x, -1.0)
|
||||||
|
|
|
@ -571,29 +571,29 @@ namespace jittor
|
||||||
ret = it->second.getWorkspaceSizeFuncLayerNorm(inputTensors[0], normalizedShape, inputTensors[1], inputTensors[2], attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
ret = it->second.getWorkspaceSizeFuncLayerNorm(inputTensors[0], normalizedShape, inputTensors[1], inputTensors[2], attr->eps, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 58:
|
// case 58:
|
||||||
{
|
// {
|
||||||
ret = it->second.getWorkspaceSizeFuncRotaryPosEmb(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], (int64_t)1, &workspaceSize, &executor);
|
// ret = it->second.getWorkspaceSizeFuncRotaryPosEmb(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], (int64_t)1, &workspaceSize, &executor);
|
||||||
break;
|
// break;
|
||||||
}
|
// }
|
||||||
case 59:
|
// case 59:
|
||||||
{
|
// {
|
||||||
std::vector<aclTensor *> stackTensorList = {};
|
// std::vector<aclTensor *> stackTensorList = {};
|
||||||
for (int i = 0; i < input_num; i++)
|
// for (int i = 0; i < input_num; i++)
|
||||||
{
|
// {
|
||||||
stackTensorList.push_back(inputTensors[i]);
|
// stackTensorList.push_back(inputTensors[i]);
|
||||||
}
|
// }
|
||||||
auto stackTensorListInput = aclCreateTensorList(&stackTensorList[0], input_num);
|
// auto stackTensorListInput = aclCreateTensorList(&stackTensorList[0], input_num);
|
||||||
auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
|
// auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
|
||||||
ret = it->second.getWorkspaceSizeFuncConcat(stackTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
// ret = it->second.getWorkspaceSizeFuncConcat(stackTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||||
break;
|
// break;
|
||||||
}
|
// }
|
||||||
case 60:
|
// case 60:
|
||||||
{
|
// {
|
||||||
auto attr = dynamic_cast<NanToNumAttr *>(op_attr.get());
|
// auto attr = dynamic_cast<NanToNumAttr *>(op_attr.get());
|
||||||
ret = it->second.getWorkspaceSizeFuncProdDim(inputTensors[0], attr->nan, attr->posinf, attr->neginf, outputTensors[0], &workspaceSize, &executor);
|
// ret = it->second.getWorkspaceSizeFuncProdDim(inputTensors[0], attr->nan, attr->posinf, attr->neginf, outputTensors[0], &workspaceSize, &executor);
|
||||||
break;
|
// break;
|
||||||
}
|
// }
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
LOGir << "not supported op: " << name;
|
LOGir << "not supported op: " << name;
|
||||||
|
|
|
@ -25,3 +25,6 @@
|
||||||
#include <acl/aclops/silu_op_acl.h>
|
#include <acl/aclops/silu_op_acl.h>
|
||||||
#include <acl/aclops/sigmoid_op_acl.h>
|
#include <acl/aclops/sigmoid_op_acl.h>
|
||||||
#include <acl/aclops/softmax_op_acl.h>
|
#include <acl/aclops/softmax_op_acl.h>
|
||||||
|
#include <acl/aclops/stack_op_acl.h>
|
||||||
|
#include <acl/aclops/nantonum_op_acl.h>
|
||||||
|
#include <acl/aclops/rope_op_acl.h>
|
|
@ -0,0 +1,73 @@
|
||||||
|
import os
|
||||||
|
from jittor_utils import env_or_try_find
|
||||||
|
import jittor_utils
|
||||||
|
import ctypes
|
||||||
|
import glob
|
||||||
|
import jittor.compiler as compiler
|
||||||
|
import jittor as jt
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from collections.abc import Sequence, Iterable
|
||||||
|
|
||||||
|
def nantonum_cmd(name: str,
|
||||||
|
inputs: list,
|
||||||
|
output_dtypes: list = None,
|
||||||
|
output_shapes: list = None,
|
||||||
|
attr_code: str = "",
|
||||||
|
attr_header: str = "",
|
||||||
|
outputs: list = None):
|
||||||
|
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||||
|
|
||||||
|
cuda_header = '''
|
||||||
|
#include "acl/aclops/aclops.h"
|
||||||
|
'''
|
||||||
|
outputs_ = []
|
||||||
|
if outputs is not None:
|
||||||
|
outputs_ = outputs
|
||||||
|
else:
|
||||||
|
assert output_dtypes is not None
|
||||||
|
assert output_shapes is not None
|
||||||
|
assert len(output_dtypes) == len(output_shapes)
|
||||||
|
for i in range(len(output_shapes)):
|
||||||
|
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||||
|
input_code = ''
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
input_code += f"op.add(in{i}, true);\n"
|
||||||
|
|
||||||
|
output_code = ''
|
||||||
|
for i in range(len(outputs_)):
|
||||||
|
output_code += f"op.add(out{i}, false);\n"
|
||||||
|
return jt.code(outputs=outputs_,
|
||||||
|
inputs=inputs,
|
||||||
|
cuda_header=attr_header + cuda_header,
|
||||||
|
cuda_src=f"""
|
||||||
|
|
||||||
|
// aclop
|
||||||
|
{name}OpRunner op;
|
||||||
|
{input_code}
|
||||||
|
{output_code}
|
||||||
|
{attr_code}
|
||||||
|
op.run();""")
|
||||||
|
|
||||||
|
class NanToNumACL(jt.Function):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(NanToNumACL, self).__init__()
|
||||||
|
|
||||||
|
def execute(self, input, nan_or_inf):
|
||||||
|
attr_code = f"""
|
||||||
|
op.jt_name = "NanToNum";
|
||||||
|
NanToNumAttr *attr = new NanToNumAttr();
|
||||||
|
attr->nan = {nan_or_inf};
|
||||||
|
attr->posinf = {-nan_or_inf};
|
||||||
|
attr->neginf = {-nan_or_inf};
|
||||||
|
op.op_attr.reset(attr);
|
||||||
|
"""
|
||||||
|
self.attr_code = attr_code
|
||||||
|
result = nantonum_cmd("NanToNum", [input],
|
||||||
|
output_dtypes=[input[0].dtype],
|
||||||
|
output_shapes=[input.shape],
|
||||||
|
attr_code=self.attr_code)[0]
|
||||||
|
return result
|
|
@ -0,0 +1,58 @@
|
||||||
|
#pragma once
|
||||||
|
#include <acl/acl.h>
|
||||||
|
#include <acl/acl_op_compiler.h>
|
||||||
|
#include <Python.h>
|
||||||
|
#include <pystate.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
|
#include "common.h"
|
||||||
|
#include "op.h"
|
||||||
|
#include "acl_jittor.h"
|
||||||
|
#include "ops/random_op.h"
|
||||||
|
#include "ops/reduce_op.h"
|
||||||
|
#include "ops/binary_op.h"
|
||||||
|
#include "ops/broadcast_to_op.h"
|
||||||
|
#include "ops/transpose_op.h"
|
||||||
|
#include "ops/array_op.h"
|
||||||
|
#include "ops/code_op.h"
|
||||||
|
#include "fused_op.h"
|
||||||
|
#include "ops/unary_op.h"
|
||||||
|
#include "ops/ternary_op.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "misc/cuda_flags.h"
|
||||||
|
#include "mem/allocator.h"
|
||||||
|
#include "op_compiler.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
#include "opt/tuner_manager.h"
|
||||||
|
#include "utils/str_utils.h"
|
||||||
|
#include "aclnn/aclnn.h"
|
||||||
|
#include "nantonum_op_acl.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
NanToNumOpRunner::NanToNumOpRunner() : BaseOpRunner("NanToNum")
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void NanToNumOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||||
|
{
|
||||||
|
auto attr = dynamic_cast<NanToNumAttr *>(op_attr.get());
|
||||||
|
ret = aclnnNanToNumGetWorkspaceSize(inputTensors[0], attr->nan, attr->posinf, attr->neginf, outputTensors[0], &workspaceSize, &executor);
|
||||||
|
|
||||||
|
checkRet(ret);
|
||||||
|
|
||||||
|
if (workspaceSize > 0)
|
||||||
|
{
|
||||||
|
mallocWorkSpace(workspaceSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = aclnnNanToNum(workspaceAddr, workspaceSize, executor, aclstream);
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnNanToNum failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||||
|
|
||||||
|
syncRun();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "utils.h"
|
||||||
|
#include "base_op.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
class NanToNumOpRunner : public BaseOpRunner
|
||||||
|
{
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||||
|
public:
|
||||||
|
NanToNumOpRunner();
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
import os
|
||||||
|
from jittor_utils import env_or_try_find
|
||||||
|
import jittor_utils
|
||||||
|
import ctypes
|
||||||
|
import glob
|
||||||
|
import jittor.compiler as compiler
|
||||||
|
import jittor as jt
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from collections.abc import Sequence, Iterable
|
||||||
|
|
||||||
|
def rope_cmd(name: str,
|
||||||
|
inputs: list,
|
||||||
|
output_dtypes: list = None,
|
||||||
|
output_shapes: list = None,
|
||||||
|
attr_code: str = "",
|
||||||
|
attr_header: str = "",
|
||||||
|
outputs: list = None):
|
||||||
|
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||||
|
|
||||||
|
cuda_header = '''
|
||||||
|
#include "acl/aclops/aclops.h"
|
||||||
|
'''
|
||||||
|
outputs_ = []
|
||||||
|
if outputs is not None:
|
||||||
|
outputs_ = outputs
|
||||||
|
else:
|
||||||
|
assert output_dtypes is not None
|
||||||
|
assert output_shapes is not None
|
||||||
|
assert len(output_dtypes) == len(output_shapes)
|
||||||
|
for i in range(len(output_shapes)):
|
||||||
|
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||||
|
input_code = ''
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
input_code += f"op.add(in{i}, true);\n"
|
||||||
|
|
||||||
|
output_code = ''
|
||||||
|
for i in range(len(outputs_)):
|
||||||
|
output_code += f"op.add(out{i}, false);\n"
|
||||||
|
return jt.code(outputs=outputs_,
|
||||||
|
inputs=inputs,
|
||||||
|
cuda_header=attr_header + cuda_header,
|
||||||
|
cuda_src=f"""
|
||||||
|
|
||||||
|
// aclop
|
||||||
|
{name}OpRunner op;
|
||||||
|
{input_code}
|
||||||
|
{output_code}
|
||||||
|
{attr_code}
|
||||||
|
op.run();""")
|
||||||
|
|
||||||
|
class RopeACL(jt.Function):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(RopeACL, self).__init__()
|
||||||
|
|
||||||
|
def execute(self, xq, xk, freqs_cis, freq_cos, freq_sin):
|
||||||
|
attr_code = f"""
|
||||||
|
op.jt_name = "RotaryPosEmb";
|
||||||
|
"""
|
||||||
|
if freqs_cis is not None:
|
||||||
|
freq_cos = freqs_cis[..., 0]
|
||||||
|
freq_sin = freqs_cis[..., 1]
|
||||||
|
else:
|
||||||
|
assert freq_cos is not None and freq_sin is not None
|
||||||
|
inputs = [xq, xk, freq_cos, freq_sin]
|
||||||
|
results = rope_cmd("RotaryPosEmb",
|
||||||
|
inputs,
|
||||||
|
output_dtypes=[
|
||||||
|
xq.dtype,
|
||||||
|
],
|
||||||
|
output_shapes=[
|
||||||
|
xq.shape,
|
||||||
|
],
|
||||||
|
attr_code=attr_code)
|
||||||
|
results[0].sync()
|
||||||
|
return inputs[0], inputs[1]
|
||||||
|
|
||||||
|
def grad(self, grad_output):
|
||||||
|
return grad_output
|
|
@ -0,0 +1,57 @@
|
||||||
|
#pragma once
|
||||||
|
#include <acl/acl.h>
|
||||||
|
#include <acl/acl_op_compiler.h>
|
||||||
|
#include <Python.h>
|
||||||
|
#include <pystate.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
|
#include "common.h"
|
||||||
|
#include "op.h"
|
||||||
|
#include "acl_jittor.h"
|
||||||
|
#include "ops/random_op.h"
|
||||||
|
#include "ops/reduce_op.h"
|
||||||
|
#include "ops/binary_op.h"
|
||||||
|
#include "ops/broadcast_to_op.h"
|
||||||
|
#include "ops/transpose_op.h"
|
||||||
|
#include "ops/array_op.h"
|
||||||
|
#include "ops/code_op.h"
|
||||||
|
#include "fused_op.h"
|
||||||
|
#include "ops/unary_op.h"
|
||||||
|
#include "ops/ternary_op.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "misc/cuda_flags.h"
|
||||||
|
#include "mem/allocator.h"
|
||||||
|
#include "op_compiler.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
#include "opt/tuner_manager.h"
|
||||||
|
#include "utils/str_utils.h"
|
||||||
|
#include "aclnn/aclnn.h"
|
||||||
|
#include "rope_op_acl.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
RotaryPosEmbOpRunner::RotaryPosEmbOpRunner() : BaseOpRunner("RotaryPosEmb")
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void RotaryPosEmbOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||||
|
{
|
||||||
|
ret = aclnnApplyRotaryPosEmbGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], inputTensors[3], (int64_t)1, &workspaceSize, &executor);
|
||||||
|
|
||||||
|
checkRet(ret);
|
||||||
|
|
||||||
|
if (workspaceSize > 0)
|
||||||
|
{
|
||||||
|
mallocWorkSpace(workspaceSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = aclnnApplyRotaryPosEmb(workspaceAddr, workspaceSize, executor, aclstream);
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnApplyRotaryPosEmb failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||||
|
|
||||||
|
syncRun();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "utils.h"
|
||||||
|
#include "base_op.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
class RotaryPosEmbOpRunner : public BaseOpRunner
|
||||||
|
{
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||||
|
public:
|
||||||
|
RotaryPosEmbOpRunner();
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
import os
|
||||||
|
from jittor_utils import env_or_try_find
|
||||||
|
import jittor_utils
|
||||||
|
import ctypes
|
||||||
|
import glob
|
||||||
|
import jittor.compiler as compiler
|
||||||
|
import jittor as jt
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from collections.abc import Sequence, Iterable
|
||||||
|
|
||||||
|
def stack_cmd(name: str,
|
||||||
|
inputs: list,
|
||||||
|
output_dtypes: list = None,
|
||||||
|
output_shapes: list = None,
|
||||||
|
attr_code: str = "",
|
||||||
|
attr_header: str = "",
|
||||||
|
outputs: list = None):
|
||||||
|
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||||
|
|
||||||
|
cuda_header = '''
|
||||||
|
#include "acl/aclops/aclops.h"
|
||||||
|
'''
|
||||||
|
outputs_ = []
|
||||||
|
if outputs is not None:
|
||||||
|
outputs_ = outputs
|
||||||
|
else:
|
||||||
|
assert output_dtypes is not None
|
||||||
|
assert output_shapes is not None
|
||||||
|
assert len(output_dtypes) == len(output_shapes)
|
||||||
|
for i in range(len(output_shapes)):
|
||||||
|
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||||
|
input_code = ''
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
input_code += f"op.add(in{i}, true);\n"
|
||||||
|
|
||||||
|
output_code = ''
|
||||||
|
for i in range(len(outputs_)):
|
||||||
|
output_code += f"op.add(out{i}, false);\n"
|
||||||
|
return jt.code(outputs=outputs_,
|
||||||
|
inputs=inputs,
|
||||||
|
cuda_header=attr_header + cuda_header,
|
||||||
|
cuda_src=f"""
|
||||||
|
|
||||||
|
// aclop
|
||||||
|
{name}OpRunner op;
|
||||||
|
{input_code}
|
||||||
|
{output_code}
|
||||||
|
{attr_code}
|
||||||
|
op.run();""")
|
||||||
|
|
||||||
|
class StackACL(jt.Function):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(StackACL, self).__init__()
|
||||||
|
|
||||||
|
def execute(self, input_tensors, dim):
|
||||||
|
if type(input_tensors) is tuple:
|
||||||
|
input_tensors = list(input_tensors)
|
||||||
|
assert type(input_tensors) is list
|
||||||
|
assert -1 * len(input_tensors) - 1 <= dim and dim <= len(
|
||||||
|
input_tensors)
|
||||||
|
for i in range(len(input_tensors)):
|
||||||
|
if input_tensors[i].dtype != input_tensors[0].dtype:
|
||||||
|
raise ValueError(
|
||||||
|
"All input tensors must have the same dtype")
|
||||||
|
if input_tensors[i].shape != input_tensors[0].shape:
|
||||||
|
raise ValueError(
|
||||||
|
"All input tensors must have the same shape")
|
||||||
|
self.input = input_tensors
|
||||||
|
input_shape = list(input_tensors[0].shape)
|
||||||
|
output_shape = input_shape[:dim] + [len(input_tensors)
|
||||||
|
] + input_shape[dim:]
|
||||||
|
attr_code = f"""
|
||||||
|
op.jt_name = "stack";
|
||||||
|
ConcatAttr *attr = new ConcatAttr();
|
||||||
|
attr->tensorNum = {len(input_tensors)};
|
||||||
|
attr->dim = {dim};
|
||||||
|
op.op_attr.reset(attr);
|
||||||
|
"""
|
||||||
|
self.attr_code = attr_code
|
||||||
|
result = stack_cmd("Stack",
|
||||||
|
input_tensors,
|
||||||
|
output_dtypes=[input_tensors[0].dtype],
|
||||||
|
output_shapes=[output_shape],
|
||||||
|
attr_code=self.attr_code)[0]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def grad(self, grad_output):
|
||||||
|
grad_inputs = self.split_grad(grad_output, self.input, self.dim)
|
||||||
|
return grad_inputs
|
||||||
|
|
||||||
|
def split_grad(self, grad_output, input_tensors, axis):
|
||||||
|
offset = []
|
||||||
|
shapeVec = []
|
||||||
|
dtypeVec = []
|
||||||
|
for tensor in input_tensors:
|
||||||
|
offset.append(tensor.shape[axis])
|
||||||
|
dtypeVec.append(tensor.dtype)
|
||||||
|
shapeVec.append(tensor.shape)
|
||||||
|
|
||||||
|
attr_code = f"""
|
||||||
|
op.jt_name = "splitwithsize";
|
||||||
|
auto *attr = new SplitWithSizeAttr();
|
||||||
|
attr->splitSize = {{ {", ".join(map(str, offset))} }};
|
||||||
|
attr->dim = {axis};
|
||||||
|
op.op_attr.reset(attr);
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = stack_cmd("SplitWithSize", [grad_output],
|
||||||
|
output_dtypes=dtypeVec,
|
||||||
|
output_shapes=shapeVec,
|
||||||
|
attr_code=attr_code)
|
||||||
|
return result
|
|
@ -0,0 +1,65 @@
|
||||||
|
#pragma once
|
||||||
|
#include <acl/acl.h>
|
||||||
|
#include <acl/acl_op_compiler.h>
|
||||||
|
#include <Python.h>
|
||||||
|
#include <pystate.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
|
#include "common.h"
|
||||||
|
#include "op.h"
|
||||||
|
#include "acl_jittor.h"
|
||||||
|
#include "ops/random_op.h"
|
||||||
|
#include "ops/reduce_op.h"
|
||||||
|
#include "ops/binary_op.h"
|
||||||
|
#include "ops/broadcast_to_op.h"
|
||||||
|
#include "ops/transpose_op.h"
|
||||||
|
#include "ops/array_op.h"
|
||||||
|
#include "ops/code_op.h"
|
||||||
|
#include "fused_op.h"
|
||||||
|
#include "ops/unary_op.h"
|
||||||
|
#include "ops/ternary_op.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "misc/cuda_flags.h"
|
||||||
|
#include "mem/allocator.h"
|
||||||
|
#include "op_compiler.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
#include "opt/tuner_manager.h"
|
||||||
|
#include "utils/str_utils.h"
|
||||||
|
#include "aclnn/aclnn.h"
|
||||||
|
#include "stack_op_acl.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
StackOpRunner::StackOpRunner() : BaseOpRunner("Stack")
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void StackOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||||
|
{
|
||||||
|
auto input_num = in_.size();
|
||||||
|
std::vector<aclTensor *> stackTensorList = {};
|
||||||
|
for (int i = 0; i < input_num; i++)
|
||||||
|
{
|
||||||
|
stackTensorList.push_back(inputTensors[i]);
|
||||||
|
}
|
||||||
|
auto stackTensorListInput = aclCreateTensorList(&stackTensorList[0], input_num);
|
||||||
|
auto attr = dynamic_cast<ConcatAttr *>(op_attr.get());
|
||||||
|
ret = aclnnStackGetWorkspaceSize(stackTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||||
|
|
||||||
|
checkRet(ret);
|
||||||
|
|
||||||
|
if (workspaceSize > 0)
|
||||||
|
{
|
||||||
|
mallocWorkSpace(workspaceSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = aclnnStack(workspaceAddr, workspaceSize, executor, aclstream);
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnStack failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||||
|
|
||||||
|
syncRun();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
#pragma once
|
||||||
|
#include "utils.h"
|
||||||
|
#include "base_op.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
class StackOpRunner : public BaseOpRunner
|
||||||
|
{
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||||
|
public:
|
||||||
|
StackOpRunner();
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue