optimize where op && fix vary shape infer

This commit is contained in:
li-xl 2020-10-26 22:03:53 +08:00
parent 9ff2cddee5
commit aa502deafb
24 changed files with 198 additions and 82 deletions

View File

@ -66,6 +66,25 @@ struct ConvertOp
}
};
__global__ static void where_kernel(
int n,
To* input
@for(i, 0, NDIM, 1, ,index_t shape_@i, To* out_@i)
) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int tnum = gridDim.x * blockDim.x;
for (index_t i=tid; i<n; i+=tnum) {
index_t x = input[i];
@for(j, NDIM-1, 0, -1,
index_t i@j = x % shape_@j;
out_@j[i] = i@j;
x /= shape_@j;
)
out_0[i] = x;
(void)shape_0;
}
}
void CubWhereOp::jit_run(){
int N = cond->num;
size_t temp_storage_bytes=0;
@ -80,43 +99,30 @@ void CubWhereOp::jit_run(){
cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, itr, (int *)num_nonzeros, N);
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
int num_nonzeros_h;
checkCudaErrors(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros, sizeof(int), cudaMemcpyDeviceToHost, 0));
//need to synchronize to make sure data is available on the host
checkCudaErrors(cudaStreamSynchronize(0));
checkCudaErrors(cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(int), cudaMemcpyDeviceToHost));
size_t out_temp_allocation;
To * out_temp = (To *) exe.allocator->alloc(num_nonzeros_h*sizeof(To), out_temp_allocation);
To* out_temp = outs[0]->ptr<To>();
@for(i, 0, NDIM, outs[@i]->set_shape({num_nonzeros_h});)
if (NDIM > 0) {
cub::CountingInputIterator<To> counting_itr(0);
temp_storage_bytes = 0;
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
if (num_nonzeros_h > 0 && NDIM > 1){
To div = 1;
for (int dim = NDIM-1; dim >= 0; dim--){
To dim_size = cond->shape[dim];
thrust::transform(
thrust::device_ptr<To>(out_temp),
thrust::device_ptr<To>(out_temp) + num_nonzeros_h,
thrust::device_ptr<To>(outs[dim]->ptr<To>()),
ConvertOp<To>(div,dim_size)
);
div *= dim_size;
}
}else if (num_nonzeros_h>0 && NDIM==1){
checkCudaErrors(cudaMemcpyAsync(outs[0]->ptr<To>(), out_temp, num_nonzeros_h*sizeof(To), cudaMemcpyDeviceToDevice, 0));
}
cub::CountingInputIterator<To> counting_itr(0);
temp_storage_bytes = 0;
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
if (num_nonzeros_h > 0 && NDIM > 1) {
int thread_num = std::min(1024, num_nonzeros_h);
int block_num = std::max(1, num_nonzeros_h/1024);
where_kernel<<<block_num, thread_num>>>(
num_nonzeros_h,
out_temp
@for(i, 0, NDIM, 1, , cond->shape[@i], outs[@i]->ptr<To>())
);
}
exe.allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation);
exe.allocator->free(out_temp, NDIM*num_nonzeros_h*sizeof(To), out_temp_allocation);
}
#endif

View File

@ -143,6 +143,12 @@ def relu(x): return jt.maximum(x, 0)
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
def gelu(x):
_sqrt2 = 1.4142135623730951
erf = jt.erf(x/_sqrt2)+1
r = erf*x*.5
return r
class PReLU(Module):
def __init__(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
@ -411,6 +417,29 @@ class InstanceNorm2d(Module):
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class LayerNorm(Module):
def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = init.constant(normalized_shape, "float32", 1.0)
self.bias = init.constant(normalized_shape, "float32", 0.0)
def execute(self,x):
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
mean = jt.mean(x,dims=dims,keepdims=1)
numerator = x-mean
variance = jt.mean(numerator.sqr(),dims=dims,keepdims=1)
denominator = jt.sqrt(variance+self.eps)
norm_x = numerator/denominator
if self.elementwise_affine:
norm_x = norm_x * self.weight+self.bias
return norm_x
class GroupNorm(Module):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True):
self.num_groups = num_groups
@ -447,6 +476,7 @@ Leaky_relu = jt.make_module(leaky_relu, 2)
LeakyReLU = Leaky_relu
ReLU6 = jt.make_module(relu6)
Softmax = jt.make_module(softmax, 2)
GELU = jt.make_module(gelu)
class Conv(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):

View File

@ -114,5 +114,20 @@ class TestBatchNorm(unittest.TestCase):
model.eval()
check_equal_with_istrain(arr, jnn.GroupNorm(2, 10, is_train=False), model, False, False)
# ***************************************************************
# Test LayerNorm Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
class Model(tnn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = tnn.LayerNorm(224)
def forward(self, x):
return self.layer(x)
model = Model()
model.eval()
check_equal_with_istrain(arr, jnn.LayerNorm(224), model, False, False)
if __name__ == "__main__":
unittest.main()

View File

@ -62,6 +62,12 @@ class TestRelu(unittest.TestCase):
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
# ***************************************************************
# Test GELU Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.GELU(), tnn.GELU())
# ***************************************************************
# Test Softplus Layer
# ***************************************************************

View File

@ -8,43 +8,46 @@ import jittor as jt
import numpy as np
class TestWhereOp(unittest.TestCase):
def setUp(self):
self.where = jt.where
def test(self):
assert (jt.where([0,1,0,1])[0].data == [1,3]).all()
a, = jt.where([0,1,0,1])
assert (self.where([0,1,0,1])[0].data == [1,3]).all()
a, = self.where([0,1,0,1])
assert a.uncertain_shape==[-4]
a.data
assert a.uncertain_shape==[2]
a,b = jt.where([[0,0,1],[1,0,0]])
a,b = self.where([[0,0,1],[1,0,0]])
assert (a.data==[0,1]).all() and (b.data==[2,0]).all()
def test_reindex_dep(self):
a = jt.random([10])
b, = (a>1).where()
b, = self.where(a>1)
assert len(b.data)==0
b, = (a>0.5).where()
b, = self.where(a>0.5)
assert (b.data==np.where(a.data>0.5)).all()
b = a.reindex_var((a>0.5).where())
b = a.reindex_var(self.where(a>0.5))
assert (b.data==a.data[a.data>0.5]).all()
def test_binary_dep(self):
a = jt.random([10])
b, = (a>0.5).where()
b, = self.where(a>0.5)
b = b+1
assert (b.data==np.where(a.data>0.5)[0]+1).all()
b, = (a>1).where()
b, = self.where(a>1)
b = b+1
assert (b.data==np.where(a.data>1)[0]+1).all()
def test_self_dep(self):
a = jt.random([100])
x = a.reindex_var((a>0.1).where())
x = x.reindex_var((x<0.9).where())
x = a.reindex_var(self.where(a>0.1))
x = x.reindex_var(self.where(x<0.9))
na = a.data
assert (na[np.logical_and(na>0.1, na<0.9)]==x.data).all()
def test_reduce_dep(self):
a = jt.random([100,100])
index = (a>0.5).where()
index = self.where(a>0.5)
x = a.reindex_var(index)
xsum =x.sum()
na = a.data
@ -53,5 +56,32 @@ class TestWhereOp(unittest.TestCase):
def test_doc(self):
assert "Where Operator" in jt.where.__doc__
class TestWhereOpCuda(TestWhereOp):
def setUp(self):
self.where = jt.where
@classmethod
def setUpClass(self):
jt.flags.use_cuda = 1
@classmethod
def tearDownClass(self):
jt.flags.use_cuda = 0
class TestWhereOpCub(TestWhereOpCuda):
def setUp(self):
self.where = jt.compile_extern.cub_ops.cub_where
@classmethod
def setUpClass(self):
jt.flags.use_cuda = 1
@classmethod
def tearDownClass(self):
jt.flags.use_cuda = 0
if __name__ == "__main__":
unittest.main()

View File

@ -391,10 +391,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
}
LOGvvv << "Run" << op;
if (!op->shape_infered()) op->infer_shape();
ASSERT(op->shape_infered()) << "Shape of(" >> op->name() >> ") not solved.";
for (auto* var : op->outputs())
if (op->flags.get(NodeFlags::_has_vary_input)) op->init();
ASSERT(!op->flags.get(NodeFlags::_has_vary_input))
<< "Shape of(" >> op->name() >> ") not solved.";
for (auto* var : op->outputs()) {
var->alloc(allocator);
}
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
op->do_prepare();
bool is_cuda = op->flags.get(NodeFlags::_cuda);

View File

