mirror of https://github.com/Jittor/Jittor
fix issue #271
This commit is contained in:
parent
7cfd216372
commit
eee405669f
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.1.0'
|
||||
__version__ = '1.3.1.1'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -856,14 +856,19 @@ string OpCompiler::__get_fused_src(
|
|||
string arg_name = op_name + "_output";
|
||||
string argp_name = op_name + "_outputp";
|
||||
string T = ((ArrayOp*)ops[oi])->output->dtype().to_cstring();
|
||||
fused_kernel_args += " ArrayOp* " + op_name + " = (ArrayOp*)(ops[" + S(oi) + "]);\n";
|
||||
// op_name = "((ArrayOp*)(ops[" + S(oi) + "]))";
|
||||
fused_kernel_args += " Var* " + arg_name + " = " + op_name + "->output;\n";
|
||||
|
||||
fused_kernel += " auto* " + argp_name + " = " + arg_name + "->ptr<" + T + ">();\n";
|
||||
fused_kernel += " " + argp_name + "[0] = " + op_name + "->ptr<" + T + ">()[0];\n";
|
||||
fused_kernel += " int " + arg_name + "shape0 = 1;\n";
|
||||
fused_kernel += " int " + arg_name + "stride0 = 1;\n";
|
||||
fused_kernel_args += precompile({{"oi",S(oi)}, {"T", T}}, R"(
|
||||
Var* op@oi@@_output = ((ArrayOp*)(ops[@oi]))->output;
|
||||
@T op@oi@@_outputv = ((ArrayOp*)(ops[@oi]))->ptr<@T>()[0];
|
||||
)");
|
||||
|
||||
|
||||
fused_kernel += precompile({{"oi",S(oi)}, {"T", T}}, R"(
|
||||
@T* op@oi@@_outputp = op@oi@@_output->ptr<@T>();
|
||||
op@oi@@_outputp[0] = op@oi@@_outputv;
|
||||
)");
|
||||
|
||||
|
||||
|
||||
fused_includes += "#include \"ops/array_op.h\"\n";
|
||||
op_members[oi].push_back(arg_name);
|
||||
|
|
|
@ -56,6 +56,11 @@ void LoopToFuncPass::run() {
|
|||
if (d->has_attr("rvalue")) {
|
||||
auto& rvalue = d->attrs["rvalue"];
|
||||
auto& dtype = d->attrs["dtype"];
|
||||
if (endswith(d->attrs["lvalue"], "_value") ||
|
||||
endswith(d->attrs["lvalue"], "_outputv")) {
|
||||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
if (rvalue.find("ops") != string::npos)
|
||||
continue;
|
||||
if (dtype=="Var*")
|
||||
|
@ -67,10 +72,6 @@ void LoopToFuncPass::run() {
|
|||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
if (endswith(d->attrs["lvalue"], "_value")) {
|
||||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
func->push_back(d->clone());
|
||||
|
|
|
@ -122,6 +122,9 @@ void LoopVarAnalyzePass::run() {
|
|||
&& (op->outputs().front()->shape.size() != max_elm_dim ||
|
||||
std::abs(op->outputs().front()->num) != max_elm_size))
|
||||
continue;
|
||||
if (op->name_ex() == "array")
|
||||
// array op should not be loop var
|
||||
continue;
|
||||
Var* loop_var;
|
||||
if (op->type() == OpType::broadcast || op->name_ex() == "index") {
|
||||
loop_var = op->output(0);
|
||||
|
|
|
@ -193,6 +193,23 @@ class TestArray(unittest.TestCase):
|
|||
assert str(c.dtype) == t
|
||||
np.testing.assert_allclose(a, c)
|
||||
|
||||
def test_scalar_fuse_unary(self):
|
||||
with jt.profile_scope() as rep:
|
||||
a = jt.array([1])
|
||||
b = -a
|
||||
a = a.clone()
|
||||
b = b.clone()
|
||||
jt.sync([a, b])
|
||||
assert a.data == 1
|
||||
assert b.data == -1
|
||||
assert len(rep) == 2
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
def test_scalar_fuse_unary_cuda(self):
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
self.test_scalar_fuse_unary()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue