mirror of https://github.com/Jittor/Jittor
split where,scatter,floor
This commit is contained in:
parent
064af9d543
commit
8762352c64
|
@ -834,7 +834,7 @@ def change_function():
|
|||
# attr_code=attr_code)[0]
|
||||
# return grad_input
|
||||
|
||||
from .aclops.gather_op import GatherACL
|
||||
from .aclops.gather_scatter_op import GatherACL
|
||||
|
||||
def gather_acl(input, dim, index):
|
||||
return GatherACL()(input, dim, index)
|
||||
|
@ -963,194 +963,201 @@ def change_function():
|
|||
inshape = inshape.shape
|
||||
return IndexACL()(inshape, dim, dtype)
|
||||
|
||||
class ScatterACL(Function):
|
||||
# class ScatterACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ScatterACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(ScatterACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index, src, reduce='void'):
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
self.reduce = reduce
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {dim};
|
||||
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = acl_cmd("Scatter", [input, self.index, src],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
# def execute(self, input, dim, index, src, reduce='void'):
|
||||
# self.dim = dim
|
||||
# self.index = index
|
||||
# self.reduce = reduce
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "scatter";
|
||||
# ScatterAttr *attr = new ScatterAttr();
|
||||
# attr->axis = {dim};
|
||||
# attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# result = acl_cmd("Scatter", [input, self.index, src],
|
||||
# output_dtypes=[input.dtype],
|
||||
# output_shapes=[input.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = acl_cmd("Gather", [grad_output, self.index],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[self.index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_output, None, None, grad_input
|
||||
# def grad(self, grad_output):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "gather";
|
||||
# GatherAttr *attr = new GatherAttr();
|
||||
# attr->dim = {self.dim};
|
||||
# op.op_attr.reset(attr);
|
||||
# """
|
||||
# grad_input = acl_cmd("Gather", [grad_output, self.index],
|
||||
# output_dtypes=[grad_output.dtype],
|
||||
# output_shapes=[self.index.shape],
|
||||
# attr_code=attr_code)[0]
|
||||
# return grad_output, None, None, grad_input
|
||||
|
||||
from .aclops.gather_scatter_op import ScatterACL
|
||||
def scatter_acl(input, dim, index, src, reduce='void'):
|
||||
return ScatterACL()(input, dim, index, src, reduce)
|
||||
|
||||
class WhereACL(Function):
|
||||
# class WhereACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(WhereACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(WhereACL, self).__init__()
|
||||
|
||||
def execute(self, condition, x=None, y=None):
|
||||
# case 1 (unary)
|
||||
if y is None:
|
||||
self.unary = True
|
||||
# def execute(self, condition, x=None, y=None):
|
||||
# # case 1 (unary)
|
||||
# if y is None:
|
||||
# self.unary = True
|
||||
|
||||
# In this case, `condition` is the input, while `x` is dtype
|
||||
result = nonzero_acl(condition).t()
|
||||
result = [result[i] for i in range(result.size(0))]
|
||||
return result
|
||||
# The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
|
||||
# # In this case, `condition` is the input, while `x` is dtype
|
||||
# result = nonzero_acl(condition).t()
|
||||
# result = [result[i] for i in range(result.size(0))]
|
||||
# return result
|
||||
# # The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
|
||||
|
||||
# case 2 (cond ? x : y)
|
||||
else:
|
||||
self.condition = condition
|
||||
# # case 2 (cond ? x : y)
|
||||
# else:
|
||||
# self.condition = condition
|
||||
|
||||
if x.dtype != y.dtype:
|
||||
if x.dtype == jt.float32:
|
||||
y = y.float32()
|
||||
elif y.dtype == jt.float32:
|
||||
x = x.float32()
|
||||
else:
|
||||
x = x.to(y.dtype)
|
||||
# if x.dtype != y.dtype:
|
||||
# if x.dtype == jt.float32:
|
||||
# y = y.float32()
|
||||
# elif y.dtype == jt.float32:
|
||||
# x = x.float32()
|
||||
# else:
|
||||
# x = x.to(y.dtype)
|
||||
|
||||
self.x = x
|
||||
self.y = y
|
||||
# self.x = x
|
||||
# self.y = y
|
||||
|
||||
result = acl_cmd("Where", [condition, x, y],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return result
|
||||
# result = acl_cmd("Where", [condition, x, y],
|
||||
# output_dtypes=[x.dtype],
|
||||
# output_shapes=[x.shape],
|
||||
# attr_code="op.jt_name=\"where\";")[0]
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
if hasattr(self, 'unary') and self.unary:
|
||||
return grad_output
|
||||
else:
|
||||
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
|
||||
grad_x = acl_cmd("Where", [self.condition, grad_output, tmp],
|
||||
output_dtypes=[self.x.dtype],
|
||||
output_shapes=[self.x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
# def grad(self, grad_output):
|
||||
# if hasattr(self, 'unary') and self.unary:
|
||||
# return grad_output
|
||||
# else:
|
||||
# tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
|
||||
# grad_x = acl_cmd("Where", [self.condition, grad_output, tmp],
|
||||
# output_dtypes=[self.x.dtype],
|
||||
# output_shapes=[self.x.shape],
|
||||
# attr_code="op.jt_name=\"where\";")[0]
|
||||
|
||||
grad_y = acl_cmd("Where", [self.condition, tmp, grad_output],
|
||||
output_dtypes=[self.y.dtype],
|
||||
output_shapes=[self.y.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return grad_output, grad_x, grad_y
|
||||
# grad_y = acl_cmd("Where", [self.condition, tmp, grad_output],
|
||||
# output_dtypes=[self.y.dtype],
|
||||
# output_shapes=[self.y.shape],
|
||||
# attr_code="op.jt_name=\"where\";")[0]
|
||||
# return grad_output, grad_x, grad_y
|
||||
|
||||
from .aclops.where_op import WhereACL
|
||||
|
||||
def where_acl(condition, x=None, y=None):
|
||||
return WhereACL()(condition, x, y)
|
||||
|
||||
class NonzeroACL(Function):
|
||||
# class NonzeroACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(NonzeroACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(NonzeroACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
attr_code = f"""
|
||||
op.jt_name = "nonzero";
|
||||
"""
|
||||
nonzero_cnt = (x != 0.0).sum().item()
|
||||
# def execute(self, x):
|
||||
# attr_code = f"""
|
||||
# op.jt_name = "nonzero";
|
||||
# """
|
||||
# nonzero_cnt = (x != 0.0).sum().item()
|
||||
|
||||
result = acl_cmd("Nonzero", [x],
|
||||
output_dtypes=['int64'],
|
||||
output_shapes=[(nonzero_cnt, x.ndim)],
|
||||
attr_code=attr_code)[0]
|
||||
# result = acl_cmd("Nonzero", [x],
|
||||
# output_dtypes=['int64'],
|
||||
# output_shapes=[(nonzero_cnt, x.ndim)],
|
||||
# attr_code=attr_code)[0]
|
||||
|
||||
return result
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
# def grad(self, grad_output):
|
||||
# return grad_output
|
||||
|
||||
from .aclops.where_op import NonzeroACL
|
||||
|
||||
def nonzero_acl(x):
|
||||
return NonzeroACL()(x)
|
||||
|
||||
class FloorIntACL(Function):
|
||||
# class FloorIntACL(Function):
|
||||
|
||||
def __init__(self):
|
||||
super(FloorIntACL, self).__init__()
|
||||
# def __init__(self):
|
||||
# super(FloorIntACL, self).__init__()
|
||||
|
||||
def execute(self, input):
|
||||
self.shape = input.shape
|
||||
result = acl_cmd("Floor", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code="op.jt_name=\"floor\";")[0]
|
||||
return result
|
||||
# def execute(self, input):
|
||||
# self.shape = input.shape
|
||||
# result = acl_cmd("Floor", [input],
|
||||
# output_dtypes=[input.dtype],
|
||||
# output_shapes=[input.shape],
|
||||
# attr_code="op.jt_name=\"floor\";")[0]
|
||||
# return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return jt.zeros(self.shape, dtype=grad_output.dtype)
|
||||
# def grad(self, grad_output):
|
||||
# return jt.zeros(self.shape, dtype=grad_output.dtype)
|
||||
|
||||
from .aclops.floor_op import FloorIntACL
|
||||
|
||||
def floor_int_acl(x):
|
||||
return FloorIntACL()(x)
|
||||
|
||||
def caculate_shape(tensors):
|
||||
if isinstance(tensors, jt.Var):
|
||||
# tensors = tensors[0]
|
||||
return tensors.shape
|
||||
elif isinstance(tensors, (int, float)):
|
||||
return []
|
||||
elif isinstance(tensors, (list, tuple)):
|
||||
# return [caculate_shape(tensor) for tensor in tensors]
|
||||
sub_shape = caculate_shape(tensors[0])
|
||||
return [len(tensors)] + sub_shape
|
||||
else:
|
||||
assert False, f"not implemented for {type(tensors)}"
|
||||
# def caculate_shape(tensors):
|
||||
# if isinstance(tensors, jt.Var):
|
||||
# # tensors = tensors[0]
|
||||
# return tensors.shape
|
||||
# elif isinstance(tensors, (int, float)):
|
||||
# return []
|
||||
# elif isinstance(tensors, (list, tuple)):
|
||||
# # return [caculate_shape(tensor) for tensor in tensors]
|
||||
# sub_shape = caculate_shape(tensors[0])
|
||||
# return [len(tensors)] + sub_shape
|
||||
# else:
|
||||
# assert False, f"not implemented for {type(tensors)}"
|
||||
|
||||
def can_broadcast_and_shape(shape1, shape2):
|
||||
"""
|
||||
检查两个张量是否可以广播,并返回广播后的形状。
|
||||
# def can_broadcast_and_shape(shape1, shape2):
|
||||
# """
|
||||
# 检查两个张量是否可以广播,并返回广播后的形状。
|
||||
|
||||
参数:
|
||||
- shape1: 第一个张量的形状(tuple 或 list)
|
||||
- shape2: 第二个张量的形状(tuple 或 list)
|
||||
# 参数:
|
||||
# - shape1: 第一个张量的形状(tuple 或 list)
|
||||
# - shape2: 第二个张量的形状(tuple 或 list)
|
||||
|
||||
返回:
|
||||
- can_broadcast: 布尔值,表示是否可以广播
|
||||
- broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None
|
||||
"""
|
||||
# 将形状转换为元组,以防输入是列表
|
||||
shape1 = tuple(shape1)
|
||||
shape2 = tuple(shape2)
|
||||
# 返回:
|
||||
# - can_broadcast: 布尔值,表示是否可以广播
|
||||
# - broadcast_shape: 如果可以广播,返回广播后的形状;否则返回 None
|
||||
# """
|
||||
# # 将形状转换为元组,以防输入是列表
|
||||
# shape1 = tuple(shape1)
|
||||
# shape2 = tuple(shape2)
|
||||
|
||||
# 使两个形状的长度一致,通过在前面补1
|
||||
len1, len2 = len(shape1), len(shape2)
|
||||
if len1 < len2:
|
||||
shape1 = (1, ) * (len2 - len1) + shape1
|
||||
elif len2 < len1:
|
||||
shape2 = (1, ) * (len1 - len2) + shape2
|
||||
# # 使两个形状的长度一致,通过在前面补1
|
||||
# len1, len2 = len(shape1), len(shape2)
|
||||
# if len1 < len2:
|
||||
# shape1 = (1, ) * (len2 - len1) + shape1
|
||||
# elif len2 < len1:
|
||||
# shape2 = (1, ) * (len1 - len2) + shape2
|
||||
|
||||
broadcast_shape = []
|
||||
# broadcast_shape = []
|
||||
|
||||
# 从最后一维开始检查每一维度
|
||||
for dim1, dim2 in zip(shape1, shape2):
|
||||
if dim1 == dim2:
|
||||
broadcast_shape.append(dim1)
|
||||
elif dim1 == 1:
|
||||
broadcast_shape.append(dim2)
|
||||
elif dim2 == 1:
|
||||
broadcast_shape.append(dim1)
|
||||
else:
|
||||
# 如果在某一维度上不兼容,则不能广播
|
||||
return False, None
|
||||
# # 从最后一维开始检查每一维度
|
||||
# for dim1, dim2 in zip(shape1, shape2):
|
||||
# if dim1 == dim2:
|
||||
# broadcast_shape.append(dim1)
|
||||
# elif dim1 == 1:
|
||||
# broadcast_shape.append(dim2)
|
||||
# elif dim2 == 1:
|
||||
# broadcast_shape.append(dim1)
|
||||
# else:
|
||||
# # 如果在某一维度上不兼容,则不能广播
|
||||
# return False, None
|
||||
|
||||
return True, tuple(broadcast_shape)
|
||||
# return True, tuple(broadcast_shape)
|
||||
|
||||
# class GetItemACL(Function):
|
||||
|
||||
|
|
|
@ -257,16 +257,16 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], 0.0, 1.0, attr->seed, attr->offset, &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 15:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
case 17:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 15:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 17:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 18:
|
||||
{
|
||||
auto attr = dynamic_cast<TriuAttr *>(op_attr.get());
|
||||
|
@ -384,11 +384,11 @@ namespace jittor
|
|||
// ret = it->second.getWorkspaceSizeFuncScatter(inputTensors[0], attr->axis, inputTensors[1], inputTensors[2], attr->reduction, outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
case 31:
|
||||
{
|
||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
break;
|
||||
}
|
||||
// case 31:
|
||||
// {
|
||||
// ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
// break;
|
||||
// }
|
||||
// case 32:
|
||||
// {
|
||||
// auto indexTensorList = aclCreateTensorList(&inputTensors[1], input_num - 1);
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include <acl/aclops/pool_op_acl.h>
|
||||
#include <acl/aclops/flip_op_acl.h>
|
||||
#include <acl/aclops/concat_op_acl.h>
|
||||
#include <acl/aclops/gather_op_acl.h>
|
||||
#include <acl/aclops/gather_scatter_op_acl.h>
|
||||
#include <acl/aclops/cumsum_op_acl.h>
|
||||
#include <acl/aclops/index_op_acl.h>
|
||||
#include <acl/aclops/where_op_acl.h>
|
||||
#include <acl/aclops/floor_op_acl.h>
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def floor_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class FloorIntACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(FloorIntACL, self).__init__()
|
||||
|
||||
def execute(self, input):
|
||||
self.shape = input.shape
|
||||
result = floor_cmd("Floor", [input],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code="op.jt_name=\"floor\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return jt.zeros(self.shape, dtype=grad_output.dtype)
|
|
@ -0,0 +1,56 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "floor_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
FloorOpRunner::FloorOpRunner() : BaseOpRunner("Floor")
|
||||
{
|
||||
}
|
||||
|
||||
void FloorOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnFloorGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnFloor(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnFloor failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class FloorOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
FloorOpRunner();
|
||||
};
|
||||
}
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def gather_cmd(name: str,
|
||||
def gather_scatter_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
|
@ -65,7 +65,7 @@ class GatherACL(jt.Function):
|
|||
attr->dim = {dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_cmd("Gather", [input, index],
|
||||
result = gather_scatter_cmd("Gather", [input, index],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
|
@ -80,8 +80,43 @@ class GatherACL(jt.Function):
|
|||
attr->reduction = {1};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
grad_input = gather_scatter_cmd("Scatter", [tmp, self.index, grad_output],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[tmp.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_input
|
||||
return grad_input
|
||||
|
||||
class ScatterACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(ScatterACL, self).__init__()
|
||||
|
||||
def execute(self, input, dim, index, src, reduce='void'):
|
||||
self.dim = dim
|
||||
self.index = index
|
||||
self.reduce = reduce
|
||||
attr_code = f"""
|
||||
op.jt_name = "scatter";
|
||||
ScatterAttr *attr = new ScatterAttr();
|
||||
attr->axis = {dim};
|
||||
attr->reduction = {1 if reduce == 'add' else 2 if reduce == 'mul' else 0};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
result = gather_scatter_cmd("Scatter", [input, self.index, src],
|
||||
output_dtypes=[input.dtype],
|
||||
output_shapes=[input.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
attr_code = f"""
|
||||
op.jt_name = "gather";
|
||||
GatherAttr *attr = new GatherAttr();
|
||||
attr->dim = {self.dim};
|
||||
op.op_attr.reset(attr);
|
||||
"""
|
||||
grad_input = gather_scatter_cmd("Gather", [grad_output, self.index],
|
||||
output_dtypes=[grad_output.dtype],
|
||||
output_shapes=[self.index.shape],
|
||||
attr_code=attr_code)[0]
|
||||
return grad_output, None, None, grad_input
|
|
@ -27,7 +27,7 @@
|
|||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "gather_op_acl.h"
|
||||
#include "gather_scatter_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
|
@ -8,8 +8,9 @@ namespace jittor
|
|||
{
|
||||
|
||||
protected:
|
||||
string name;
|
||||
string name; // special to random op
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
|
||||
public:
|
||||
RandomOpRunner();
|
||||
RandomOpRunner(const string &name);
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
import jittor as jt
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from typing import Union
|
||||
from collections.abc import Sequence, Iterable
|
||||
|
||||
def where_cmd(name: str,
|
||||
inputs: list,
|
||||
output_dtypes: list = None,
|
||||
output_shapes: list = None,
|
||||
attr_code: str = "",
|
||||
attr_header: str = "",
|
||||
outputs: list = None):
|
||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||
|
||||
cuda_header = '''
|
||||
#include "acl/aclops/aclops.h"
|
||||
'''
|
||||
outputs_ = []
|
||||
if outputs is not None:
|
||||
outputs_ = outputs
|
||||
else:
|
||||
assert output_dtypes is not None
|
||||
assert output_shapes is not None
|
||||
assert len(output_dtypes) == len(output_shapes)
|
||||
for i in range(len(output_shapes)):
|
||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
||||
input_code = ''
|
||||
for i in range(len(inputs)):
|
||||
input_code += f"op.add(in{i}, true);\n"
|
||||
|
||||
output_code = ''
|
||||
for i in range(len(outputs_)):
|
||||
output_code += f"op.add(out{i}, false);\n"
|
||||
return jt.code(outputs=outputs_,
|
||||
inputs=inputs,
|
||||
cuda_header=attr_header + cuda_header,
|
||||
cuda_src=f"""
|
||||
|
||||
// aclop
|
||||
{name}OpRunner op;
|
||||
{input_code}
|
||||
{output_code}
|
||||
{attr_code}
|
||||
op.run();""")
|
||||
|
||||
class NonzeroACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(NonzeroACL, self).__init__()
|
||||
|
||||
def execute(self, x):
|
||||
attr_code = f"""
|
||||
op.jt_name = "nonzero";
|
||||
"""
|
||||
nonzero_cnt = (x != 0.0).sum().item()
|
||||
|
||||
result = where_cmd("Nonzero", [x],
|
||||
output_dtypes=['int64'],
|
||||
output_shapes=[(nonzero_cnt, x.ndim)],
|
||||
attr_code=attr_code)[0]
|
||||
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
return grad_output
|
||||
|
||||
class WhereACL(jt.Function):
|
||||
|
||||
def __init__(self):
|
||||
super(WhereACL, self).__init__()
|
||||
|
||||
def execute(self, condition, x=None, y=None):
|
||||
# case 1 (unary)
|
||||
if y is None:
|
||||
self.unary = True
|
||||
|
||||
# In this case, `condition` is the input, while `x` is dtype
|
||||
result = NonzeroACL()(condition).t()
|
||||
result = [result[i] for i in range(result.size(0))]
|
||||
return result
|
||||
# The return value should be a tuple, but even we set to tuple here, it will be convert to a list in `Function.__call__`.
|
||||
|
||||
# case 2 (cond ? x : y)
|
||||
else:
|
||||
self.condition = condition
|
||||
|
||||
if x.dtype != y.dtype:
|
||||
if x.dtype == jt.float32:
|
||||
y = y.float32()
|
||||
elif y.dtype == jt.float32:
|
||||
x = x.float32()
|
||||
else:
|
||||
x = x.to(y.dtype)
|
||||
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
result = where_cmd("Where", [condition, x, y],
|
||||
output_dtypes=[x.dtype],
|
||||
output_shapes=[x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return result
|
||||
|
||||
def grad(self, grad_output):
|
||||
if hasattr(self, 'unary') and self.unary:
|
||||
return grad_output
|
||||
else:
|
||||
tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype)
|
||||
grad_x = where_cmd("Where", [self.condition, grad_output, tmp],
|
||||
output_dtypes=[self.x.dtype],
|
||||
output_shapes=[self.x.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
|
||||
grad_y = where_cmd("Where", [self.condition, tmp, grad_output],
|
||||
output_dtypes=[self.y.dtype],
|
||||
output_shapes=[self.y.shape],
|
||||
attr_code="op.jt_name=\"where\";")[0]
|
||||
return grad_output, grad_x, grad_y
|
|
@ -0,0 +1,79 @@
|
|||
#pragma once
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "where_op_acl.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
WhereOpRunner::WhereOpRunner() : BaseOpRunner("Where")
|
||||
{
|
||||
}
|
||||
|
||||
void WhereOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnSWhereGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnSWhere(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnSWhere failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
NonzeroOpRunner::NonzeroOpRunner() : BaseOpRunner("Nonzero")
|
||||
{
|
||||
}
|
||||
|
||||
void NonzeroOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
ret = aclnnNonzeroGetWorkspaceSize(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
||||
|
||||
checkRet(ret);
|
||||
|
||||
if (workspaceSize > 0)
|
||||
{
|
||||
mallocWorkSpace(workspaceSize);
|
||||
}
|
||||
|
||||
ret = aclnnNonzero(workspaceAddr, workspaceSize, executor, aclstream);
|
||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnNonzero failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||
|
||||
syncRun();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
#include "utils.h"
|
||||
#include "base_op.h"
|
||||
|
||||
namespace jittor
|
||||
{
|
||||
class WhereOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
WhereOpRunner();
|
||||
};
|
||||
|
||||
class NonzeroOpRunner : public BaseOpRunner
|
||||
{
|
||||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
public:
|
||||
NonzeroOpRunner();
|
||||
};
|
||||
}
|
Loading…
Reference in New Issue