@ -91,7 +91,9 @@ void FusedOp::update_ops() {
}
}
vars.clear();
bool has_vary_input = 0;
for (Op* opi : ops) {
has_vary_input |= opi->flags.get(NodeFlags::_has_vary_input);
for (Var* i : opi->inputs()) {
auto &c = i->custom_data;
// if not visited
@ -110,6 +112,7 @@ void FusedOp::update_ops() {
}
}
}
flags.set(NodeFlags::_has_vary_input, has_vary_input);
LOGvvvv << "Var info" << vars;
}
@ -136,15 +139,12 @@ FusedOp::~FusedOp() {
}
void FusedOp::infer_shape() {
for (uint i=0; i<ops.size(); i++)
ops[i]->infer_shape();
}
bool FusedOp::shape_infered() {
for (uint i=0; i<ops.size(); i++)
if (!ops[i]->shape_infered())
return false;
return true;
bool has_vary_input = 0;
for (Op* op : ops) {
op->init();
has_vary_input |= op->flags.get(NodeFlags::_has_vary_input);
}
flags.set(NodeFlags::_has_vary_input, has_vary_input);
}
void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {

View File

@ -48,7 +48,6 @@ struct FusedOp final : Op {
const char* name() const override { return "fused"; }
void statistics(uint64_t& in, uint64_t& out, uint64_t& compute) override;
bool shape_infered() override;
void infer_shape() override;
void do_jit_prepare() override;
void do_prepare() override;

View File

@ -177,7 +177,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
Var* dout = grads[id];
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar && var->num)
if (dvar && dvar->num>=0 && var->num)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)

View File

