add negtive dim support

This commit is contained in:
Dun Liang 2020-07-29 16:58:42 +08:00
parent d83389117f
commit a62b45d6ca
5 changed files with 36 additions and 4 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.6.7'
__version__ = '1.1.6.8'
from . import lock
with lock.lock_scope():
from . import compiler

View File

@ -120,5 +120,19 @@ class TestBroadcastToOp2Cuda(TestBroadcastToOp):
def tearDown(self):
jt.flags.use_cuda = 0
class TestBroadcastToOpMisc(unittest.TestCase):
def test_negtive_dim(self):
a = jt.array([1,2])
assert (a.broadcast([2,2], [-1]).data == [[1,1],[2,2]]).all()
assert (a.broadcast([2,2], [-2]).data == [[1,2],[1,2]]).all()
def test_negtive_dim2(self):
a = jt.array([1,2])
b = jt.zeros((2,2))
assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all()
assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all()
if __name__ == "__main__":
unittest.main()

View File

@ -70,5 +70,12 @@ class TestReduceOpCuda2(TestReduceOp):
def tearDown(self):
jt.flags.use_cuda = 0
class TestReduceOpMisc(unittest.TestCase):
def test_negtive_dim(self):
a = jt.array([[1,2],[3,4]])
assert (a.sum(-1).data == [3,7]).all()
assert (a.sum(-2).data == [4,6]).all()
if __name__ == "__main__":
unittest.main()

View File

@ -27,8 +27,13 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
z = create_output(NanoVector(), x->dtype());
bcast_mask = 0;
keepdims = 0;
auto ydim = y->shape.size();
if (dims.size()) {
for (auto a : dims) bcast_mask |= 1 << a;
for (auto dim : dims) {
if (dim<0) dim += ydim;
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
bcast_mask |= 1 << dim;
}
} else
keepdims = 1;
}
@ -62,8 +67,13 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
z = create_output(nullptr, x->dtype());
bcast_mask = 0;
keepdims = 0;
auto ydim = shape.size();
if (dims.size()) {
for (auto a : dims) bcast_mask |= 1 << a;
for (auto dim : dims) {
if (dim<0) dim += ydim;
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
bcast_mask |= 1 << dim;
}
} else
keepdims = 1;
}

View File

@ -57,7 +57,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
} else {
reduce_mask = 0;
for (auto dim : dims) {
CHECKop(dim,<,xdim) << "Wrong dims number:" << dims;
if (dim<0) dim += xdim;
CHECK(dim>=0 && dim<xdim) << "Wrong dims number:" << dims;
reduce_mask |= 1<<dim;
}
}