mirror of https://github.com/Jittor/Jittor
add negtive dim support
This commit is contained in:
parent
d83389117f
commit
a62b45d6ca
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.6.7'
|
||||
__version__ = '1.1.6.8'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -120,5 +120,19 @@ class TestBroadcastToOp2Cuda(TestBroadcastToOp):
|
|||
def tearDown(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
|
||||
class TestBroadcastToOpMisc(unittest.TestCase):
|
||||
def test_negtive_dim(self):
|
||||
a = jt.array([1,2])
|
||||
assert (a.broadcast([2,2], [-1]).data == [[1,1],[2,2]]).all()
|
||||
assert (a.broadcast([2,2], [-2]).data == [[1,2],[1,2]]).all()
|
||||
|
||||
def test_negtive_dim2(self):
|
||||
a = jt.array([1,2])
|
||||
b = jt.zeros((2,2))
|
||||
assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all()
|
||||
assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -70,5 +70,12 @@ class TestReduceOpCuda2(TestReduceOp):
|
|||
def tearDown(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
|
||||
class TestReduceOpMisc(unittest.TestCase):
|
||||
def test_negtive_dim(self):
|
||||
a = jt.array([[1,2],[3,4]])
|
||||
assert (a.sum(-1).data == [3,7]).all()
|
||||
assert (a.sum(-2).data == [4,6]).all()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -27,8 +27,13 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
|
|||
z = create_output(NanoVector(), x->dtype());
|
||||
bcast_mask = 0;
|
||||
keepdims = 0;
|
||||
auto ydim = y->shape.size();
|
||||
if (dims.size()) {
|
||||
for (auto a : dims) bcast_mask |= 1 << a;
|
||||
for (auto dim : dims) {
|
||||
if (dim<0) dim += ydim;
|
||||
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
|
||||
bcast_mask |= 1 << dim;
|
||||
}
|
||||
} else
|
||||
keepdims = 1;
|
||||
}
|
||||
|
@ -62,8 +67,13 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
|
|||
z = create_output(nullptr, x->dtype());
|
||||
bcast_mask = 0;
|
||||
keepdims = 0;
|
||||
auto ydim = shape.size();
|
||||
if (dims.size()) {
|
||||
for (auto a : dims) bcast_mask |= 1 << a;
|
||||
for (auto dim : dims) {
|
||||
if (dim<0) dim += ydim;
|
||||
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
|
||||
bcast_mask |= 1 << dim;
|
||||
}
|
||||
} else
|
||||
keepdims = 1;
|
||||
}
|
||||
|
|
|
@ -57,7 +57,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
} else {
|
||||
reduce_mask = 0;
|
||||
for (auto dim : dims) {
|
||||
CHECKop(dim,<,xdim) << "Wrong dims number:" << dims;
|
||||
if (dim<0) dim += xdim;
|
||||
CHECK(dim>=0 && dim<xdim) << "Wrong dims number:" << dims;
|
||||
reduce_mask |= 1<<dim;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue