This commit is contained in:
Dun Liang 2021-10-14 16:36:04 +08:00
parent 7cfd216372
commit eee405669f
5 changed files with 38 additions and 12 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.0'
__version__ = '1.3.1.1'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -856,14 +856,19 @@ string OpCompiler::__get_fused_src(
string arg_name = op_name + "_output";
string argp_name = op_name + "_outputp";
string T = ((ArrayOp*)ops[oi])->output->dtype().to_cstring();
fused_kernel_args += " ArrayOp* " + op_name + " = (ArrayOp*)(ops[" + S(oi) + "]);\n";
// op_name = "((ArrayOp*)(ops[" + S(oi) + "]))";
fused_kernel_args += " Var* " + arg_name + " = " + op_name + "->output;\n";
fused_kernel += " auto* " + argp_name + " = " + arg_name + "->ptr<" + T + ">();\n";
fused_kernel += " " + argp_name + "[0] = " + op_name + "->ptr<" + T + ">()[0];\n";
fused_kernel += " int " + arg_name + "shape0 = 1;\n";
fused_kernel += " int " + arg_name + "stride0 = 1;\n";
fused_kernel_args += precompile({{"oi",S(oi)}, {"T", T}}, R"(
Var* op@oi@@_output = ((ArrayOp*)(ops[@oi]))->output;
@T op@oi@@_outputv = ((ArrayOp*)(ops[@oi]))->ptr<@T>()[0];
)");
fused_kernel += precompile({{"oi",S(oi)}, {"T", T}}, R"(
@T* op@oi@@_outputp = op@oi@@_output->ptr<@T>();
op@oi@@_outputp[0] = op@oi@@_outputv;
)");
fused_includes += "#include \"ops/array_op.h\"\n";
op_members[oi].push_back(arg_name);

View File

@ -56,6 +56,11 @@ void LoopToFuncPass::run() {
if (d->has_attr("rvalue")) {
auto& rvalue = d->attrs["rvalue"];
auto& dtype = d->attrs["dtype"];
if (endswith(d->attrs["lvalue"], "_value") ||
endswith(d->attrs["lvalue"], "_outputv")) {
args.push_back(d.get());
continue;
}
if (rvalue.find("ops") != string::npos)
continue;
if (dtype=="Var*")
@ -67,10 +72,6 @@ void LoopToFuncPass::run() {
args.push_back(d.get());
continue;
}
if (endswith(d->attrs["lvalue"], "_value")) {
args.push_back(d.get());
continue;
}
}
}
func->push_back(d->clone());

View File

@ -122,6 +122,9 @@ void LoopVarAnalyzePass::run() {
&& (op->outputs().front()->shape.size() != max_elm_dim ||
std::abs(op->outputs().front()->num) != max_elm_size))
continue;
if (op->name_ex() == "array")
// array op should not be loop var
continue;
Var* loop_var;
if (op->type() == OpType::broadcast || op->name_ex() == "index") {
loop_var = op->output(0);

View File

@ -193,6 +193,23 @@ class TestArray(unittest.TestCase):
assert str(c.dtype) == t
np.testing.assert_allclose(a, c)
def test_scalar_fuse_unary(self):
with jt.profile_scope() as rep:
a = jt.array([1])
b = -a
a = a.clone()
b = b.clone()
jt.sync([a, b])
assert a.data == 1
assert b.data == -1
assert len(rep) == 2
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
def test_scalar_fuse_unary_cuda(self):
with jt.flag_scope(use_cuda=1):
self.test_scalar_fuse_unary()
if __name__ == "__main__":
unittest.main()