@ -71,6 +71,7 @@ static unordered_set<string> unary_ops = {
"cosh",
"acosh",
"sigmoid",
"erf"
};
static unordered_set<string> unary_float_ops = {

View File

@ -73,6 +73,7 @@ namespace jittor {
m(acos) \
m(cosh) \
m(acosh) \
m(erf) \
m(sigmoid) \
\
m(uniform) \

View File

@ -50,6 +50,8 @@ struct NodeFlags {
_grads=_n+6,
// bit7: has graph optimize
_has_gopt=_n+7,
// bit7: has vary input
_has_vary_input=_n+8,
};
inline void set(Flags f, int a=1, int nbits=1) {

View File

@ -62,17 +62,14 @@ Var* Op::create_output(NanoVector shape, NanoString dtype) {
}
void Op::init() {
bool has_vary_input = 0;
for (Var* v : inputs())
if (v->num < 0) {
has_vary_input = 1;
break;
}
flags.set(NodeFlags::_has_vary_input, has_vary_input);
infer_shape();
LOGvvvv << "Create" << this << "and outputs" << outputs();
for (Var* v : outputs())
CHECK(v->shape.size()) << "Number of dims should be solved.";
}
bool Op::shape_infered() {
if (flags.get(NodeFlags::_vary_shape)) return true;
for (Var* v : outputs())
if (v->num < 0) return false;
return true;
}
void Op::compile_optimize(string& src) {}

View File

@ -23,7 +23,6 @@ struct Op : Node {
inline uint type() const { CHECK_EXIST; return flags.get(NodeFlags::_op_type, NodeFlags::_op_type_nbits); }
inline void set_type(OpType t) { CHECK_EXIST; flags.set(NodeFlags::_op_type, t, NodeFlags::_op_type_nbits); }
virtual bool shape_infered();
Var* create_output(NanoVector shape, NanoString dtype);
void init();

View File

@ -31,7 +31,7 @@ void ReshapeOp::infer_shape() {
size_t uncertain_dim = 0;
int64_t y_items = 1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (shape[i] < 0) {
++uncertain_dim;
} else
y_items *= shape[i];
@ -39,14 +39,16 @@ void ReshapeOp::infer_shape() {
ASSERT(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
int64_t x_items = x->num;
auto yshape = shape;
if (uncertain_dim == 0) {
ASSERT(x_items == y_items) << "reshape shape is invalid for input of size " << x_items;
if (x_items < 0) {
// pass if input is uncertain
} else if (uncertain_dim == 0) {
ASSERTop(x_items,==,y_items) << "reshape shape is invalid for input of size";
} else {
ASSERT(x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
uncertain_dim = x_items / y_items;
yshape.clear();
for (auto a : shape)
yshape.push_back(a==-1 ? uncertain_dim : a);
yshape.push_back(a<0 ? uncertain_dim : a);
}
y->set_shape(yshape);
y->share_with(x);

View File

@ -65,11 +65,12 @@ void SetitemOp::infer_shape() {
int bmask = 0;
int bmask2 = 0;
ASSERTop(data_dim,<=,out_shape.size()) << "Data dimension not match";
CHECKop(data_dim,<=,out_shape.size()) << "Data dimension not match";
for (int i=0; i<data_dim; i++) {
int j = i - data_dim + out_shape.size();
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
ASSERTop(data_shape[i],==,out_shape[j]) << "Data shape not match" << data_shape << out_shape;
CHECK(data_shape[i]<0 || data_shape[i]==out_shape[j])
<< "Data shape not match" << data_shape << out_shape;
bmask |= 1<<j;
}
}

View File

@ -36,13 +36,18 @@ void TernaryOp::infer_shape() {
auto ydim = y->shape.size();
auto cdim = cond->shape.size();
CHECK(xdim==ydim && cdim==ydim) << "Number of dims should be the same.";
NanoVector zshape;
for (size_t i=0; i<xdim; i++) {
auto xshape = x->shape[i];
auto yshape = y->shape[i];
auto cshape = cond->shape[i];
CHECK(xshape==yshape && cshape==yshape) << "Shape not match";
auto shape = std::min(xshape, std::min(yshape, cshape));
auto shape2 = std::max(xshape, std::max(yshape, cshape));
zshape.push_back(shape2);
if (shape < 0) continue;
CHECK(shape==shape2) << "Shape not match" << x->shape << y->shape << cond->shape;
}
z->set_shape(x->shape);
z->set_shape(zshape);
}
void TernaryOp::jit_prepare() {

View File

@ -65,6 +65,7 @@ static unordered_set<string> unary_ops = {
// @pybind(acosh, arccosh)
"acosh",
"sigmoid",
"erf",
};
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
@ -191,6 +192,16 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
r = make_binary(out, r, ns_subtract);
return make_binary(dout, r, ns_multiply);
}
// derf(x) = e^(-x^2)*2/sqrt(pi)
if (ns == ns_erf) {
auto two_div_sqrt_pi = make_number(2/1.7724538509055159, x);
auto two = make_number(2, x);
auto x2 = make_binary(x, x, ns_multiply);
x2 = make_unary(x2, ns_negative);
auto r = make_unary(x2, ns_exp);
r = make_binary(r, two_div_sqrt_pi, ns_multiply);
return make_binary(dout, r, ns_multiply);
}
return nullptr;
}

View File

@ -38,6 +38,8 @@ namespace jittor {
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
#define erf(T,x) ((T) ::erff((x)))
#else
#define abs(T,x) std::abs(x)
#define log(T,x) std::log((T)(x))
@ -64,6 +66,8 @@ namespace jittor {
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
#define erf(T,x) ((T) std::erf((x)))
#endif
#define cast(T,x) ((T)(x))

View File

@ -22,13 +22,10 @@ WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
auto ndim = cond->shape.size();
#ifdef HAS_CUDA
if (use_cuda) {
static std::vector<VarPtr>(*cub_where)(Var*, NanoString) = nullptr;
if (!cub_where && has_op("cub_where")) {
cub_where = get_op_info("cub_where")
.get_constructor<std::vector<VarPtr>, Var*, NanoString>();
}
if (cub_where) {
auto var = cub_where(cond,dtype);
static auto cub_where = has_op("cub_where") ? get_op_info("cub_where")
.get_constructor<std::vector<VarPtr>, Var*, NanoString>() : nullptr;
if (cub_where && (ndim>1 || std::abs(cond->num)>4096)) {
auto var = cub_where(cond, dtype);
for(uint i=0;i<ndim;i++)
forward(var[i]);
return;

View File

@ -1040,7 +1040,7 @@ void KernelIR::split_loop(int i, int j) {
inner[2]->attrs["code"] = lvalue+"+="+rvalue2+";";
push_back("for ("+dtype+" id"+sj+"=0; id"+sj+"<range"+sj+"; id"+sj+"++) {}");
auto& sloop = children.back();
int range, stride;
int range=0, stride=0;
if (get_number("range"+si, range) && get_number("stride"+si, stride) && (range%stride==0))
push_back(dtype+" range"+sj+" = "+S(stride)+";", &inner);
else {

View File

@ -14,6 +14,7 @@ namespace jittor {
// define in tracer.cc
void print_trace();
void breakpoint();
constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
return path[index]

View File

@ -72,6 +72,12 @@ void setter_gdb_path(string v) {
setter_gdb_attach(gdb_attach);
}
void breakpoint() {
static bool is_attached = 0;
if (is_attached) return;
setter_gdb_attach(1);
}
void print_trace() {
if (gdb_path.size()) {
// using gdb to print the stack trace

View File

@ -9,5 +9,6 @@
namespace jittor {
void print_trace();
void breakpoint();
} // jittor