mirror of https://github.com/Jittor/Jittor
split cumsum,gather,index
This commit is contained in:
parent
5a30cd334f
commit
064af9d543
|
@ -799,40 +799,42 @@ def change_function():
|
|||
def concat(x, dim=0):
|
||||
return ConcatACL()(x, dim)
|
||||
|
||||
class GatherACL(Function):
|
||||
# class GatherACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(GatherACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(GatherACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index):
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Gather", [input, index],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
# def execute(self, input, dim, index):
|
||||
# self.dim = dim
|
||||
# self.index = index
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "gather";
|
||||
# GatherAttr *attr = new GatherAttr();
|
||||
# attr->dim = {dim};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd("Gather", [input, index],
|
||||
# output_dtypes=[input.dtype],
|
||||
# output_shapes=[index.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {self.dim};
|
||||
attr->reduction = {1};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = acl_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
# def grad(self, grad_output):
|
||||
# tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "scatter";
|
||||
# ScatterAttr *attr = new ScatterAttr();
|
||||
# attr->axis = {self.dim};
|
||||
# attr->reduction = {1};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# grad_input = acl_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[tmp.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
from .aclops.gather_op import GatherACL
|
||||
|
||||
def gather_acl(input, dim, index):
|
||||
return GatherACL()(input, dim, index)
|
||||
|
@ -843,52 +845,54 @@ def change_function():
|
|||
else:
|
||||
return jt.array([False])
|
||||
|
||||
class CumsumACL(Function):
|
||||
# class CumsumACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(CumsumACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(CumsumACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim=-1):
|
||||
self.dim = dim
|
||||
attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Cumsum", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
# def execute(self, input, dim=-1):
|
||||
# self.dim = dim
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "cumsum";
|
||||
# GatherAttr *attr = new GatherAttr();
|
||||
# attr->dim = {dim};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd("Cumsum", [input],
|
||||
# output_dtypes=[input.dtype],
|
||||
# output_shapes=[input.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
cumsum_attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flip_attr_code = f"""
|
||||
op.jt_name = "flip";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{{self.dim}}};
|
||||
attr->prod_dim = {{{1}}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flipped_grad_output = acl_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=cumsum_attr_code)[0]
|
||||
grad_input = acl_cmd("Flip", [cumulative_grad],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
return grad_input
|
||||
# def grad(self, grad_output):
|
||||
# cumsum_attr_code = f"""
|
||||
# op.jt_name = "cumsum";
|
||||
# GatherAttr *attr = new GatherAttr();
|
||||
# attr->dim = {self.dim};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# flip_attr_code = f"""
|
||||
# op.jt_name = "flip";
|
||||
# ReduceAttr *attr = new ReduceAttr();
|
||||
# attr->axes = {{{self.dim}}};
|
||||
# attr->prod_dim = {{{1}}};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# flipped_grad_output = acl_cmd("Flip", [grad_output],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[grad_output.shape],
|
||||
# attr_code=flip_attr_code)[0]
|
||||
# cumulative_grad = acl_cmd("Cumsum", [flipped_grad_output],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[grad_output.shape],
|
||||
# attr_code=cumsum_attr_code)[0]
|
||||
# grad_input = acl_cmd("Flip", [cumulative_grad],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[grad_output.shape],
|
||||
# attr_code=flip_attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
from .aclops.cumsum_op import CumsumACL
|
||||
|
||||
def cumsum_acl(input, dim=-1):
|
||||
return CumsumACL()(input, dim)
|
||||
|
@ -898,59 +902,61 @@ def change_function():
|
|||
x = cumsum_acl(x, dim=dim)
|
||||
return jt.exp(x)
|
||||
|
||||
class IndexACL(Function):
|
||||
# class IndexACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(IndexACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(IndexACL, self).__init__()
|
||||
|
||||
def execute(self, inshape: list, dim=None, dtype="int32"):
|
||||
# zeros a tensor, shape is inshape, dtype is dtype
|
||||
dim_input = dim
|
||||
if dim == None:
|
||||
dim = [i for i in range(len(inshape))]
|
||||
elif type(dim) == int:
|
||||
dim = [dim]
|
||||
results = []
|
||||
extra_data = {}
|
||||
extra_data["dim_count"] = len(dim)
|
||||
# def execute(self, inshape: list, dim=None, dtype="int32"):
|
||||
# # zeros a tensor, shape is inshape, dtype is dtype
|
||||
# dim_input = dim
|
||||
# if dim == None:
|
||||
# dim = [i for i in range(len(inshape))]
|
||||
# elif type(dim) == int:
|
||||
# dim = [dim]
|
||||
# results = []
|
||||
# extra_data = {}
|
||||
# extra_data["dim_count"] = len(dim)
|
||||
|
||||
for i, d in enumerate(dim):
|
||||
max_len = inshape[d]
|
||||
# for i, d in enumerate(dim):
|
||||
# max_len = inshape[d]
|
||||
|
||||
extra_data[f"dim_{i}_start"] = 0
|
||||
extra_data[f"dim_{i}_end"] = max_len
|
||||
extra_data[f"dim_{i}_step"] = 1
|
||||
# extra_data[f"dim_{i}_start"] = 0
|
||||
# extra_data[f"dim_{i}_end"] = max_len
|
||||
# extra_data[f"dim_{i}_step"] = 1
|
||||
|
||||
tmp = jt.zeros(max_len, dtype=dtype)
|
||||
range_attr_code = f"""
|
||||
op.jt_name = "range";
|
||||
RangeAttr *attr = new RangeAttr();
|
||||
attr->start = data["dim_{i}_start"];
|
||||
attr->end = data["dim_{i}_end"];
|
||||
attr->step = data["dim_{i}_step"];
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd_forward("Range", [],
|
||||
output_dtypes=[tmp.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=range_attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
broadcast_dims = list(range(len(inshape)))
|
||||
broadcast_dims.remove(d)
|
||||
result = jt.broadcast(result,
|
||||
shape=inshape,
|
||||
dims=broadcast_dims)
|
||||
results.append(result)
|
||||
# tmp = jt.zeros(max_len, dtype=dtype)
|
||||
# range_attr_code = f"""
|
||||
# op.jt_name = "range";
|
||||
# RangeAttr *attr = new RangeAttr();
|
||||
# attr->start = data["dim_{i}_start"];
|
||||
# attr->end = data["dim_{i}_end"];
|
||||
# attr->step = data["dim_{i}_step"];
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd_forward("Range", [],
|
||||
# output_dtypes=[tmp.dtype],
|
||||
# output_shapes=[tmp.shape],
|
||||
# attr_code=range_attr_code,
|
||||
# extra_data=extra_data)[0]
|
||||
# broadcast_dims = list(range(len(inshape)))
|
||||
# broadcast_dims.remove(d)
|
||||
# result = jt.broadcast(result,
|
||||
# shape=inshape,
|
||||
# dims=broadcast_dims)
|
||||
# results.append(result)
|
||||
|
||||
if len(results) != 1 or dim_input == None:
|
||||
return tuple(results)
|
||||
elif len(results) == 1 and dim_input != None:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
# if len(results) != 1 or dim_input == None:
|
||||
# return tuple(results)
|
||||
# elif len(results) == 1 and dim_input != None:
|
||||
# return results[0]
|
||||
# else:
|
||||
# return results
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
# def grad(self, grad_output):
|
||||
# return grad_output
|
||||
|
||||
from .aclops.index_op import IndexACL
|
||||
|
||||
def index_acl(inshape: Union[jt.Var, list], dim=None, dtype="int32"):
|
||||
if isinstance(inshape, jt.Var):
|
||||
|
|
|
@ -142,16 +142,16 @@ namespace jittor
|
|||
}
|
||||
}
|
||||
|
||||
if (jt_name == "range")
|
||||
{
|
||||
auto attr = dynamic_cast<RangeAttr *>(op_attr.get());
|
||||
int64_t startValue = attr->start;
|
||||
int64_t endValue = attr->end;
|
||||
int64_t stepValue = attr->step;
|
||||
start = aclCreateScalar(&startValue, aclDataType::ACL_INT64);
|
||||
end = aclCreateScalar(&endValue, aclDataType::ACL_INT64);
|
||||
step = aclCreateScalar(&stepValue, aclDataType::ACL_INT64);
|
||||
}
|
||||
// if (jt_name == "range")
|
||||
// {
|
||||
// auto attr = dynamic_cast<RangeAttr *>(op_attr.get());
|
||||
// int64_t startValue = attr->start;
|
||||
// int64_t endValue = attr->end;
|
||||
// int64_t stepValue = attr->step;
|
||||
// start = aclCreateScalar(&startValue, aclDataType::ACL_INT64);
|
||||
// end = aclCreateScalar(&endValue, aclDataType::ACL_INT64);
|
||||
// step = aclCreateScalar(&stepValue, aclDataType::ACL_INT64);
|
||||
// }
|
||||
|
||||
if (jt_name == "conv2dbackward")
|
||||
{
|
||||
|
@ -366,24 +366,24 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncConcat(concatTensorListInput, attr->dim, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 28:
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncGather(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 29:
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncCumsum(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 30:
|
||||
{
|
||||
auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
|
||||
ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 28:
|
||||
// {
|
||||
// auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncGather(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 29:
|
||||
// {
|
||||
// auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncCumsum(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 30:
|
||||
// {
|
||||
// auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
|
||||
// ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 31:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
@ -437,11 +437,11 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncStridedSliceAssignV2(outputTensors[0], inputTensors[0], begins, ends, steps, axes, &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 37:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncRange(start, end, step, outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 37:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncRange(start, end, step, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 38:
|
||||
{
|
||||
auto attr = dynamic_cast<LeakyReluAttr *>(op_attr.get());
|
||||
|
|
|
@ -13,3 +13,6 @@
|
|||
#include <acl/aclops/pool_op_acl.h>
|
||||
#include <acl/aclops/flip_op_acl.h>
|
||||
#include <acl/aclops/concat_op_acl.h>
|
||||
#include <acl/aclops/gather_op_acl.h>
|
||||
#include <acl/aclops/cumsum_op_acl.h>
|
||||
#include <acl/aclops/index_op_acl.h>
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def cumsum_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class CumsumACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(CumsumACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim=-1):
|
||||
self.dim = dim
|
||||
attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = cumsum_cmd("Cumsum", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
cumsum_attr_code = f"""
|
||||
op.jt_name = "cumsum";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flip_attr_code = f"""
|
||||
op.jt_name = "flip";
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
attr->axes = {{{self.dim}}};
|
||||
attr->prod_dim = {{{1}}};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
flipped_grad_output = cumsum_cmd("Flip", [grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
cumulative_grad = cumsum_cmd("Cumsum", [flipped_grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=cumsum_attr_code)[0]
|
||||
grad_input = cumsum_cmd("Flip", [cumulative_grad],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[grad_output.shape],
|
||||
attr_code=flip_attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,57 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "cumsum_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
CumsumOpRunner::CumsumOpRunner() : BaseOpRunner("Cumsum")
|
||||
{
|
||||
}
|
||||
|
||||
void CumsumOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = aclnnCumsumGetWorkspaceSize(inputTensors[0], attr->dim, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnCumsum(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnCumsum failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class CumsumOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
CumsumOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def gather_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class GatherACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(GatherACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index):
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_cmd("Gather", [input, index],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype)
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {self.dim};
|
||||
attr->reduction = {1};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
|
@ -0,0 +1,81 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "gather_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
GatherOpRunner::GatherOpRunner() : BaseOpRunner("Gather")
|
||||
{
|
||||
}
|
||||
|
||||
void GatherOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<GatherAttr *>(op_attr.get());
|
||||
ret = aclnnGatherGetWorkspaceSize(inputTensors[0], attr->dim, inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnGather(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnGather failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
ScatterOpRunner::ScatterOpRunner() : BaseOpRunner("Scatter")
|
||||
{
|
||||
}
|
||||
|
||||
void ScatterOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
auto attr = dynamic_cast<ScatterAttr *>(op_attr.get());
|
||||
ret = aclnnScatterGetWorkspaceSize(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnScatter(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnScatter failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class GatherOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
GatherOpRunner();
|
||||
};
|
||||
|
||||
|
||||
class ScatterOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
ScatterOpRunner();
|
||||
};
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def range_forward(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None,
|
||||
extra_data: dict = {}):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
op.add(out0, false);
|
||||
{attr_code}
|
||||
op.run();""",
|
||||
data=extra_data)
|
||||
|
||||
class IndexACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(IndexACL, self).__init__()
|
||||
|
||||
def execute(self, inshape: list, dim=None, dtype="int32"):
|
||||
# zeros a tensor, shape is inshape, dtype is dtype
|
||||
dim_input = dim
|
||||
if dim == None:
|
||||
dim = [i for i in range(len(inshape))]
|
||||
elif type(dim) == int:
|
||||
dim = [dim]
|
||||
results = []
|
||||
extra_data = {}
|
||||
extra_data["dim_count"] = len(dim)
|
||||
|
||||
for i, d in enumerate(dim):
|
||||
max_len = inshape[d]
|
||||
|
||||
extra_data[f"dim_{i}_start"] = 0
|
||||
extra_data[f"dim_{i}_end"] = max_len
|
||||
extra_data[f"dim_{i}_step"] = 1
|
||||
|
||||
tmp = jt.zeros(max_len, dtype=dtype)
|
||||
range_attr_code = f"""
|
||||
op.jt_name = "range";
|
||||
RangeAttr *attr = new RangeAttr();
|
||||
attr->start = data["dim_{i}_start"];
|
||||
attr->end = data["dim_{i}_end"];
|
||||
attr->step = data["dim_{i}_step"];
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = range_forward("Range", [],
|
||||
output_dtypes=[tmp.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=range_attr_code,
|
||||
extra_data=extra_data)[0]
|
||||
broadcast_dims = list(range(len(inshape)))
|
||||
broadcast_dims.remove(d)
|
||||
result = jt.broadcast(result,
|
||||
shape=inshape,
|
||||
dims=broadcast_dims)
|
||||
results.append(result)
|
||||
|
||||
if len(results) != 1 or dim_input == None:
|
||||
return tuple(results)
|
||||
elif len(results) == 1 and dim_input != None:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
|
@ -0,0 +1,72 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "index_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
RangeOpRunner::RangeOpRunner() : BaseOpRunner("Range")
|
||||
{
|
||||
}
|
||||
|
||||
void RangeOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
aclScalar *start = nullptr;
|
||||
aclScalar *end = nullptr;
|
||||
aclScalar *step = nullptr;
|
||||
|
||||
auto attr = dynamic_cast<RangeAttr *>(op_attr.get());
|
||||
int64_t startValue = attr->start;
|
||||
int64_t endValue = attr->end;
|
||||
int64_t stepValue = attr->step;
|
||||
start = aclCreateScalar(&startValue, aclDataType::ACL_INT64);
|
||||
end = aclCreateScalar(&endValue, aclDataType::ACL_INT64);
|
||||
step = aclCreateScalar(&stepValue, aclDataType::ACL_INT64);
|
||||
|
||||
ret = aclnnRangeGetWorkspaceSize(start, end, step, outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnRange(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnRange failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
|
||||
aclDestroyScalar(start);
|
||||
aclDestroyScalar(end);
|
||||
aclDestroyScalar(step);
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class RangeOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
RangeOpRunner();
|
||||
};
|
||||
|
||||
}
|
|
@ -503,6 +503,28 @@ class TestACL(unittest.TestCase):
|
|||
res = self.measure_time(lambda: jt.grad(b.sum(), a))
|
||||
np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]])
|
||||
print("test gather grad success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_cumsum_1(self):
|
||||
a = jt.array([1, 2, 3, 4, 5])
|
||||
b = self.measure_time(lambda: jt.cumsum(a))
|
||||
np.testing.assert_allclose(b.numpy(), [1, 3, 6, 10, 15])
|
||||
print("test cumsum (test case 1) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_cumsum_2(self):
|
||||
a = jt.array([[1, 2, 3], [4, 5, 6]])
|
||||
b = self.measure_time(lambda: jt.cumsum(a, dim = 0))
|
||||
np.testing.assert_allclose(b.numpy(), [[1, 2, 3], [5, 7, 9]])
|
||||
print("test cumsum (test case 2) success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_cumsum_grad(self):
|
||||
a = jt.array([[1., 2., 3.], [4., 5., 6.]])
|
||||
b = jt.cumsum(a, dim = 0)
|
||||
res = self.measure_time(lambda: jt.grad(b.sum(), a))
|
||||
np.testing.assert_allclose(res.numpy(), [[2., 2., 2.], [1., 1., 1.]])
|
||||
print("test cumsum grad success")
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_any_1(self):
|
||||
|
|
Loading…
Reference in New Issue