mirror of https://github.com/Jittor/Jittor
optimize where op && fix vary shape infer
This commit is contained in:
parent
9ff2cddee5
commit
aa502deafb
|
@ -66,6 +66,25 @@ struct ConvertOp
|
|||
}
|
||||
};
|
||||
|
||||
__global__ static void where_kernel(
|
||||
int n,
|
||||
To* input
|
||||
@for(i, 0, NDIM, 1, ,index_t shape_@i, To* out_@i)
|
||||
) {
|
||||
int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int tnum = gridDim.x * blockDim.x;
|
||||
for (index_t i=tid; i<n; i+=tnum) {
|
||||
index_t x = input[i];
|
||||
@for(j, NDIM-1, 0, -1,
|
||||
index_t i@j = x % shape_@j;
|
||||
out_@j[i] = i@j;
|
||||
x /= shape_@j;
|
||||
)
|
||||
out_0[i] = x;
|
||||
(void)shape_0;
|
||||
}
|
||||
}
|
||||
|
||||
void CubWhereOp::jit_run(){
|
||||
int N = cond->num;
|
||||
size_t temp_storage_bytes=0;
|
||||
|
@ -80,43 +99,30 @@ void CubWhereOp::jit_run(){
|
|||
cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, itr, (int *)num_nonzeros, N);
|
||||
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
|
||||
|
||||
int num_nonzeros_h;
|
||||
checkCudaErrors(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros, sizeof(int), cudaMemcpyDeviceToHost, 0));
|
||||
//need to synchronize to make sure data is available on the host
|
||||
checkCudaErrors(cudaStreamSynchronize(0));
|
||||
checkCudaErrors(cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(int), cudaMemcpyDeviceToHost));
|
||||
|
||||
size_t out_temp_allocation;
|
||||
To * out_temp = (To *) exe.allocator->alloc(num_nonzeros_h*sizeof(To), out_temp_allocation);
|
||||
To* out_temp = outs[0]->ptr<To>();
|
||||
|
||||
@for(i, 0, NDIM, outs[@i]->set_shape({num_nonzeros_h});)
|
||||
if (NDIM > 0) {
|
||||
cub::CountingInputIterator<To> counting_itr(0);
|
||||
temp_storage_bytes = 0;
|
||||
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
|
||||
temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
|
||||
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
if (num_nonzeros_h > 0 && NDIM > 1){
|
||||
To div = 1;
|
||||
for (int dim = NDIM-1; dim >= 0; dim--){
|
||||
To dim_size = cond->shape[dim];
|
||||
thrust::transform(
|
||||
thrust::device_ptr<To>(out_temp),
|
||||
thrust::device_ptr<To>(out_temp) + num_nonzeros_h,
|
||||
thrust::device_ptr<To>(outs[dim]->ptr<To>()),
|
||||
ConvertOp<To>(div,dim_size)
|
||||
);
|
||||
div *= dim_size;
|
||||
}
|
||||
}else if (num_nonzeros_h>0 && NDIM==1){
|
||||
checkCudaErrors(cudaMemcpyAsync(outs[0]->ptr<To>(), out_temp, num_nonzeros_h*sizeof(To), cudaMemcpyDeviceToDevice, 0));
|
||||
}
|
||||
cub::CountingInputIterator<To> counting_itr(0);
|
||||
temp_storage_bytes = 0;
|
||||
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
|
||||
temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
|
||||
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
if (num_nonzeros_h > 0 && NDIM > 1) {
|
||||
int thread_num = std::min(1024, num_nonzeros_h);
|
||||
int block_num = std::max(1, num_nonzeros_h/1024);
|
||||
where_kernel<<<block_num, thread_num>>>(
|
||||
num_nonzeros_h,
|
||||
out_temp
|
||||
@for(i, 0, NDIM, 1, , cond->shape[@i], outs[@i]->ptr<To>())
|
||||
);
|
||||
}
|
||||
exe.allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation);
|
||||
exe.allocator->free(out_temp, NDIM*num_nonzeros_h*sizeof(To), out_temp_allocation);
|
||||
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -143,6 +143,12 @@ def relu(x): return jt.maximum(x, 0)
|
|||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
||||
|
||||
def gelu(x):
|
||||
_sqrt2 = 1.4142135623730951
|
||||
erf = jt.erf(x/_sqrt2)+1
|
||||
r = erf*x*.5
|
||||
return r
|
||||
|
||||
class PReLU(Module):
|
||||
def __init__(self, num_parameters=1, init_=0.25):
|
||||
self.num_parameters = num_parameters
|
||||
|
@ -411,6 +417,29 @@ class InstanceNorm2d(Module):
|
|||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class LayerNorm(Module):
|
||||
def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
|
||||
super(LayerNorm, self).__init__()
|
||||
if isinstance(normalized_shape, int):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = tuple(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = init.constant(normalized_shape, "float32", 1.0)
|
||||
self.bias = init.constant(normalized_shape, "float32", 0.0)
|
||||
|
||||
def execute(self,x):
|
||||
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
|
||||
mean = jt.mean(x,dims=dims,keepdims=1)
|
||||
numerator = x-mean
|
||||
variance = jt.mean(numerator.sqr(),dims=dims,keepdims=1)
|
||||
denominator = jt.sqrt(variance+self.eps)
|
||||
norm_x = numerator/denominator
|
||||
if self.elementwise_affine:
|
||||
norm_x = norm_x * self.weight+self.bias
|
||||
return norm_x
|
||||
|
||||
class GroupNorm(Module):
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True):
|
||||
self.num_groups = num_groups
|
||||
|
@ -447,6 +476,7 @@ Leaky_relu = jt.make_module(leaky_relu, 2)
|
|||
LeakyReLU = Leaky_relu
|
||||
ReLU6 = jt.make_module(relu6)
|
||||
Softmax = jt.make_module(softmax, 2)
|
||||
GELU = jt.make_module(gelu)
|
||||
|
||||
class Conv(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
|
|
|
@ -114,5 +114,20 @@ class TestBatchNorm(unittest.TestCase):
|
|||
model.eval()
|
||||
check_equal_with_istrain(arr, jnn.GroupNorm(2, 10, is_train=False), model, False, False)
|
||||
|
||||
# ***************************************************************
|
||||
# Test LayerNorm Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
|
||||
class Model(tnn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = tnn.LayerNorm(224)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
model = Model()
|
||||
model.eval()
|
||||
check_equal_with_istrain(arr, jnn.LayerNorm(224), model, False, False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -62,6 +62,12 @@ class TestRelu(unittest.TestCase):
|
|||
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
||||
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
||||
|
||||
# ***************************************************************
|
||||
# Test GELU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.GELU(), tnn.GELU())
|
||||
|
||||
# ***************************************************************
|
||||
# Test Softplus Layer
|
||||
# ***************************************************************
|
||||
|
|
|
@ -8,43 +8,46 @@ import jittor as jt
|
|||
import numpy as np
|
||||
|
||||
class TestWhereOp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.where = jt.where
|
||||
|
||||
def test(self):
|
||||
assert (jt.where([0,1,0,1])[0].data == [1,3]).all()
|
||||
a, = jt.where([0,1,0,1])
|
||||
assert (self.where([0,1,0,1])[0].data == [1,3]).all()
|
||||
a, = self.where([0,1,0,1])
|
||||
assert a.uncertain_shape==[-4]
|
||||
a.data
|
||||
assert a.uncertain_shape==[2]
|
||||
a,b = jt.where([[0,0,1],[1,0,0]])
|
||||
a,b = self.where([[0,0,1],[1,0,0]])
|
||||
assert (a.data==[0,1]).all() and (b.data==[2,0]).all()
|
||||
|
||||
def test_reindex_dep(self):
|
||||
a = jt.random([10])
|
||||
b, = (a>1).where()
|
||||
b, = self.where(a>1)
|
||||
assert len(b.data)==0
|
||||
b, = (a>0.5).where()
|
||||
b, = self.where(a>0.5)
|
||||
assert (b.data==np.where(a.data>0.5)).all()
|
||||
b = a.reindex_var((a>0.5).where())
|
||||
b = a.reindex_var(self.where(a>0.5))
|
||||
assert (b.data==a.data[a.data>0.5]).all()
|
||||
|
||||
def test_binary_dep(self):
|
||||
a = jt.random([10])
|
||||
b, = (a>0.5).where()
|
||||
b, = self.where(a>0.5)
|
||||
b = b+1
|
||||
assert (b.data==np.where(a.data>0.5)[0]+1).all()
|
||||
b, = (a>1).where()
|
||||
b, = self.where(a>1)
|
||||
b = b+1
|
||||
assert (b.data==np.where(a.data>1)[0]+1).all()
|
||||
|
||||
def test_self_dep(self):
|
||||
a = jt.random([100])
|
||||
x = a.reindex_var((a>0.1).where())
|
||||
x = x.reindex_var((x<0.9).where())
|
||||
x = a.reindex_var(self.where(a>0.1))
|
||||
x = x.reindex_var(self.where(x<0.9))
|
||||
na = a.data
|
||||
assert (na[np.logical_and(na>0.1, na<0.9)]==x.data).all()
|
||||
|
||||
def test_reduce_dep(self):
|
||||
a = jt.random([100,100])
|
||||
index = (a>0.5).where()
|
||||
index = self.where(a>0.5)
|
||||
x = a.reindex_var(index)
|
||||
xsum =x.sum()
|
||||
na = a.data
|
||||
|
@ -53,5 +56,32 @@ class TestWhereOp(unittest.TestCase):
|
|||
def test_doc(self):
|
||||
assert "Where Operator" in jt.where.__doc__
|
||||
|
||||
|
||||
class TestWhereOpCuda(TestWhereOp):
|
||||
def setUp(self):
|
||||
self.where = jt.where
|
||||
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
|
||||
class TestWhereOpCub(TestWhereOpCuda):
|
||||
def setUp(self):
|
||||
self.where = jt.compile_extern.cub_ops.cub_where
|
||||
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -391,10 +391,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
|
||||
}
|
||||
LOGvvv << "Run" << op;
|
||||
if (!op->shape_infered()) op->infer_shape();
|
||||
ASSERT(op->shape_infered()) << "Shape of(" >> op->name() >> ") not solved.";
|
||||
for (auto* var : op->outputs())
|
||||
if (op->flags.get(NodeFlags::_has_vary_input)) op->init();
|
||||
ASSERT(!op->flags.get(NodeFlags::_has_vary_input))
|
||||
<< "Shape of(" >> op->name() >> ") not solved.";
|
||||
for (auto* var : op->outputs()) {
|
||||
var->alloc(allocator);
|
||||
}
|
||||
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
|
||||
op->do_prepare();
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
|
|
|
@ -91,7 +91,9 @@ void FusedOp::update_ops() {
|
|||
}
|
||||
}
|
||||
vars.clear();
|
||||
bool has_vary_input = 0;
|
||||
for (Op* opi : ops) {
|
||||
has_vary_input |= opi->flags.get(NodeFlags::_has_vary_input);
|
||||
for (Var* i : opi->inputs()) {
|
||||
auto &c = i->custom_data;
|
||||
// if not visited
|
||||
|
@ -110,6 +112,7 @@ void FusedOp::update_ops() {
|
|||
}
|
||||
}
|
||||
}
|
||||
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
||||
LOGvvvv << "Var info" << vars;
|
||||
}
|
||||
|
||||
|
@ -136,15 +139,12 @@ FusedOp::~FusedOp() {
|
|||
}
|
||||
|
||||
void FusedOp::infer_shape() {
|
||||
for (uint i=0; i<ops.size(); i++)
|
||||
ops[i]->infer_shape();
|
||||
}
|
||||
|
||||
bool FusedOp::shape_infered() {
|
||||
for (uint i=0; i<ops.size(); i++)
|
||||
if (!ops[i]->shape_infered())
|
||||
return false;
|
||||
return true;
|
||||
bool has_vary_input = 0;
|
||||
for (Op* op : ops) {
|
||||
op->init();
|
||||
has_vary_input |= op->flags.get(NodeFlags::_has_vary_input);
|
||||
}
|
||||
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
||||
}
|
||||
|
||||
void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
||||
|
|
|
@ -48,7 +48,6 @@ struct FusedOp final : Op {
|
|||
|
||||
const char* name() const override { return "fused"; }
|
||||
void statistics(uint64_t& in, uint64_t& out, uint64_t& compute) override;
|
||||
bool shape_infered() override;
|
||||
void infer_shape() override;
|
||||
void do_jit_prepare() override;
|
||||
void do_prepare() override;
|
||||
|
|
|
@ -177,7 +177,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
Var* dout = grads[id];
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar && var->num)
|
||||
if (dvar && dvar->num>=0 && var->num)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
|
|
|
@ -71,6 +71,7 @@ static unordered_set<string> unary_ops = {
|
|||
"cosh",
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
"erf"
|
||||
};
|
||||
|
||||
static unordered_set<string> unary_float_ops = {
|
||||
|
|
|
@ -73,6 +73,7 @@ namespace jittor {
|
|||
m(acos) \
|
||||
m(cosh) \
|
||||
m(acosh) \
|
||||
m(erf) \
|
||||
m(sigmoid) \
|
||||
\
|
||||
m(uniform) \
|
||||
|
|
|
@ -50,6 +50,8 @@ struct NodeFlags {
|
|||
_grads=_n+6,
|
||||
// bit7: has graph optimize
|
||||
_has_gopt=_n+7,
|
||||
// bit7: has vary input
|
||||
_has_vary_input=_n+8,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
|
17
src/op.cc
17
src/op.cc
|
@ -62,17 +62,14 @@ Var* Op::create_output(NanoVector shape, NanoString dtype) {
|
|||
}
|
||||
|
||||
void Op::init() {
|
||||
bool has_vary_input = 0;
|
||||
for (Var* v : inputs())
|
||||
if (v->num < 0) {
|
||||
has_vary_input = 1;
|
||||
break;
|
||||
}
|
||||
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
||||
infer_shape();
|
||||
LOGvvvv << "Create" << this << "and outputs" << outputs();
|
||||
for (Var* v : outputs())
|
||||
CHECK(v->shape.size()) << "Number of dims should be solved.";
|
||||
}
|
||||
|
||||
bool Op::shape_infered() {
|
||||
if (flags.get(NodeFlags::_vary_shape)) return true;
|
||||
for (Var* v : outputs())
|
||||
if (v->num < 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Op::compile_optimize(string& src) {}
|
||||
|
|
1
src/op.h
1
src/op.h
|
@ -23,7 +23,6 @@ struct Op : Node {
|
|||
inline uint type() const { CHECK_EXIST; return flags.get(NodeFlags::_op_type, NodeFlags::_op_type_nbits); }
|
||||
inline void set_type(OpType t) { CHECK_EXIST; flags.set(NodeFlags::_op_type, t, NodeFlags::_op_type_nbits); }
|
||||
|
||||
virtual bool shape_infered();
|
||||
Var* create_output(NanoVector shape, NanoString dtype);
|
||||
void init();
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ void ReshapeOp::infer_shape() {
|
|||
size_t uncertain_dim = 0;
|
||||
int64_t y_items = 1;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] == -1) {
|
||||
if (shape[i] < 0) {
|
||||
++uncertain_dim;
|
||||
} else
|
||||
y_items *= shape[i];
|
||||
|
@ -39,14 +39,16 @@ void ReshapeOp::infer_shape() {
|
|||
ASSERT(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
|
||||
int64_t x_items = x->num;
|
||||
auto yshape = shape;
|
||||
if (uncertain_dim == 0) {
|
||||
ASSERT(x_items == y_items) << "reshape shape is invalid for input of size " << x_items;
|
||||
if (x_items < 0) {
|
||||
// pass if input is uncertain
|
||||
} else if (uncertain_dim == 0) {
|
||||
ASSERTop(x_items,==,y_items) << "reshape shape is invalid for input of size";
|
||||
} else {
|
||||
ASSERT(x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
|
||||
uncertain_dim = x_items / y_items;
|
||||
yshape.clear();
|
||||
for (auto a : shape)
|
||||
yshape.push_back(a==-1 ? uncertain_dim : a);
|
||||
yshape.push_back(a<0 ? uncertain_dim : a);
|
||||
}
|
||||
y->set_shape(yshape);
|
||||
y->share_with(x);
|
||||
|
|
|
@ -65,11 +65,12 @@ void SetitemOp::infer_shape() {
|
|||
int bmask = 0;
|
||||
int bmask2 = 0;
|
||||
|
||||
ASSERTop(data_dim,<=,out_shape.size()) << "Data dimension not match";
|
||||
CHECKop(data_dim,<=,out_shape.size()) << "Data dimension not match";
|
||||
for (int i=0; i<data_dim; i++) {
|
||||
int j = i - data_dim + out_shape.size();
|
||||
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
|
||||
ASSERTop(data_shape[i],==,out_shape[j]) << "Data shape not match" << data_shape << out_shape;
|
||||
CHECK(data_shape[i]<0 || data_shape[i]==out_shape[j])
|
||||
<< "Data shape not match" << data_shape << out_shape;
|
||||
bmask |= 1<<j;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,13 +36,18 @@ void TernaryOp::infer_shape() {
|
|||
auto ydim = y->shape.size();
|
||||
auto cdim = cond->shape.size();
|
||||
CHECK(xdim==ydim && cdim==ydim) << "Number of dims should be the same.";
|
||||
NanoVector zshape;
|
||||
for (size_t i=0; i<xdim; i++) {
|
||||
auto xshape = x->shape[i];
|
||||
auto yshape = y->shape[i];
|
||||
auto cshape = cond->shape[i];
|
||||
CHECK(xshape==yshape && cshape==yshape) << "Shape not match";
|
||||
auto shape = std::min(xshape, std::min(yshape, cshape));
|
||||
auto shape2 = std::max(xshape, std::max(yshape, cshape));
|
||||
zshape.push_back(shape2);
|
||||
if (shape < 0) continue;
|
||||
CHECK(shape==shape2) << "Shape not match" << x->shape << y->shape << cond->shape;
|
||||
}
|
||||
z->set_shape(x->shape);
|
||||
z->set_shape(zshape);
|
||||
}
|
||||
|
||||
void TernaryOp::jit_prepare() {
|
||||
|
|
|
@ -65,6 +65,7 @@ static unordered_set<string> unary_ops = {
|
|||
// @pybind(acosh, arccosh)
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
"erf",
|
||||
};
|
||||
|
||||
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
||||
|
@ -191,6 +192,16 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
r = make_binary(out, r, ns_subtract);
|
||||
return make_binary(dout, r, ns_multiply);
|
||||
}
|
||||
// derf(x) = e^(-x^2)*2/sqrt(pi)
|
||||
if (ns == ns_erf) {
|
||||
auto two_div_sqrt_pi = make_number(2/1.7724538509055159, x);
|
||||
auto two = make_number(2, x);
|
||||
auto x2 = make_binary(x, x, ns_multiply);
|
||||
x2 = make_unary(x2, ns_negative);
|
||||
auto r = make_unary(x2, ns_exp);
|
||||
r = make_binary(r, two_div_sqrt_pi, ns_multiply);
|
||||
return make_binary(dout, r, ns_multiply);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,6 +38,8 @@ namespace jittor {
|
|||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
|
||||
|
||||
#define erf(T,x) ((T) ::erff((x)))
|
||||
|
||||
#else
|
||||
#define abs(T,x) std::abs(x)
|
||||
#define log(T,x) std::log((T)(x))
|
||||
|
@ -64,6 +66,8 @@ namespace jittor {
|
|||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
|
||||
|
||||
#define erf(T,x) ((T) std::erf((x)))
|
||||
|
||||
#endif
|
||||
|
||||
#define cast(T,x) ((T)(x))
|
||||
|
|
|
@ -22,13 +22,10 @@ WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
|
|||
auto ndim = cond->shape.size();
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static std::vector<VarPtr>(*cub_where)(Var*, NanoString) = nullptr;
|
||||
if (!cub_where && has_op("cub_where")) {
|
||||
cub_where = get_op_info("cub_where")
|
||||
.get_constructor<std::vector<VarPtr>, Var*, NanoString>();
|
||||
}
|
||||
if (cub_where) {
|
||||
auto var = cub_where(cond,dtype);
|
||||
static auto cub_where = has_op("cub_where") ? get_op_info("cub_where")
|
||||
.get_constructor<std::vector<VarPtr>, Var*, NanoString>() : nullptr;
|
||||
if (cub_where && (ndim>1 || std::abs(cond->num)>4096)) {
|
||||
auto var = cub_where(cond, dtype);
|
||||
for(uint i=0;i<ndim;i++)
|
||||
forward(var[i]);
|
||||
return;
|
||||
|
|
|
@ -1040,7 +1040,7 @@ void KernelIR::split_loop(int i, int j) {
|
|||
inner[2]->attrs["code"] = lvalue+"+="+rvalue2+";";
|
||||
push_back("for ("+dtype+" id"+sj+"=0; id"+sj+"<range"+sj+"; id"+sj+"++) {}");
|
||||
auto& sloop = children.back();
|
||||
int range, stride;
|
||||
int range=0, stride=0;
|
||||
if (get_number("range"+si, range) && get_number("stride"+si, stride) && (range%stride==0))
|
||||
push_back(dtype+" range"+sj+" = "+S(stride)+";", &inner);
|
||||
else {
|
||||
|
|
|
@ -14,6 +14,7 @@ namespace jittor {
|
|||
|
||||
// define in tracer.cc
|
||||
void print_trace();
|
||||
void breakpoint();
|
||||
|
||||
constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
|
||||
return path[index]
|
||||
|
|
|
@ -72,6 +72,12 @@ void setter_gdb_path(string v) {
|
|||
setter_gdb_attach(gdb_attach);
|
||||
}
|
||||
|
||||
void breakpoint() {
|
||||
static bool is_attached = 0;
|
||||
if (is_attached) return;
|
||||
setter_gdb_attach(1);
|
||||
}
|
||||
|
||||
void print_trace() {
|
||||
if (gdb_path.size()) {
|
||||
// using gdb to print the stack trace
|
||||
|
|
|
@ -9,5 +9,6 @@
|
|||
namespace jittor {
|
||||
|
||||
void print_trace();
|
||||
void breakpoint();
|
||||
|
||||
} // jittor
|
Loading…
Reference in New Issue