mirror of https://github.com/Jittor/Jittor
polish matmul
This commit is contained in:
parent
87b1933447
commit
84d9434ecf
|
@ -48,13 +48,21 @@ VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
|
||||
void CublasBatchedMatmulOp::infer_shape(){
|
||||
ASSERTop(a->shape.size(),==,3);
|
||||
ASSERTop(b->shape.size(),==,3);
|
||||
auto adim = a->shape.size();
|
||||
auto bdim = b->shape.size();
|
||||
ASSERTop(adim,>=,3);
|
||||
ASSERTop(bdim,>=,3);
|
||||
ASSERTop(adim,==,bdim);
|
||||
|
||||
int batch_size = a->shape[0], n = a->shape[1], m = a->shape[2];
|
||||
int batch_size_ = b->shape[0], m_ = b->shape[1], k = b->shape[2];
|
||||
auto n = a->shape[adim-2], m = a->shape[adim-1];
|
||||
auto m_ = b->shape[adim-2], k = b->shape[adim-1];
|
||||
|
||||
ASSERTop(batch_size,==,batch_size_);
|
||||
NanoVector c_shape;
|
||||
|
||||
for (int i=0; i<adim-2; i++) {
|
||||
ASSERTop(a->shape[i],==,b->shape[i]);
|
||||
c_shape.push_back(a->shape[i]);
|
||||
}
|
||||
if (trans_a) {
|
||||
swap(n, m);
|
||||
}
|
||||
|
@ -62,8 +70,10 @@ void CublasBatchedMatmulOp::infer_shape(){
|
|||
swap(m_, k);
|
||||
}
|
||||
ASSERTop(m,==,m_);
|
||||
c_shape.push_back(n);
|
||||
c_shape.push_back(k);
|
||||
|
||||
c->set_shape({batch_size, n, k});
|
||||
c->set_shape(c_shape);
|
||||
}
|
||||
|
||||
void CublasBatchedMatmulOp::jit_prepare() {
|
||||
|
@ -83,16 +93,19 @@ void CublasBatchedMatmulOp::jit_run() {
|
|||
|
||||
const auto& as = a->shape;
|
||||
const auto& bs = b->shape;
|
||||
auto adim = as.size();
|
||||
auto batch_size = as[0];
|
||||
auto n = as[1];
|
||||
auto m = as[2];
|
||||
auto k = bs[2];
|
||||
for (int i=1; i<adim-2; i++)
|
||||
batch_size *= as[i];
|
||||
auto n = as[adim-2];
|
||||
auto m = as[adim-1];
|
||||
auto k = bs[adim-1];
|
||||
if ('@Trans_a'=='T') {
|
||||
n = as[2];
|
||||
m = as[1];
|
||||
n = as[adim-1];
|
||||
m = as[adim-2];
|
||||
}
|
||||
if ('@Trans_b'=='T') {
|
||||
k = bs[1];
|
||||
k = bs[adim-2];
|
||||
}
|
||||
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
|
||||
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
|
|
|
@ -47,43 +47,81 @@ Example::
|
|||
b = jt.random((batch, m, k))
|
||||
c = nn.bmm(a, b)
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
assert len(a.shape) > 2 and len(b.shape) > 2
|
||||
return matmul(a, b)
|
||||
if jt.flags.use_cuda:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape, [len(shape)-3])
|
||||
return (a*b).sum(len(shape)-2)
|
||||
|
||||
def matmul(a, b):
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
''' matrix multiply,
|
||||
|
||||
Example::
|
||||
|
||||
a = jt.random([3])
|
||||
b = jt.random([3])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [1]
|
||||
|
||||
a = jt.random([3, 4])
|
||||
b = jt.random([4])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [3]
|
||||
|
||||
a = jt.random([10, 3, 4])
|
||||
b = jt.random([4])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [10, 3]
|
||||
|
||||
a = jt.random([10, 3, 4])
|
||||
b = jt.random([4, 5])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [10, 3, 5]
|
||||
|
||||
a = jt.random([10, 3, 4])
|
||||
b = jt.random([10, 4, 5])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [10, 3, 5]
|
||||
|
||||
a = jt.random([8, 1, 3, 4])
|
||||
b = jt.random([10, 4, 5])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [8, 10, 3, 5]
|
||||
'''
|
||||
len_a = len(a.shape)
|
||||
len_b = len(b.shape)
|
||||
len_max = max(len_a, len_b)
|
||||
a_shape = (len_max - len_a) * [1,] + list(a.shape)
|
||||
b_shape = (len_max - len_b) * [1,] + list(b.shape)
|
||||
|
||||
a_rep = []
|
||||
b_rep = []
|
||||
for i in range(len_max-2):
|
||||
if a_shape[i] == 1 or b_shape[i] == 1:
|
||||
a_rep.append(b_shape[i])
|
||||
b_rep.append(a_shape[i])
|
||||
else:
|
||||
if a_shape[i] == b_shape[i]:
|
||||
a_rep.append(1)
|
||||
b_rep.append(1)
|
||||
else:
|
||||
raise(f"{a_shape[i]} and {b_shape[i]} must be same.")
|
||||
a_rep += [1,1,b.shape[-1],]
|
||||
b_rep += [a.shape[-2],1,1,]
|
||||
a = a.unsqueeze(-1).repeat(a_rep)
|
||||
b = b.unsqueeze(-3).repeat(b_rep)
|
||||
|
||||
return (a*b).sum(len(a.shape)-2)
|
||||
if len_b == 1:
|
||||
# a: [n, m], b:[m], c:[n]
|
||||
return (a*b).sum(-1)
|
||||
if len_a == 1:
|
||||
# a: [n], b:[n,k], c:[k]
|
||||
return (a.broadcast(b, [-1]) * b).sum(0)
|
||||
if len_a>=3 and len_a==len_b:
|
||||
# bmm
|
||||
# a: [..., n, m], b: [..., m, k], c:[..., n, k]
|
||||
if jt.flags.use_cuda:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = []
|
||||
len_c = max(len_a, len_b)
|
||||
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
|
||||
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
|
||||
# a: [..., n, m]
|
||||
# b: [..., m, k]
|
||||
# cc:[..., n, m, k]
|
||||
# -->
|
||||
# 012
|
||||
for i in range(len_c-2):
|
||||
ai = len_a-(len_c-i)
|
||||
bi = len_b-(len_c-i)
|
||||
an = a.shape[ai] if ai>=0 else 1
|
||||
bn = b.shape[bi] if bi>=0 else 1
|
||||
if an!=1 and bn!=1:
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
|
||||
cn = max(an, bn)
|
||||
shape.append(cn)
|
||||
shape.extend([n, m, k])
|
||||
a = a.broadcast(shape, [-1])
|
||||
b = b.broadcast(shape, [-3])
|
||||
return (a*b).sum(-2)
|
||||
jt.Var.matmul = jt.Var.__matmul__ = matmul
|
||||
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
||||
|
||||
|
|
|
@ -58,19 +58,19 @@ class TestCore(unittest.TestCase):
|
|||
b = np.array([[4, 1], [2, 2]]).astype("float32")
|
||||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
assert np.allclose(jtc, c)
|
||||
|
||||
a = np.random.random((128,3,10,20))
|
||||
b = np.random.random((20,30))
|
||||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
assert np.allclose(jtc, c)
|
||||
|
||||
a = np.random.random((128,3,10,20))
|
||||
b = np.random.random((128,3,20,30))
|
||||
c = np.matmul(a, b)
|
||||
jtc = jt.matmul(jt.array(a), jt.array(b)).data
|
||||
assert np.all(jtc == c)
|
||||
assert np.allclose(jtc, c), np.abs(jtc-c).max()
|
||||
|
||||
def test_var_holder(self):
|
||||
jt.clean()
|
||||
|
|
|
@ -291,5 +291,54 @@ class TestMatmul(unittest.TestCase):
|
|||
assert len(logs_b)==2, len(logs_b)
|
||||
jt.clean()
|
||||
|
||||
def test_matmul_example(self):
|
||||
a = jt.random([3])
|
||||
b = jt.random([3])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [1]
|
||||
|
||||
a = jt.random([3, 4])
|
||||
b = jt.random([4])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [3]
|
||||
|
||||
a = jt.random([10, 3, 4])
|
||||
b = jt.random([4])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [10, 3]
|
||||
|
||||
a = jt.random([10, 3, 4])
|
||||
b = jt.random([4, 5])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [10, 3, 5]
|
||||
|
||||
a = jt.random([10, 3, 4])
|
||||
b = jt.random([10, 4, 5])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [10, 3, 5]
|
||||
|
||||
a = jt.random([8, 1, 3, 4])
|
||||
b = jt.random([10, 4, 5])
|
||||
c = jt.matmul(a, b)
|
||||
assert c.shape == [8, 10, 3, 5]
|
||||
|
||||
def test_matmul_example2(self):
|
||||
def check(a_shape, b_shape):
|
||||
a = jt.random(a_shape)
|
||||
b = jt.random(b_shape)
|
||||
c = jt.matmul(a, b)
|
||||
cc = np.matmul(a.data, b.data)
|
||||
assert c.shape == cc.shape or (cc.shape==() and c.shape==[1]), (c.shape, cc.shape)
|
||||
assert np.allclose(c.data, cc), (c.data-cc)
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert da.shape == a.shape
|
||||
assert db.shape == b.shape
|
||||
check([3], [3])
|
||||
check([3,4], [4])
|
||||
check([10,3,4], [4])
|
||||
check([10,3,4], [4,5])
|
||||
check([10,3,4], [10,4,5])
|
||||
check([8,1,3,4], [10,4,5])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -13,11 +13,12 @@ namespace jittor {
|
|||
|
||||
#ifndef JIT
|
||||
static auto make_reduce = get_op_info("reduce")
|
||||
.get_constructor<VarPtr, Var*, NanoString, uint, bool>();
|
||||
.get_constructor<VarPtr, Var*, NanoString, uint, uint>();
|
||||
|
||||
BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
|
||||
auto count = dims.size();
|
||||
// forward x if don't need broadcast
|
||||
if (y->num>=0 && !need_broadcast(x, y->shape)) {
|
||||
if (y->num>=0 && !count && !need_broadcast(x, y->shape)) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
|
@ -26,21 +27,19 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
|
|||
set_type(OpType::broadcast);
|
||||
z = create_output(NanoVector(), x->dtype());
|
||||
bcast_mask = 0;
|
||||
keepdims = 0;
|
||||
auto ydim = y->shape.size();
|
||||
if (dims.size()) {
|
||||
for (auto dim : dims) {
|
||||
if (dim<0) dim += ydim;
|
||||
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
|
||||
bcast_mask |= 1 << dim;
|
||||
}
|
||||
} else
|
||||
keepdims = 1;
|
||||
keepdims_mask = 0;
|
||||
auto ydim = std::max(x->shape.size(), y->shape.size()-count)+count;
|
||||
for (auto dim : dims) {
|
||||
if (dim<0) dim += ydim;
|
||||
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
|
||||
bcast_mask |= 1 << dim;
|
||||
}
|
||||
}
|
||||
|
||||
BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask) : x(x), y(y) {
|
||||
BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask) : x(x), y(y) {
|
||||
auto count = __builtin_popcount(dims_mask);
|
||||
// forward x if don't need broadcast
|
||||
if (y->num>=0 && !need_broadcast(x, y->shape)) {
|
||||
if (y->num>=0 && !count && !need_broadcast(x, y->shape)) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
|
@ -49,12 +48,13 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask) : x(x), y(y) {
|
|||
set_type(OpType::broadcast);
|
||||
z = create_output(NanoVector(), x->dtype());
|
||||
bcast_mask = dims_mask;
|
||||
keepdims = 0;
|
||||
this->keepdims_mask = keepdims_mask;
|
||||
}
|
||||
|
||||
BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x), y(nullptr), shape(shape) {
|
||||
auto count = dims.size();
|
||||
// forward x if don't need broadcast
|
||||
if (!need_broadcast(x, shape)) {
|
||||
if (!count && !need_broadcast(x, shape)) {
|
||||
forward(x);
|
||||
return;
|
||||
}
|
||||
|
@ -66,16 +66,13 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
|
|||
CHECKop(v,>,0u) << "Shape should greater than 0.";
|
||||
z = create_output(nullptr, x->dtype());
|
||||
bcast_mask = 0;
|
||||
keepdims = 0;
|
||||
auto ydim = shape.size();
|
||||
if (dims.size()) {
|
||||
for (auto dim : dims) {
|
||||
if (dim<0) dim += ydim;
|
||||
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
|
||||
bcast_mask |= 1 << dim;
|
||||
}
|
||||
} else
|
||||
keepdims = 1;
|
||||
keepdims_mask = 0;
|
||||
auto ydim = std::max(x->shape.size(), shape.size()-count)+count;
|
||||
for (auto dim : dims) {
|
||||
if (dim<0) dim += ydim;
|
||||
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
|
||||
bcast_mask |= 1 << dim;
|
||||
}
|
||||
}
|
||||
|
||||
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
|
||||
|
@ -88,7 +85,7 @@ bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
|
|||
VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
if (v_index==1) return nullptr;
|
||||
if (bcast_mask==0) return dout;
|
||||
VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims);
|
||||
VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims_mask);
|
||||
if (dv->shape.size() != v->shape.size())
|
||||
dv->shape = v->shape;
|
||||
return dv;
|
||||
|
@ -105,50 +102,39 @@ void BroadcastToOp::infer_shape() {
|
|||
auto yshapes = y ? y->shape : shape;
|
||||
auto xdim = x->shape.size();
|
||||
auto ydim = yshapes.size();
|
||||
auto zdim = std::max(xdim, ydim);
|
||||
NanoVector zshape;
|
||||
auto count = __builtin_popcount(bcast_mask&~keepdims_mask);
|
||||
auto zdim = std::max(xdim, ydim-count) + count;
|
||||
|
||||
if (bcast_mask) {
|
||||
uint j=0;
|
||||
for (uint i=0; i<yshapes.size(); i++) {
|
||||
if (bcast_mask>>i&1) {
|
||||
zshape.push_back_check_overflow(yshapes[i]);
|
||||
continue;
|
||||
}
|
||||
CHECK(j<xdim) << "Number of shape not match.";
|
||||
// yshape[i] == 1 will be broadcast to xshape[j]
|
||||
// use case, xshape = [-3], yshape = [1, 3], dims=[1]
|
||||
// zshape -> [-3, 3]
|
||||
auto zs = (yshapes[i]<=1) ? x->shape[j] : yshapes[i];
|
||||
zshape.push_back_check_overflow(zs);
|
||||
CHECKop(x->shape[j],==,zs) << "Shape not match.";
|
||||
j++;
|
||||
int64 zz[zdim];
|
||||
for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) {
|
||||
bool bx = xi>=0;
|
||||
bool by = yi>=0;
|
||||
auto xshape = bx ? x->shape[xi] : 1;
|
||||
auto yshape = by ? yshapes[yi] : 1;
|
||||
if (bcast_mask>>i&1) {
|
||||
yi--;
|
||||
if (keepdims_mask>>i&1) xi--;
|
||||
zz[i] = yshape;
|
||||
continue;
|
||||
}
|
||||
j += j==0;
|
||||
CHECKop(j,==,xdim) << "Number of shape not match.";
|
||||
z->set_shape(zshape);
|
||||
LOGvvv << "Broadcast x(" >> x >> ") dims" << std::hex >>
|
||||
bcast_mask << "-> z(" >> z >> ")";
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i=0; i<zdim; i++) {
|
||||
bool bx = i-zdim+xdim<xdim;
|
||||
bool by = i-zdim+ydim<ydim;
|
||||
auto xshape = bx ? x->shape[i-zdim+xdim] : 1;
|
||||
auto yshape = by ? yshapes[i-zdim+ydim] : 1;
|
||||
bcast_mask |= ((xshape==1 && (yshape!=1 || !bx) )&1) << i;
|
||||
auto mask = ((xshape==1 && (yshape!=1 || !bx))&1) << i;
|
||||
bcast_mask |= mask;
|
||||
keepdims_mask |= mask;
|
||||
int64 zs;
|
||||
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
|
||||
zs = xshape * yshape;
|
||||
} else if (xshape < 0 || yshape < 0) {
|
||||
zs = std::min(xshape, yshape);
|
||||
} else {
|
||||
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes;
|
||||
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes << bcast_mask;
|
||||
zs = xshape;
|
||||
}
|
||||
zshape.push_back_check_overflow(zs);
|
||||
zz[i] = zs;
|
||||
xi--, yi--;
|
||||
}
|
||||
|
||||
NanoVector zshape;
|
||||
for (int i=0; i<zdim; i++) zshape.push_back(zz[i]);
|
||||
z->set_shape(zshape);
|
||||
LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")";
|
||||
}
|
||||
|
|
|
@ -13,14 +13,14 @@ struct BroadcastToOp : Op {
|
|||
Var* x, * y, * z;
|
||||
NanoVector shape;
|
||||
uint16 bcast_mask;
|
||||
bool keepdims;
|
||||
uint16 keepdims_mask;
|
||||
|
||||
// @pybind(broadcast)
|
||||
BroadcastToOp(Var* x, NanoVector shape, NanoVector dims=NanoVector());
|
||||
// @pybind(broadcast,broadcast_var)
|
||||
BroadcastToOp(Var* x, Var* y, NanoVector dims=NanoVector());
|
||||
// @pybind(None)
|
||||
BroadcastToOp(Var* x, Var* y, uint dims_mask);
|
||||
BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask);
|
||||
|
||||
bool need_broadcast(const Var* x, const NanoVector& shape);
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace jittor {
|
|||
|
||||
#ifndef JIT
|
||||
static auto make_broadcast_to = get_op_info("broadcast_to")
|
||||
.get_constructor<VarPtr, Var*, Var*, uint>();
|
||||
.get_constructor<VarPtr, Var*, Var*, uint, uint>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
static auto make_ternary = get_op_info("ternary")
|
||||
|
@ -45,13 +45,14 @@ unordered_set<string> reduce_ops = {
|
|||
};
|
||||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
||||
: x(x), keepdims(keepdims) {
|
||||
: x(x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
auto xdim = x->shape.size();
|
||||
keepdims_mask = keepdims ? (int)-1 : (int)0;
|
||||
if (!dims.size()) {
|
||||
reduce_mask = (1<<xdim)-1;
|
||||
} else {
|
||||
|
@ -68,15 +69,15 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
}
|
||||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims)
|
||||
: x(x), keepdims(keepdims) {
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
|
||||
: x(x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
ns = op;
|
||||
ASSERT(ns.is_binary());
|
||||
auto xdim = x->shape.size();
|
||||
reduce_mask = dims_mask ? dims_mask : (1<<xdim)-1;
|
||||
reduce_mask = dims_mask;
|
||||
this->keepdims_mask = keepdims_mask;
|
||||
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
|
||||
}
|
||||
|
||||
|
@ -85,28 +86,31 @@ ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims)
|
|||
|
||||
void ReduceOp::infer_shape() {
|
||||
auto xdim = x->shape.size();
|
||||
NanoVector yshape;
|
||||
yshape.clear();
|
||||
for (int i=0; i<xdim; i++)
|
||||
yshape.push_back(((reduce_mask>>i)&1) ? 1 : x->shape[i]);
|
||||
if (!keepdims) {
|
||||
NanoVector yshape2;
|
||||
for (size_t i=0; i<xdim; i++)
|
||||
if (!(reduce_mask>>i & 1))
|
||||
yshape2.push_back(yshape[i]);
|
||||
if (!yshape2.size()) yshape2.push_back(1);
|
||||
y->set_shape(yshape2);
|
||||
} else
|
||||
y->set_shape(yshape);
|
||||
for (int i=0; i<xdim; i++) {
|
||||
if (reduce_mask>>i&1) {
|
||||
if (keepdims_mask>>i&1)
|
||||
yshape.push_back(1);
|
||||
} else
|
||||
yshape.push_back(x->shape[i]);
|
||||
}
|
||||
if (!yshape.size()) {
|
||||
yshape.push_back(1);
|
||||
// change last bit to 1, last dim should keep dim
|
||||
keepdims_mask |= 1;
|
||||
}
|
||||
y->set_shape(yshape);
|
||||
}
|
||||
|
||||
VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
uint mask=0;
|
||||
if (!keepdims) mask = reduce_mask;
|
||||
if (ns == ns_add)
|
||||
return make_broadcast_to(dout, v, mask);
|
||||
if (ns == ns_add) {
|
||||
auto ret = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
|
||||
return ret;
|
||||
}
|
||||
if (ns == ns_multiply) {
|
||||
VarPtr a = make_binary(dout, out, ns_multiply);
|
||||
VarPtr b = make_broadcast_to(a, v, mask);
|
||||
VarPtr b = make_broadcast_to(a, v, reduce_mask, keepdims_mask);
|
||||
return make_binary(b, v, ns_divide);
|
||||
}
|
||||
if (ns == ns_mean) {
|
||||
|
@ -116,15 +120,15 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
exe.run_sync({v}, 0);
|
||||
ASSERT(v->num>=0);
|
||||
}
|
||||
VarPtr a = make_broadcast_to(dout, v, mask);
|
||||
VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
|
||||
VarPtr n = make_number(1.0f*out->num / v->num, a);
|
||||
return make_binary(a, n, ns_multiply);
|
||||
}
|
||||
if (ns == ns_maximum || ns == ns_minimum) {
|
||||
VarPtr zeros = make_number(0, v);
|
||||
VarPtr a = make_broadcast_to(out, v, mask);
|
||||
VarPtr a = make_broadcast_to(out, v, reduce_mask, keepdims_mask);
|
||||
VarPtr cond = make_binary(v, a, ns_equal);
|
||||
VarPtr dv = make_broadcast_to(dout, v, mask);
|
||||
VarPtr dv = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
|
||||
return make_ternary(cond, dv, zeros);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -135,7 +139,7 @@ void ReduceOp::jit_prepare() {
|
|||
add_jit_define("Ty", y->dtype());
|
||||
add_jit_define("Tz", y->dtype());
|
||||
add_jit_define("OP", ns.to_cstring());
|
||||
add_jit_define("DIM", JK::hex1(yshape.size()));
|
||||
add_jit_define("DIM", JK::hex1(x->shape.size()));
|
||||
add_jit_define("REDUCE", JK::hex(reduce_mask));
|
||||
}
|
||||
|
||||
|
|
|
@ -10,13 +10,12 @@ namespace jittor {
|
|||
|
||||
struct ReduceOp : Op {
|
||||
Var* x, * y;
|
||||
NanoVector yshape; // keepdim shape
|
||||
bool keepdims;
|
||||
uint16 reduce_mask; // i-th bit is 1 of dim-i is reduced
|
||||
uint16 keepdims_mask;
|
||||
ReduceOp(Var* x, NanoString op, int dim, bool keepdims=false);
|
||||
ReduceOp(Var* x, NanoString op, NanoVector dims=NanoVector(), bool keepdims=false);
|
||||
// @pybind(None)
|
||||
ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims);
|
||||
ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask);
|
||||
|
||||
const char* name() const override { return "reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
|
Loading…
Reference in New Issue