mirror of https://github.com/Jittor/Jittor
fix conv,relu, split expand
This commit is contained in:
parent
9e1e764958
commit
981391ea6d
|
@ -519,32 +519,7 @@ def change_function():
|
||||||
def transpose_acl(x, *dim):
|
def transpose_acl(x, *dim):
|
||||||
return TransPoseACL()(x, *dim)
|
return TransPoseACL()(x, *dim)
|
||||||
|
|
||||||
class ReLUACL(Function):
|
from .aclops.relu_op import ReLUACL
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(ReLUACL, self).__init__()
|
|
||||||
|
|
||||||
def execute(self, x):
|
|
||||||
x = x.float32()
|
|
||||||
self.input = x
|
|
||||||
result = acl_cmd("ReLU", [x],
|
|
||||||
output_dtypes=[x.dtype],
|
|
||||||
output_shapes=[x.shape],
|
|
||||||
attr_code="op.jt_name=\"unary\";")[0]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def grad(self, grad_output):
|
|
||||||
mask = acl_cmd("Greater",
|
|
||||||
[self.input, jt.zeros(self.input.shape)],
|
|
||||||
output_dtypes=[self.input.dtype],
|
|
||||||
output_shapes=[self.input.shape],
|
|
||||||
attr_code="op.jt_name=\"binary\";")[0]
|
|
||||||
grad_input = acl_cmd("Mul", [grad_output, mask],
|
|
||||||
output_dtypes=[grad_output.dtype],
|
|
||||||
output_shapes=[grad_output.shape],
|
|
||||||
attr_code="op.jt_name=\"binary\";")[0]
|
|
||||||
return grad_input
|
|
||||||
|
|
||||||
class ReLU(jt.nn.Module):
|
class ReLU(jt.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -274,7 +274,6 @@ namespace jittor
|
||||||
else if (op->name() == string("reduce"))
|
else if (op->name() == string("reduce"))
|
||||||
{
|
{
|
||||||
auto rop = (ReduceOp *)op;
|
auto rop = (ReduceOp *)op;
|
||||||
// AclOpRunner op("");
|
|
||||||
ReduceOpRunner op;
|
ReduceOpRunner op;
|
||||||
if (rop->ns == ns_add)
|
if (rop->ns == ns_add)
|
||||||
op.op_idx = 9;
|
op.op_idx = 9;
|
||||||
|
@ -308,7 +307,7 @@ namespace jittor
|
||||||
else if (op->name() == string("broadcast_to"))
|
else if (op->name() == string("broadcast_to"))
|
||||||
{
|
{
|
||||||
auto bop = (BroadcastToOp *)op;
|
auto bop = (BroadcastToOp *)op;
|
||||||
AclOpRunner op("Expand");
|
ExpandOpRunner op;
|
||||||
op.jt_name = "expand";
|
op.jt_name = "expand";
|
||||||
NanoVector xshape, xshape_bk = bop->x->shape;
|
NanoVector xshape, xshape_bk = bop->x->shape;
|
||||||
NanoVector zshape = bop->z->shape;
|
NanoVector zshape = bop->z->shape;
|
||||||
|
|
|
@ -191,73 +191,6 @@ namespace jittor
|
||||||
|
|
||||||
// LOGir << name << " " << jt_name;
|
// LOGir << name << " " << jt_name;
|
||||||
// LOGir<<op_idx;
|
// LOGir<<op_idx;
|
||||||
switch (op_idx)
|
|
||||||
{
|
|
||||||
case 3:
|
|
||||||
{
|
|
||||||
size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size());
|
|
||||||
ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 4:
|
|
||||||
{
|
|
||||||
ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 5:
|
|
||||||
{
|
|
||||||
ret = it->second.getWorkspaceSizeFuncUnaryNonzero(inputTensors[0], outputTensors[0], &workspaceSize, &executor);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 6:
|
|
||||||
{
|
|
||||||
ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 20:
|
|
||||||
{
|
|
||||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
|
||||||
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
|
||||||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
|
||||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
|
||||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
|
||||||
aclTensor *bias = nullptr;
|
|
||||||
if (input_num == 3)
|
|
||||||
bias = inputTensors[2];
|
|
||||||
|
|
||||||
ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], bias, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 21:
|
|
||||||
{
|
|
||||||
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
|
||||||
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
|
||||||
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
|
||||||
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
|
||||||
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
|
||||||
bool outputMask[3] = {true, true, true};
|
|
||||||
if (input_num == 3)
|
|
||||||
{
|
|
||||||
outputMask[2] = false;
|
|
||||||
}
|
|
||||||
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
|
|
||||||
auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size());
|
|
||||||
ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
{
|
|
||||||
LOGir << "not supported op: " << name;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// for debug
|
|
||||||
if (ret != ACL_SUCCESS)
|
|
||||||
{
|
|
||||||
auto tmp_err_msg = aclGetRecentErrMsg();
|
|
||||||
LOGir << name << ", " << tmp_err_msg;
|
|
||||||
}
|
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 根据第一段接口计算出的workspaceSize申请device内存
|
// 4. 根据第一段接口计算出的workspaceSize申请device内存
|
||||||
if (workspaceSize > 0)
|
if (workspaceSize > 0)
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
#include <acl/aclops/conv_op_acl.h>
|
#include <acl/aclops/conv_op_acl.h>
|
||||||
#include <acl/aclops/ternary_op_acl.h>
|
#include <acl/aclops/ternary_op_acl.h>
|
||||||
#include <acl/aclops/reduce_op_acl.h>
|
#include <acl/aclops/reduce_op_acl.h>
|
||||||
|
#include <acl/aclops/expand_op_acl.h>
|
||||||
#include <acl/aclops/getitem_op_acl.h>
|
#include <acl/aclops/getitem_op_acl.h>
|
||||||
#include <acl/aclops/setitem_op_acl.h>
|
#include <acl/aclops/setitem_op_acl.h>
|
||||||
#include <acl/aclops/matmul_op_acl.h>
|
#include <acl/aclops/matmul_op_acl.h>
|
||||||
|
|
|
@ -5,7 +5,6 @@ import ctypes
|
||||||
import glob
|
import glob
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import jittor.compiler as compiler
|
import jittor.compiler as compiler
|
||||||
from jittor.extern.acl.acl_compiler import acl_cmd_forward
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -26,14 +25,13 @@ def _ntuple(n):
|
||||||
_pair = _ntuple(2)
|
_pair = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
def conv_forward(name: str,
|
def conv_cmd(name: str,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
output_dtypes: list = None,
|
output_dtypes: list = None,
|
||||||
output_shapes: list = None,
|
output_shapes: list = None,
|
||||||
attr_code: str = "",
|
attr_code: str = "",
|
||||||
attr_header: str = "",
|
attr_header: str = "",
|
||||||
outputs: list = None,
|
outputs: list = None):
|
||||||
extra_data: dict = {}):
|
|
||||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
||||||
|
|
||||||
cuda_header = '''
|
cuda_header = '''
|
||||||
|
@ -52,58 +50,20 @@ def conv_forward(name: str,
|
||||||
for i in range(len(inputs)):
|
for i in range(len(inputs)):
|
||||||
input_code += f"op.add(in{i}, true);\n"
|
input_code += f"op.add(in{i}, true);\n"
|
||||||
|
|
||||||
|
output_code = ''
|
||||||
|
for i in range(len(outputs_)):
|
||||||
|
output_code += f"op.add(out{i}, false);\n"
|
||||||
return jt.code(outputs=outputs_,
|
return jt.code(outputs=outputs_,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
cuda_header=attr_header + cuda_header,
|
cuda_header=attr_header + cuda_header,
|
||||||
cuda_src=f"""
|
cuda_src=f"""
|
||||||
|
|
||||||
// aclop
|
// aclop
|
||||||
ConvOpRunner op;
|
{name}OpRunner op;
|
||||||
{input_code}
|
{input_code}
|
||||||
op.add(out0, false);
|
{output_code}
|
||||||
{attr_code}
|
{attr_code}
|
||||||
op.run();""",
|
op.run();""")
|
||||||
data=extra_data)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_forward(name: str,
|
|
||||||
inputs: list,
|
|
||||||
output_dtypes: list = None,
|
|
||||||
output_shapes: list = None,
|
|
||||||
attr_code: str = "",
|
|
||||||
attr_header: str = "",
|
|
||||||
outputs: list = None,
|
|
||||||
extra_data: dict = {}):
|
|
||||||
# TODO: not done for now
|
|
||||||
attr_header = "\nnamespace jittor{" + attr_header + "}\n"
|
|
||||||
|
|
||||||
cuda_header = '''
|
|
||||||
#include "acl/aclops/aclops.h"
|
|
||||||
'''
|
|
||||||
outputs_ = []
|
|
||||||
if outputs is not None:
|
|
||||||
outputs_ = outputs
|
|
||||||
else:
|
|
||||||
assert output_dtypes is not None
|
|
||||||
assert output_shapes is not None
|
|
||||||
assert len(output_dtypes) == len(output_shapes)
|
|
||||||
for i in range(len(output_shapes)):
|
|
||||||
outputs_.append(jt.empty(output_shapes[i], output_dtypes[i]))
|
|
||||||
input_code = ''
|
|
||||||
for i in range(len(inputs)):
|
|
||||||
input_code += f"op.add(in{i}, true);\n"
|
|
||||||
|
|
||||||
return jt.code(outputs=outputs_,
|
|
||||||
inputs=inputs,
|
|
||||||
cuda_header=attr_header + cuda_header,
|
|
||||||
cuda_src=f"""
|
|
||||||
// aclop
|
|
||||||
ConvOpRunner op;
|
|
||||||
{input_code}
|
|
||||||
op.add(out0, false);
|
|
||||||
{attr_code}
|
|
||||||
op.run();""",
|
|
||||||
data=extra_data)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvACL(jt.Function):
|
class ConvACL(jt.Function):
|
||||||
|
|
||||||
|
@ -151,7 +111,7 @@ class ConvACL(jt.Function):
|
||||||
inputs = [x, weight]
|
inputs = [x, weight]
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
inputs.append(bias)
|
inputs.append(bias)
|
||||||
result = conv_forward(
|
result = conv_cmd(
|
||||||
"Conv2d",
|
"Conv2d",
|
||||||
inputs,
|
inputs,
|
||||||
output_dtypes=[x.dtype],
|
output_dtypes=[x.dtype],
|
||||||
|
@ -189,7 +149,7 @@ class ConvACL(jt.Function):
|
||||||
attr->convOutPads = {{ 1,1}};
|
attr->convOutPads = {{ 1,1}};
|
||||||
op.op_attr.reset(attr);
|
op.op_attr.reset(attr);
|
||||||
"""
|
"""
|
||||||
results = acl_cmd_forward("Conv2dBackward",
|
results = conv_cmd("Conv2dBackward",
|
||||||
inputs,
|
inputs,
|
||||||
output_dtypes=output_dtypes,
|
output_dtypes=output_dtypes,
|
||||||
output_shapes=output_shapes,
|
output_shapes=output_shapes,
|
||||||
|
|
|
@ -30,12 +30,12 @@
|
||||||
|
|
||||||
namespace jittor
|
namespace jittor
|
||||||
{
|
{
|
||||||
ConvOpRunner::ConvOpRunner() : BaseOpRunner("Conv2d")
|
Conv2dOpRunner::Conv2dOpRunner() : BaseOpRunner("Conv2d")
|
||||||
{
|
{
|
||||||
use_nchw = true;
|
use_nchw = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
void Conv2dOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||||
{
|
{
|
||||||
// for conv
|
// for conv
|
||||||
aclIntArray *strides = nullptr;
|
aclIntArray *strides = nullptr;
|
||||||
|
@ -64,7 +64,54 @@ namespace jittor
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream);
|
ret = aclnnConvolution(workspaceAddr, workspaceSize, executor, aclstream);
|
||||||
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return);
|
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolution failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||||
|
|
||||||
|
syncRun();
|
||||||
|
|
||||||
|
aclDestroyIntArray(strides);
|
||||||
|
aclDestroyIntArray(pads);
|
||||||
|
aclDestroyIntArray(outPads);
|
||||||
|
aclDestroyIntArray(dilations);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Conv2dBackwardOpRunner::Conv2dBackwardOpRunner() : BaseOpRunner("Conv2dBackward")
|
||||||
|
{
|
||||||
|
use_nchw = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Conv2dBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||||
|
{
|
||||||
|
// for conv
|
||||||
|
aclIntArray *strides = nullptr;
|
||||||
|
aclIntArray *pads = nullptr;
|
||||||
|
aclIntArray *outPads = nullptr;
|
||||||
|
aclIntArray *dilations = nullptr;
|
||||||
|
auto attr = dynamic_cast<ConvAttr *>(op_attr.get());
|
||||||
|
strides = aclCreateIntArray(attr->convStrides.data(), 2);
|
||||||
|
pads = aclCreateIntArray(attr->convPads.data(), 2);
|
||||||
|
outPads = aclCreateIntArray(attr->convOutPads.data(), 2);
|
||||||
|
dilations = aclCreateIntArray(attr->convDilations.data(), 2);
|
||||||
|
bool outputMask[3] = {true, true, true};
|
||||||
|
auto input_num = in_.size();
|
||||||
|
if (input_num == 3)
|
||||||
|
{
|
||||||
|
outputMask[2] = false;
|
||||||
|
}
|
||||||
|
aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3);
|
||||||
|
auto biasSizes = aclCreateIntArray(&outputShapes[2][0], outputShapes[2].size());
|
||||||
|
ret = aclnnConvolutionBackwardGetWorkspaceSize(inputTensors[0], inputTensors[1], inputTensors[2], biasSizes, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], outputTensors[2], &workspaceSize, &executor);
|
||||||
|
|
||||||
|
checkRet(ret);
|
||||||
|
|
||||||
|
if (workspaceSize > 0)
|
||||||
|
{
|
||||||
|
mallocWorkSpace(workspaceSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = aclnnConvolutionBackward(workspaceAddr, workspaceSize, executor, aclstream);
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnConvolutionBackward failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||||
|
|
||||||
syncRun();
|
syncRun();
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,23 @@
|
||||||
|
|
||||||
namespace jittor
|
namespace jittor
|
||||||
{
|
{
|
||||||
class ConvOpRunner : public BaseOpRunner
|
class Conv2dOpRunner : public BaseOpRunner
|
||||||
{
|
{
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ConvOpRunner();
|
Conv2dOpRunner();
|
||||||
|
};
|
||||||
|
|
||||||
|
class Conv2dBackwardOpRunner : public BaseOpRunner
|
||||||
|
{
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Conv2dBackwardOpRunner();
|
||||||
};
|
};
|
||||||
}
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
#include <acl/acl.h>
|
||||||
|
#include <acl/acl_op_compiler.h>
|
||||||
|
#include <Python.h>
|
||||||
|
#include <pystate.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
|
#include "common.h"
|
||||||
|
#include "op.h"
|
||||||
|
#include "acl_jittor.h"
|
||||||
|
#include "ops/random_op.h"
|
||||||
|
#include "ops/reduce_op.h"
|
||||||
|
#include "ops/binary_op.h"
|
||||||
|
#include "ops/broadcast_to_op.h"
|
||||||
|
#include "ops/transpose_op.h"
|
||||||
|
#include "ops/array_op.h"
|
||||||
|
#include "ops/code_op.h"
|
||||||
|
#include "fused_op.h"
|
||||||
|
#include "ops/unary_op.h"
|
||||||
|
#include "ops/ternary_op.h"
|
||||||
|
#include "executor.h"
|
||||||
|
#include "misc/cuda_flags.h"
|
||||||
|
#include "mem/allocator.h"
|
||||||
|
#include "op_compiler.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
#include "opt/tuner_manager.h"
|
||||||
|
#include "utils/str_utils.h"
|
||||||
|
#include "aclnn/aclnn.h"
|
||||||
|
#include "expand_op_acl.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
ExpandOpRunner::ExpandOpRunner() : BaseOpRunner("ternary")
|
||||||
|
{
|
||||||
|
use_nchw = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExpandOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||||
|
{
|
||||||
|
aclIntArray *size = nullptr;
|
||||||
|
size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size());
|
||||||
|
ret = aclnnExpandGetWorkspaceSize(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor);
|
||||||
|
|
||||||
|
checkRet(ret);
|
||||||
|
|
||||||
|
if (workspaceSize > 0)
|
||||||
|
{
|
||||||
|
mallocWorkSpace(workspaceSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = aclnnExpand(workspaceAddr, workspaceSize, executor, aclstream);
|
||||||
|
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnExpand failed. ERROR: %d\n", name.c_str(), ret); return);
|
||||||
|
|
||||||
|
aclDestroyIntArray(size);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
#pragma once
|
||||||
|
#include "utils.h"
|
||||||
|
#include "base_op.h"
|
||||||
|
|
||||||
|
namespace jittor
|
||||||
|
{
|
||||||
|
struct ExpandOpRunner : public BaseOpRunner
|
||||||
|
{
|
||||||
|
ExpandOpRunner();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||||
|
};
|
||||||
|
}
|
|
@ -61,25 +61,24 @@ class ReLUACL(jt.Function):
|
||||||
def execute(self, x):
|
def execute(self, x):
|
||||||
x = x.float32()
|
x = x.float32()
|
||||||
self.input = x
|
self.input = x
|
||||||
result = relu_cmd("ReLU", [x],
|
result = relu_cmd("Unary", [x],
|
||||||
output_dtypes=[x.dtype],
|
output_dtypes=[x.dtype],
|
||||||
output_shapes=[x.shape],
|
output_shapes=[x.shape],
|
||||||
attr_code="op.jt_name=\"unary\";")[0]
|
attr_code="op.name=\"ReLU\";")[0]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def grad(self, grad_output):
|
def grad(self, grad_output):
|
||||||
mask = relu_cmd("Greater",
|
mask = relu_cmd("Binary",
|
||||||
[self.input, jt.zeros(self.input.shape)],
|
[self.input, jt.zeros(self.input.shape)],
|
||||||
output_dtypes=[self.input.dtype],
|
output_dtypes=[self.input.dtype],
|
||||||
output_shapes=[self.input.shape],
|
output_shapes=[self.input.shape],
|
||||||
attr_code="op.jt_name=\"binary\";")[0]
|
attr_code="op.name=\"Greater\";")[0]
|
||||||
grad_input = relu_cmd("Mul", [grad_output, mask],
|
grad_input = relu_cmd("Binary", [grad_output, mask],
|
||||||
output_dtypes=[grad_output.dtype],
|
output_dtypes=[grad_output.dtype],
|
||||||
output_shapes=[grad_output.shape],
|
output_shapes=[grad_output.shape],
|
||||||
attr_code="op.jt_name=\"binary\";")[0]
|
attr_code="op.name=\"Mul\";")[0]
|
||||||
return grad_input
|
return grad_input
|
||||||
|
|
||||||
|
|
||||||
class LeakyReLUACL(jt.Function):
|
class LeakyReLUACL(jt.Function):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue