split where,scatter,floor

This commit is contained in:
Exusial 2024-12-14 10:36:37 +08:00
parent 064af9d543
commit 8762352c64
13 changed files with 581 additions and 168 deletions

View File

@ -834,7 +834,7 @@ def change_function():
# attr_code=attr_code)[0]
# return grad_input
from .aclops.gather_op import GatherACL
from .aclops.gather_scatter_op import GatherACL
def gather_acl(input, dim, index):
return GatherACL()(input, dim, index)
@ -963,194 +963,201 @@ def change_function():
inshape = inshape.shape
return IndexACL()(inshape, dim, dtype)
class ScatterACL(Function):
# class ScatterACL(Function):
def __init__(self):
super(ScatterACL, self).__init__()
# def __init__(self):
# super(ScatterACL, self).__init__()
def execute(self, input, dim, index, src, reduce='void'):
self.dim = dim
self.index = index
self.reduce = reduce
attr_code = f"""
op.jt_name = "scatter";
ScatterAttr *attr = new ScatterAttr();
attr->axis = {dim};
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
op.op_attr.reset(attr);
"""
result = acl_cmd("Scatter", [input, self.index, src],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
# def execute(self, input, dim, index, src, reduce='void'):
# self.dim = dim
# self.index = index
# self.reduce = reduce
# attr_code = f"""
# op.jt_name = "scatter";
# ScatterAttr *attr = new ScatterAttr();
# attr->axis = {dim};
# attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
# op.op_attr.reset(attr);
# """
# result = acl_cmd("Scatter", [input, self.index, src],
# output_dtypes=[input.dtype],
# output_shapes=[input.shape],
# attr_code=attr_code)[0]
# return result
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "gather";
GatherAttr *attr = new GatherAttr();
attr->dim = {self.dim};
op.op_attr.reset(attr);
"""
grad_input = acl_cmd("Gather", [grad_output, self.index],
output_dtypes=[grad_output.dtype],
output_shapes=[self.index.shape],
attr_code=attr_code)[0]
return grad_output, None, None, grad_input
# def grad(self, grad_output):
# attr_code = f"""
# op.jt_name = "gather";
# GatherAttr *attr = new GatherAttr();
# attr->dim = {self.dim};
# op.op_attr.reset(attr);
# """
# grad_input = acl_cmd("Gather", [grad_output, self.index],
# output_dtypes=[grad_output.dtype],
# output_shapes=[self.index.shape],
# attr_code=attr_code)[0]
# return grad_output, None, None, grad_input
from .aclops.gather_scatter_op import ScatterACL
def scatter_acl(input, dim, index, src, reduce='void'):
return ScatterACL()(input, dim, index, src, reduce)
class WhereACL(Function):
# class WhereACL(Function):
def __init__(self):
super(WhereACL, self).__init__()
# def __init__(self):
# super(WhereACL, self).__init__()
def execute(self, condition, x=None, y=None):
# case 1 (unary)
if y is None:
self.unary = True
# def execute(self, condition, x=None, y=None):
# # case 1 (unary)
# if y is None:
# self.unary = True
# In this case, `condition` is the input, while `x` is dtype
result = nonzero_acl(condition).t()
result = [result[i] for i in range(result.size(0))]
return result
# The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
# # In this case, `condition` is the input, while `x` is dtype
# result = nonzero_acl(condition).t()
# result = [result[i] for i in range(result.size(0))]
# return result
# # The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
# case 2 (cond ? x : y)
else:
self.condition = condition
# # case 2 (cond ? x : y)
# else:
# self.condition = condition
if x.dtype != y.dtype:
if x.dtype == jt.float32:
y = y.float32()
elif y.dtype == jt.float32:
x = x.float32()
else:
x = x.to(y.dtype)
# if x.dtype != y.dtype:
# if x.dtype == jt.float32:
# y = y.float32()
# elif y.dtype == jt.float32:
# x = x.float32()
# else:
# x = x.to(y.dtype)
self.x = x
self.y = y
# self.x = x
# self.y = y
result = acl_cmd("Where", [condition, x, y],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"where\";")[0]
return result
# result = acl_cmd("Where", [condition, x, y],
# output_dtypes=[x.dtype],
# output_shapes=[x.shape],
# attr_code="op.jt_name=\"where\";")[0]
# return result
def grad(self, grad_output):
if hasattr(self, 'unary') and self.unary:
return grad_output
else:
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
grad_x = acl_cmd("Where", [self.condition, grad_output, tmp],
output_dtypes=[self.x.dtype],
output_shapes=[self.x.shape],
attr_code="op.jt_name=\"where\";")[0]
# def grad(self, grad_output):
# if hasattr(self, 'unary') and self.unary:
# return grad_output
# else:
# tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
# grad_x = acl_cmd("Where", [self.condition, grad_output, tmp],
# output_dtypes=[self.x.dtype],
# output_shapes=[self.x.shape],
# attr_code="op.jt_name=\"where\";")[0]
grad_y = acl_cmd("Where", [self.condition, tmp, grad_output],
output_dtypes=[self.y.dtype],
output_shapes=[self.y.shape],
attr_code="op.jt_name=\"where\";")[0]
return grad_output, grad_x, grad_y
# grad_y = acl_cmd("Where", [self.condition, tmp, grad_output],
# output_dtypes=[self.y.dtype],
# output_shapes=[self.y.shape],
# attr_code="op.jt_name=\"where\";")[0]
# return grad_output, grad_x, grad_y
from .aclops.where_op import WhereACL
def where_acl(condition, x=None, y=None):
return WhereACL()(condition, x, y)
class NonzeroACL(Function):
# class NonzeroACL(Function):
def __init__(self):
super(NonzeroACL, self).__init__()
# def __init__(self):
# super(NonzeroACL, self).__init__()
def execute(self, x):
attr_code = f"""
op.jt_name = "nonzero";
"""
nonzero_cnt = (x != 0.0).sum().item()
# def execute(self, x):
# attr_code = f"""
# op.jt_name = "nonzero";
# """
# nonzero_cnt = (x != 0.0).sum().item()
result = acl_cmd("Nonzero", [x],
output_dtypes=['int64'],
output_shapes=[(nonzero_cnt, x.ndim)],
attr_code=attr_code)[0]
# result = acl_cmd("Nonzero", [x],
# output_dtypes=['int64'],
# output_shapes=[(nonzero_cnt, x.ndim)],
# attr_code=attr_code)[0]
return result
# return result
def grad(self, grad_output):
return grad_output
# def grad(self, grad_output):
# return grad_output
from .aclops.where_op import NonzeroACL
def nonzero_acl(x):
return NonzeroACL()(x)
class FloorIntACL(Function):
# class FloorIntACL(Function):
def __init__(self):
super(FloorIntACL, self).__init__()
# def __init__(self):
# super(FloorIntACL, self).__init__()
def execute(self, input):
self.shape = input.shape
result = acl_cmd("Floor", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code="op.jt_name=\"floor\";")[0]
return result
# def execute(self, input):
# self.shape = input.shape
# result = acl_cmd("Floor", [input],
# output_dtypes=[input.dtype],
# output_shapes=[input.shape],
# attr_code="op.jt_name=\"floor\";")[0]
# return result
def grad(self, grad_output):
return jt.zeros(self.shape, dtype=grad_output.dtype)
# def grad(self, grad_output):
# return jt.zeros(self.shape, dtype=grad_output.dtype)
from .aclops.floor_op import FloorIntACL
def floor_int_acl(x):
return FloorIntACL()(x)
def caculate_shape(tensors):
if isinstance(tensors, jt.Var):
# tensors = tensors[0]
return tensors.shape
elif isinstance(tensors, (int, float)):
return []
elif isinstance(tensors, (list, tuple)):
# return [caculate_shape(tensor) for tensor in tensors]
sub_shape = caculate_shape(tensors[0])
return [len(tensors)] + sub_shape
else:
assert False, f"not implemented for {type(tensors)}"
# def caculate_shape(tensors):
# if isinstance(tensors, jt.Var):
# # tensors = tensors[0]
# return tensors.shape
# elif isinstance(tensors, (int, float)):
# return []
# elif isinstance(tensors, (list, tuple)):
# # return [caculate_shape(tensor) for tensor in tensors]
# sub_shape = caculate_shape(tensors[0])
# return [len(tensors)] + sub_shape
# else:
# assert False, f"not implemented for {type(tensors)}"
def can_broadcast_and_shape(shape1, shape2):
"""
检查两个张量是否可以广播并返回广播后的形状
# def can_broadcast_and_shape(shape1, shape2):
# """
# 检查两个张量是否可以广播,并返回广播后的形状。
参数:
- shape1: 第一个张量的形状tuple list
- shape2: 第二个张量的形状tuple list
# 参数:
# - shape1: 第一个张量的形状tuple 或 list
# - shape2: 第二个张量的形状tuple 或 list
返回:
- can_broadcast: 布尔值表示是否可以广播
- broadcast_shape: 如果可以广播返回广播后的形状否则返回 None
"""
# 将形状转换为元组,以防输入是列表
shape1 = tuple(shape1)
shape2 = tuple(shape2)
# 返回:
# - can_broadcast: 布尔值,表示是否可以广播
# - broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None
# """
# # 将形状转换为元组,以防输入是列表
# shape1 = tuple(shape1)
# shape2 = tuple(shape2)
# 使两个形状的长度一致通过在前面补1
len1, len2 = len(shape1), len(shape2)
if len1 < len2:
shape1 = (1, ) * (len2 - len1) + shape1
elif len2 < len1:
shape2 = (1, ) * (len1 - len2) + shape2
# # 使两个形状的长度一致通过在前面补1
# len1, len2 = len(shape1), len(shape2)
# if len1 < len2:
# shape1 = (1, ) * (len2 - len1) + shape1
# elif len2 < len1:
# shape2 = (1, ) * (len1 - len2) + shape2
broadcast_shape = []
# broadcast_shape = []
# 从最后一维开始检查每一维度
for dim1, dim2 in zip(shape1, shape2):
if dim1 == dim2:
broadcast_shape.append(dim1)
elif dim1 == 1:
broadcast_shape.append(dim2)
elif dim2 == 1:
broadcast_shape.append(dim1)
else:
# 如果在某一维度上不兼容,则不能广播
return False, None
# # 从最后一维开始检查每一维度
# for dim1, dim2 in zip(shape1, shape2):
# if dim1 == dim2:
# broadcast_shape.append(dim1)
# elif dim1 == 1:
# broadcast_shape.append(dim2)
# elif dim2 == 1:
# broadcast_shape.append(dim1)
# else:
# # 如果在某一维度上不兼容,则不能广播
# return False, None
return True, tuple(broadcast_shape)
# return True, tuple(broadcast_shape)
# class GetItemACL(Function):

View File

@ -257,16 +257,16 @@ namespace jittor
// ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
// break;
// }
case 15:
{
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
break;
}
case 17:
{
ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
break;
}
// case 15:
// {
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 17:
// {
// ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 18:
{
auto attr = dynamic_cast<TriuAttr *>(op_attr.get());
@ -384,11 +384,11 @@ namespace jittor
// ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
// break;
// }
case 31:
{
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
break;
}
// case 31:
// {
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
// break;
// }
// case 32:
// {
// auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);

View File

@ -13,6 +13,8 @@
#include <acl/aclops/pool_op_acl.h>
#include <acl/aclops/flip_op_acl.h>
#include <acl/aclops/concat_op_acl.h>
#include <acl/aclops/gather_op_acl.h>
#include <acl/aclops/gather_scatter_op_acl.h>
#include <acl/aclops/cumsum_op_acl.h>
#include <acl/aclops/index_op_acl.h>
#include <acl/aclops/where_op_acl.h>
#include <acl/aclops/floor_op_acl.h>

View File

@ -0,0 +1,68 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def floor_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class FloorIntACL(jt.Function):
def __init__(self):
super(FloorIntACL, self).__init__()
def execute(self, input):
self.shape = input.shape
result = floor_cmd("Floor", [input],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code="op.jt_name=\"floor\";")[0]
return result
def grad(self, grad_output):
return jt.zeros(self.shape, dtype=grad_output.dtype)

View File

@ -0,0 +1,56 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "floor_op_acl.h"
namespace jittor
{
FloorOpRunner::FloorOpRunner() : BaseOpRunner("Floor")
{
}
void FloorOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnFloorGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnFloor(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFloor failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,15 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class FloorOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
FloorOpRunner();
};
}

View File

@ -11,7 +11,7 @@ import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def gather_cmd(name: str,
def gather_scatter_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
@ -65,7 +65,7 @@ class GatherACL(jt.Function):
attr->dim = {dim};
op.op_attr.reset(attr);
"""
result = gather_cmd("Gather", [input, index],
result = gather_scatter_cmd("Gather", [input, index],
output_dtypes=[input.dtype],
output_shapes=[index.shape],
attr_code=attr_code)[0]
@ -80,8 +80,43 @@ class GatherACL(jt.Function):
attr->reduction = {1};
op.op_attr.reset(attr);
"""
grad_input = gather_cmd("Scatter", [tmp, self.index, grad_output],
grad_input = gather_scatter_cmd("Scatter", [tmp, self.index, grad_output],
output_dtypes=[grad_output.dtype],
output_shapes=[tmp.shape],
attr_code=attr_code)[0]
return grad_input
return grad_input
class ScatterACL(jt.Function):
def __init__(self):
super(ScatterACL, self).__init__()
def execute(self, input, dim, index, src, reduce='void'):
self.dim = dim
self.index = index
self.reduce = reduce
attr_code = f"""
op.jt_name = "scatter";
ScatterAttr *attr = new ScatterAttr();
attr->axis = {dim};
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
op.op_attr.reset(attr);
"""
result = gather_scatter_cmd("Scatter", [input, self.index, src],
output_dtypes=[input.dtype],
output_shapes=[input.shape],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
attr_code = f"""
op.jt_name = "gather";
GatherAttr *attr = new GatherAttr();
attr->dim = {self.dim};
op.op_attr.reset(attr);
"""
grad_input = gather_scatter_cmd("Gather", [grad_output, self.index],
output_dtypes=[grad_output.dtype],
output_shapes=[self.index.shape],
attr_code=attr_code)[0]
return grad_output, None, None, grad_input

View File

@ -27,7 +27,7 @@
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "gather_op_acl.h"
#include "gather_scatter_op_acl.h"
namespace jittor
{

View File

@ -8,8 +8,9 @@ namespace jittor
{
protected:
string name;
string name; // special to random op
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
RandomOpRunner();
RandomOpRunner(const string &name);

View File

@ -0,0 +1,126 @@
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
import jittor as jt
import math
import numpy as np
from typing import Union
from collections.abc import Sequence, Iterable
def where_cmd(name: str,
inputs: list,
output_dtypes: list = None,
output_shapes: list = None,
attr_code: str = "",
attr_header: str = "",
outputs: list = None):
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
cuda_header = '''
#include "acl/aclops/aclops.h"
'''
outputs_ = []
if outputs is not None:
outputs_ = outputs
else:
assert output_dtypes is not None
assert output_shapes is not None
assert len(output_dtypes) == len(output_shapes)
for i in range(len(output_shapes)):
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
input_code = ''
for i in range(len(inputs)):
input_code += f"op.add(in{i}, true);\n"
output_code = ''
for i in range(len(outputs_)):
output_code += f"op.add(out{i}, false);\n"
return jt.code(outputs=outputs_,
inputs=inputs,
cuda_header=attr_header + cuda_header,
cuda_src=f"""
// aclop
{name}OpRunner op;
{input_code}
{output_code}
{attr_code}
op.run();""")
class NonzeroACL(jt.Function):
def __init__(self):
super(NonzeroACL, self).__init__()
def execute(self, x):
attr_code = f"""
op.jt_name = "nonzero";
"""
nonzero_cnt = (x != 0.0).sum().item()
result = where_cmd("Nonzero", [x],
output_dtypes=['int64'],
output_shapes=[(nonzero_cnt, x.ndim)],
attr_code=attr_code)[0]
return result
def grad(self, grad_output):
return grad_output
class WhereACL(jt.Function):
def __init__(self):
super(WhereACL, self).__init__()
def execute(self, condition, x=None, y=None):
# case 1 (unary)
if y is None:
self.unary = True
# In this case, `condition` is the input, while `x` is dtype
result = NonzeroACL()(condition).t()
result = [result[i] for i in range(result.size(0))]
return result
# The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
# case 2 (cond ? x : y)
else:
self.condition = condition
if x.dtype != y.dtype:
if x.dtype == jt.float32:
y = y.float32()
elif y.dtype == jt.float32:
x = x.float32()
else:
x = x.to(y.dtype)
self.x = x
self.y = y
result = where_cmd("Where", [condition, x, y],
output_dtypes=[x.dtype],
output_shapes=[x.shape],
attr_code="op.jt_name=\"where\";")[0]
return result
def grad(self, grad_output):
if hasattr(self, 'unary') and self.unary:
return grad_output
else:
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
grad_x = where_cmd("Where", [self.condition, grad_output, tmp],
output_dtypes=[self.x.dtype],
output_shapes=[self.x.shape],
attr_code="op.jt_name=\"where\";")[0]
grad_y = where_cmd("Where", [self.condition, tmp, grad_output],
output_dtypes=[self.y.dtype],
output_shapes=[self.y.shape],
attr_code="op.jt_name=\"where\";")[0]
return grad_output, grad_x, grad_y

View File

@ -0,0 +1,79 @@
#pragma once
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "where_op_acl.h"
namespace jittor
{
WhereOpRunner::WhereOpRunner() : BaseOpRunner("Where")
{
}
void WhereOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnSWhereGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSWhere failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
NonzeroOpRunner::NonzeroOpRunner() : BaseOpRunner("Nonzero")
{
}
void NonzeroOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
ret = aclnnNonzeroGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
checkRet(ret);
if (workspaceSize > 0)
{
mallocWorkSpace(workspaceSize);
}
ret = aclnnNonzero(workspaceAddr, workspaceSize, executor, aclstream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnNonzero failed. ERROR: %d\n", name.c_str(), ret); return);
syncRun();
return;
}
}

View File

@ -0,0 +1,24 @@
#pragma once
#include "utils.h"
#include "base_op.h"
namespace jittor
{
class WhereOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
WhereOpRunner();
};
class NonzeroOpRunner : public BaseOpRunner
{
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
public:
NonzeroOpRunner();
};
}