polish matmul

This commit is contained in:
Dun Liang 2020-07-31 21:02:01 +08:00
parent 87b1933447
commit 84d9434ecf
8 changed files with 225 additions and 136 deletions

View File

@ -48,13 +48,21 @@ VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
}
void CublasBatchedMatmulOp::infer_shape(){
ASSERTop(a->shape.size(),==,3);
ASSERTop(b->shape.size(),==,3);
auto adim = a->shape.size();
auto bdim = b->shape.size();
ASSERTop(adim,>=,3);
ASSERTop(bdim,>=,3);
ASSERTop(adim,==,bdim);
int batch_size = a->shape[0], n = a->shape[1], m = a->shape[2];
int batch_size_ = b->shape[0], m_ = b->shape[1], k = b->shape[2];
auto n = a->shape[adim-2], m = a->shape[adim-1];
auto m_ = b->shape[adim-2], k = b->shape[adim-1];
ASSERTop(batch_size,==,batch_size_);
NanoVector c_shape;
for (int i=0; i<adim-2; i++) {
ASSERTop(a->shape[i],==,b->shape[i]);
c_shape.push_back(a->shape[i]);
}
if (trans_a) {
swap(n, m);
}
@ -62,8 +70,10 @@ void CublasBatchedMatmulOp::infer_shape(){
swap(m_, k);
}
ASSERTop(m,==,m_);
c_shape.push_back(n);
c_shape.push_back(k);
c->set_shape({batch_size, n, k});
c->set_shape(c_shape);
}
void CublasBatchedMatmulOp::jit_prepare() {
@ -83,16 +93,19 @@ void CublasBatchedMatmulOp::jit_run() {
const auto& as = a->shape;
const auto& bs = b->shape;
auto adim = as.size();
auto batch_size = as[0];
auto n = as[1];
auto m = as[2];
auto k = bs[2];
for (int i=1; i<adim-2; i++)
batch_size *= as[i];
auto n = as[adim-2];
auto m = as[adim-1];
auto k = bs[adim-1];
if ('@Trans_a'=='T') {
n = as[2];
m = as[1];
n = as[adim-1];
m = as[adim-2];
}
if ('@Trans_b'=='T') {
k = bs[1];
k = bs[adim-2];
}
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,

View File

@ -47,43 +47,81 @@ Example::
b = jt.random((batch, m, k))
c = nn.bmm(a, b)
'''
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]
assert len(a.shape) > 2 and len(b.shape) > 2
return matmul(a, b)
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = list(a.shape) + [b.shape[-1]]
a = a.broadcast(shape, [len(shape)-1])
b = b.broadcast(shape, [len(shape)-3])
return (a*b).sum(len(shape)-2)
def matmul(a, b):
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]
''' matrix multiply,
Example::
a = jt.random([3])
b = jt.random([3])
c = jt.matmul(a, b)
assert c.shape == [1]
a = jt.random([3, 4])
b = jt.random([4])
c = jt.matmul(a, b)
assert c.shape == [3]
a = jt.random([10, 3, 4])
b = jt.random([4])
c = jt.matmul(a, b)
assert c.shape == [10, 3]
a = jt.random([10, 3, 4])
b = jt.random([4, 5])
c = jt.matmul(a, b)
assert c.shape == [10, 3, 5]
a = jt.random([10, 3, 4])
b = jt.random([10, 4, 5])
c = jt.matmul(a, b)
assert c.shape == [10, 3, 5]
a = jt.random([8, 1, 3, 4])
b = jt.random([10, 4, 5])
c = jt.matmul(a, b)
assert c.shape == [8, 10, 3, 5]
'''
len_a = len(a.shape)
len_b = len(b.shape)
len_max = max(len_a, len_b)
a_shape = (len_max - len_a) * [1,] + list(a.shape)
b_shape = (len_max - len_b) * [1,] + list(b.shape)
a_rep = []
b_rep = []
for i in range(len_max-2):
if a_shape[i] == 1 or b_shape[i] == 1:
a_rep.append(b_shape[i])
b_rep.append(a_shape[i])
else:
if a_shape[i] == b_shape[i]:
a_rep.append(1)
b_rep.append(1)
else:
raise(f"{a_shape[i]} and {b_shape[i]} must be same.")
a_rep += [1,1,b.shape[-1],]
b_rep += [a.shape[-2],1,1,]
a = a.unsqueeze(-1).repeat(a_rep)
b = b.unsqueeze(-3).repeat(b_rep)
return (a*b).sum(len(a.shape)-2)
if len_b == 1:
# a: [n, m], b:[m], c:[n]
return (a*b).sum(-1)
if len_a == 1:
# a: [n], b:[n,k], c:[k]
return (a.broadcast(b, [-1]) * b).sum(0)
if len_a>=3 and len_a==len_b:
# bmm
# a: [..., n, m], b: [..., m, k], c:[..., n, k]
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
shape = []
len_c = max(len_a, len_b)
(n, m), (m_, k) = a.shape[-2:], b.shape[-2:]
assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
# a: [..., n, m]
# b: [..., m, k]
# cc:[..., n, m, k]
# -->
# 012
for i in range(len_c-2):
ai = len_a-(len_c-i)
bi = len_b-(len_c-i)
an = a.shape[ai] if ai>=0 else 1
bn = b.shape[bi] if bi>=0 else 1
if an!=1 and bn!=1:
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
cn = max(an, bn)
shape.append(cn)
shape.extend([n, m, k])
a = a.broadcast(shape, [-1])
b = b.broadcast(shape, [-3])
return (a*b).sum(-2)
jt.Var.matmul = jt.Var.__matmul__ = matmul
jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))

View File

@ -58,19 +58,19 @@ class TestCore(unittest.TestCase):
b = np.array([[4, 1], [2, 2]]).astype("float32")
c = np.matmul(a, b)
jtc = jt.matmul(jt.array(a), jt.array(b)).data
assert np.all(jtc == c)
assert np.allclose(jtc, c)
a = np.random.random((128,3,10,20))
b = np.random.random((20,30))
c = np.matmul(a, b)
jtc = jt.matmul(jt.array(a), jt.array(b)).data
assert np.all(jtc == c)
assert np.allclose(jtc, c)
a = np.random.random((128,3,10,20))
b = np.random.random((128,3,20,30))
c = np.matmul(a, b)
jtc = jt.matmul(jt.array(a), jt.array(b)).data
assert np.all(jtc == c)
assert np.allclose(jtc, c), np.abs(jtc-c).max()
def test_var_holder(self):
jt.clean()

View File

@ -291,5 +291,54 @@ class TestMatmul(unittest.TestCase):
assert len(logs_b)==2, len(logs_b)
jt.clean()
def test_matmul_example(self):
a = jt.random([3])
b = jt.random([3])
c = jt.matmul(a, b)
assert c.shape == [1]
a = jt.random([3, 4])
b = jt.random([4])
c = jt.matmul(a, b)
assert c.shape == [3]
a = jt.random([10, 3, 4])
b = jt.random([4])
c = jt.matmul(a, b)
assert c.shape == [10, 3]
a = jt.random([10, 3, 4])
b = jt.random([4, 5])
c = jt.matmul(a, b)
assert c.shape == [10, 3, 5]
a = jt.random([10, 3, 4])
b = jt.random([10, 4, 5])
c = jt.matmul(a, b)
assert c.shape == [10, 3, 5]
a = jt.random([8, 1, 3, 4])
b = jt.random([10, 4, 5])
c = jt.matmul(a, b)
assert c.shape == [8, 10, 3, 5]
def test_matmul_example2(self):
def check(a_shape, b_shape):
a = jt.random(a_shape)
b = jt.random(b_shape)
c = jt.matmul(a, b)
cc = np.matmul(a.data, b.data)
assert c.shape == cc.shape or (cc.shape==() and c.shape==[1]), (c.shape, cc.shape)
assert np.allclose(c.data, cc), (c.data-cc)
da, db = jt.grad(c, [a, b])
assert da.shape == a.shape
assert db.shape == b.shape
check([3], [3])
check([3,4], [4])
check([10,3,4], [4])
check([10,3,4], [4,5])
check([10,3,4], [10,4,5])
check([8,1,3,4], [10,4,5])
if __name__ == "__main__":
unittest.main()

View File

@ -13,11 +13,12 @@ namespace jittor {
#ifndef JIT
static auto make_reduce = get_op_info("reduce")
.get_constructor<VarPtr, Var*, NanoString, uint, bool>();
.get_constructor<VarPtr, Var*, NanoString, uint, uint>();
BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
auto count = dims.size();
// forward x if don't need broadcast
if (y->num>=0 && !need_broadcast(x, y->shape)) {
if (y->num>=0 && !count && !need_broadcast(x, y->shape)) {
forward(x);
return;
}
@ -26,21 +27,19 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
set_type(OpType::broadcast);
z = create_output(NanoVector(), x->dtype());
bcast_mask = 0;
keepdims = 0;
auto ydim = y->shape.size();
if (dims.size()) {
for (auto dim : dims) {
if (dim<0) dim += ydim;
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
bcast_mask |= 1 << dim;
}
} else
keepdims = 1;
keepdims_mask = 0;
auto ydim = std::max(x->shape.size(), y->shape.size()-count)+count;
for (auto dim : dims) {
if (dim<0) dim += ydim;
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
bcast_mask |= 1 << dim;
}
}
BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask) : x(x), y(y) {
BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask) : x(x), y(y) {
auto count = __builtin_popcount(dims_mask);
// forward x if don't need broadcast
if (y->num>=0 && !need_broadcast(x, y->shape)) {
if (y->num>=0 && !count && !need_broadcast(x, y->shape)) {
forward(x);
return;
}
@ -49,12 +48,13 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask) : x(x), y(y) {
set_type(OpType::broadcast);
z = create_output(NanoVector(), x->dtype());
bcast_mask = dims_mask;
keepdims = 0;
this->keepdims_mask = keepdims_mask;
}
BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x), y(nullptr), shape(shape) {
auto count = dims.size();
// forward x if don't need broadcast
if (!need_broadcast(x, shape)) {
if (!count && !need_broadcast(x, shape)) {
forward(x);
return;
}
@ -66,16 +66,13 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
CHECKop(v,>,0u) << "Shape should greater than 0.";
z = create_output(nullptr, x->dtype());
bcast_mask = 0;
keepdims = 0;
auto ydim = shape.size();
if (dims.size()) {
for (auto dim : dims) {
if (dim<0) dim += ydim;
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
bcast_mask |= 1 << dim;
}
} else
keepdims = 1;
keepdims_mask = 0;
auto ydim = std::max(x->shape.size(), shape.size()-count)+count;
for (auto dim : dims) {
if (dim<0) dim += ydim;
CHECK(dim>=0 && dim<ydim) << "Wrong dims number:" << dims;
bcast_mask |= 1 << dim;
}
}
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
@ -88,7 +85,7 @@ bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index==1) return nullptr;
if (bcast_mask==0) return dout;
VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims);
VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims_mask);
if (dv->shape.size() != v->shape.size())
dv->shape = v->shape;
return dv;
@ -105,50 +102,39 @@ void BroadcastToOp::infer_shape() {
auto yshapes = y ? y->shape : shape;
auto xdim = x->shape.size();
auto ydim = yshapes.size();
auto zdim = std::max(xdim, ydim);
NanoVector zshape;
auto count = __builtin_popcount(bcast_mask&~keepdims_mask);
auto zdim = std::max(xdim, ydim-count) + count;
if (bcast_mask) {
uint j=0;
for (uint i=0; i<yshapes.size(); i++) {
if (bcast_mask>>i&1) {
zshape.push_back_check_overflow(yshapes[i]);
continue;
}
CHECK(j<xdim) << "Number of shape not match.";
// yshape[i] == 1 will be broadcast to xshape[j]
// use case, xshape = [-3], yshape = [1, 3], dims=[1]
// zshape -> [-3, 3]
auto zs = (yshapes[i]<=1) ? x->shape[j] : yshapes[i];
zshape.push_back_check_overflow(zs);
CHECKop(x->shape[j],==,zs) << "Shape not match.";
j++;
int64 zz[zdim];
for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) {
bool bx = xi>=0;
bool by = yi>=0;
auto xshape = bx ? x->shape[xi] : 1;
auto yshape = by ? yshapes[yi] : 1;
if (bcast_mask>>i&1) {
yi--;
if (keepdims_mask>>i&1) xi--;
zz[i] = yshape;
continue;
}
j += j==0;
CHECKop(j,==,xdim) << "Number of shape not match.";
z->set_shape(zshape);
LOGvvv << "Broadcast x(" >> x >> ") dims" << std::hex >>
bcast_mask << "-> z(" >> z >> ")";
return;
}
for (size_t i=0; i<zdim; i++) {
bool bx = i-zdim+xdim<xdim;
bool by = i-zdim+ydim<ydim;
auto xshape = bx ? x->shape[i-zdim+xdim] : 1;
auto yshape = by ? yshapes[i-zdim+ydim] : 1;
bcast_mask |= ((xshape==1 && (yshape!=1 || !bx) )&1) << i;
auto mask = ((xshape==1 && (yshape!=1 || !bx))&1) << i;
bcast_mask |= mask;
keepdims_mask |= mask;
int64 zs;
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
zs = xshape * yshape;
} else if (xshape < 0 || yshape < 0) {
zs = std::min(xshape, yshape);
} else {
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes;
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes << bcast_mask;
zs = xshape;
}
zshape.push_back_check_overflow(zs);
zz[i] = zs;
xi--, yi--;
}
NanoVector zshape;
for (int i=0; i<zdim; i++) zshape.push_back(zz[i]);
z->set_shape(zshape);
LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")";
}

View File

@ -13,14 +13,14 @@ struct BroadcastToOp : Op {
Var* x, * y, * z;
NanoVector shape;
uint16 bcast_mask;
bool keepdims;
uint16 keepdims_mask;
// @pybind(broadcast)
BroadcastToOp(Var* x, NanoVector shape, NanoVector dims=NanoVector());
// @pybind(broadcast,broadcast_var)
BroadcastToOp(Var* x, Var* y, NanoVector dims=NanoVector());
// @pybind(None)
BroadcastToOp(Var* x, Var* y, uint dims_mask);
BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask);
bool need_broadcast(const Var* x, const NanoVector& shape);

View File

@ -15,7 +15,7 @@ namespace jittor {
#ifndef JIT
static auto make_broadcast_to = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, uint>();
.get_constructor<VarPtr, Var*, Var*, uint, uint>();
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
static auto make_ternary = get_op_info("ternary")
@ -45,13 +45,14 @@ unordered_set<string> reduce_ops = {
};
ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
: x(x), keepdims(keepdims) {
: x(x) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::reduce);
ns = op;
ASSERT(ns.is_binary());
auto xdim = x->shape.size();
keepdims_mask = keepdims ? (int)-1 : (int)0;
if (!dims.size()) {
reduce_mask = (1<<xdim)-1;
} else {
@ -68,15 +69,15 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
}
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims)
: x(x), keepdims(keepdims) {
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
: x(x) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::reduce);
ns = op;
ASSERT(ns.is_binary());
auto xdim = x->shape.size();
reduce_mask = dims_mask ? dims_mask : (1<<xdim)-1;
reduce_mask = dims_mask;
this->keepdims_mask = keepdims_mask;
y = create_output(nullptr, binary_dtype_infer(ns, x, x));
}
@ -85,28 +86,31 @@ ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims)
void ReduceOp::infer_shape() {
auto xdim = x->shape.size();
NanoVector yshape;
yshape.clear();
for (int i=0; i<xdim; i++)
yshape.push_back(((reduce_mask>>i)&1) ? 1 : x->shape[i]);
if (!keepdims) {
NanoVector yshape2;
for (size_t i=0; i<xdim; i++)
if (!(reduce_mask>>i & 1))
yshape2.push_back(yshape[i]);
if (!yshape2.size()) yshape2.push_back(1);
y->set_shape(yshape2);
} else
y->set_shape(yshape);
for (int i=0; i<xdim; i++) {
if (reduce_mask>>i&1) {
if (keepdims_mask>>i&1)
yshape.push_back(1);
} else
yshape.push_back(x->shape[i]);
}
if (!yshape.size()) {
yshape.push_back(1);
// change last bit to 1, last dim should keep dim
keepdims_mask |= 1;
}
y->set_shape(yshape);
}
VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
uint mask=0;
if (!keepdims) mask = reduce_mask;
if (ns == ns_add)
return make_broadcast_to(dout, v, mask);
if (ns == ns_add) {
auto ret = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
return ret;
}
if (ns == ns_multiply) {
VarPtr a = make_binary(dout, out, ns_multiply);
VarPtr b = make_broadcast_to(a, v, mask);
VarPtr b = make_broadcast_to(a, v, reduce_mask, keepdims_mask);
return make_binary(b, v, ns_divide);
}
if (ns == ns_mean) {
@ -116,15 +120,15 @@ VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
exe.run_sync({v}, 0);
ASSERT(v->num>=0);
}
VarPtr a = make_broadcast_to(dout, v, mask);
VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
VarPtr n = make_number(1.0f*out->num / v->num, a);
return make_binary(a, n, ns_multiply);
}
if (ns == ns_maximum || ns == ns_minimum) {
VarPtr zeros = make_number(0, v);
VarPtr a = make_broadcast_to(out, v, mask);
VarPtr a = make_broadcast_to(out, v, reduce_mask, keepdims_mask);
VarPtr cond = make_binary(v, a, ns_equal);
VarPtr dv = make_broadcast_to(dout, v, mask);
VarPtr dv = make_broadcast_to(dout, v, reduce_mask, keepdims_mask);
return make_ternary(cond, dv, zeros);
}
return nullptr;
@ -135,7 +139,7 @@ void ReduceOp::jit_prepare() {
add_jit_define("Ty", y->dtype());
add_jit_define("Tz", y->dtype());
add_jit_define("OP", ns.to_cstring());
add_jit_define("DIM", JK::hex1(yshape.size()));
add_jit_define("DIM", JK::hex1(x->shape.size()));
add_jit_define("REDUCE", JK::hex(reduce_mask));
}

View File

@ -10,13 +10,12 @@ namespace jittor {
struct ReduceOp : Op {
Var* x, * y;
NanoVector yshape; // keepdim shape
bool keepdims;
uint16 reduce_mask; // i-th bit is 1 of dim-i is reduced
uint16 keepdims_mask;
ReduceOp(Var* x, NanoString op, int dim, bool keepdims=false);
ReduceOp(Var* x, NanoString op, NanoVector dims=NanoVector(), bool keepdims=false);
// @pybind(None)
ReduceOp(Var* x, NanoString op, uint dims_mask, bool keepdims);
ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask);
const char* name() const override { return "reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;