mirror of https://github.com/Jittor/Jittor
add slice broadcast
This commit is contained in:
parent
ab30a15eae
commit
0b13930ed3
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.3.7'
|
||||
__version__ = '1.3.3.8'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -98,7 +98,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
|
|||
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
|
||||
if (x->shape.size() < shape.size()) return true;
|
||||
for (uint i=shape.size()-1, j=x->shape.size()-1; i<shape.size(); i--,j--)
|
||||
if (x->shape[j]< 0 || (x->shape[j] != shape[i] && shape[i] != 1)) return true;
|
||||
if ((x->shape[j] != shape[i] && shape[i] != 1)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -154,8 +154,6 @@ void BroadcastToOp::infer_shape() {
|
|||
int64 zs;
|
||||
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
|
||||
zs = xshape * yshape;
|
||||
} else if (xshape < 0 || yshape < 0) {
|
||||
zs = std::min(xshape, yshape);
|
||||
} else {
|
||||
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes << bcast_mask;
|
||||
zs = xshape;
|
||||
|
|
|
@ -108,11 +108,11 @@ void GetitemOp::infer_slices(
|
|||
for (int j=0; j<niv; j++) {
|
||||
auto iv_shape_j = iv_shape[niv-j-1];
|
||||
auto& out_shape_j = out_shape[first_oid_of_var+var_dim-j-1];
|
||||
CHECK(out_shape_j == iv_shape_j || out_shape_j == 1 || iv_shape_j == 1) << "Shape not match " >> out_shape_j >> "!="
|
||||
>> iv_shape_j << "data shape:" << in_shape <<
|
||||
"slice shape:" << iv_shape;
|
||||
if (out_shape_j == 1)
|
||||
out_shape_j = iv_shape_j;
|
||||
else
|
||||
ASSERT(out_shape_j == iv_shape_j || out_shape_j < 0 || iv_shape_j < 0)
|
||||
<< out_shape_j << iv_shape_j << out_shape;
|
||||
}
|
||||
} else
|
||||
if (s.is_ellipsis()) {
|
||||
|
|
|
@ -41,9 +41,7 @@ void ReshapeOp::infer_shape() {
|
|||
CHECK(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
|
||||
int64_t x_items = x->num;
|
||||
auto yshape = shape;
|
||||
if (x_items < 0) {
|
||||
// pass if input is uncertain
|
||||
} else if (uncertain_dim == 0) {
|
||||
if (uncertain_dim == 0) {
|
||||
CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size";
|
||||
} else {
|
||||
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
|
||||
|
|
|
@ -58,7 +58,6 @@ void TernaryOp::infer_shape() {
|
|||
auto shape = std::min(xshape, std::min(yshape, cshape));
|
||||
auto shape2 = std::max(xshape, std::max(yshape, cshape));
|
||||
zshape.push_back(shape2);
|
||||
if (shape < 0) continue;
|
||||
CHECK(shape==shape2) << "Shape not match" << x->shape << y->shape << cond->shape;
|
||||
}
|
||||
z->set_shape(zshape);
|
||||
|
|
|
@ -394,7 +394,39 @@ class TestSetitem(unittest.TestCase):
|
|||
jt.get_max_memory_treemap()
|
||||
|
||||
|
||||
|
||||
def test_setitem_bc(self):
|
||||
a = jt.random([10,11,12])
|
||||
b = a[jt.arange(3)[:,None],
|
||||
jt.arange(4)[None,:]]
|
||||
b.sync()
|
||||
assert (a[:3, :4] == b).all()
|
||||
|
||||
a = jt.random([10,11,12])
|
||||
b = a[jt.arange(3)[:,None],
|
||||
jt.arange(4)[None,:],
|
||||
jt.arange(4)[None,:]]
|
||||
nb = a.data[np.arange(3)[:,None],
|
||||
np.arange(4)[None,:],
|
||||
np.arange(4)[None,:]]
|
||||
np.testing.assert_allclose(nb, b.data)
|
||||
|
||||
a = jt.random([10,11,12])
|
||||
b = a[jt.arange(3)[::-1,None],
|
||||
jt.arange(4)[None,:],
|
||||
jt.arange(4)[None,:]]
|
||||
nb = a.data[np.arange(3)[::-1,None],
|
||||
np.arange(4)[None,:],
|
||||
np.arange(4)[None,:]]
|
||||
np.testing.assert_allclose(nb, b.data)
|
||||
|
||||
a = jt.random([10,11,12])
|
||||
b = a[jt.arange(3)[::-1,None],
|
||||
jt.arange(4)[None,:],
|
||||
jt.arange(4)[None,::-1]]
|
||||
nb = a.data[np.arange(3)[::-1,None],
|
||||
np.arange(4)[None,:],
|
||||
np.arange(4)[None,::-1]]
|
||||
np.testing.assert_allclose(nb, b.data)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue