fix transpose bugs

This commit is contained in:
li-xl 2020-11-24 10:14:31 +08:00
parent 1045f04c4b
commit 8c83ba92e1
9 changed files with 68 additions and 82 deletions

View File

@ -6,16 +6,12 @@
#include "var.h"
#include "cutt_transpose_op.h"
#include "ops/op_register.h"
#include <iostream>
#ifdef JIT
#include "cutt.h"
#endif
#include "cutt_warper.h"
#include "misc/stack_vector.h"
namespace jittor {
#ifndef JIT
static auto make_transpose = get_op_info("cutt_transpose")
.get_constructor<VarPtr, Var*, NanoVector>();
@ -58,52 +54,49 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_transpose(dout, reverse);
}
void CuttTransposeOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype();
jk << _CS("][DIM=") << JK::hex1(axes.size());
for (uint i=0; i<axes.size(); i++)
jk << _CS("][AXES") << JK::hex1(axes[i]) << '=' << JK::hex1(i);
jk << ']';
}
unordered_map<string, unsigned int> cutt_plan_cache;
#else // JIT
#ifdef JIT_cuda
extern unordered_map<string, unsigned int> cutt_plan_cache;
void CuttTransposeOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
vector<int> permutation, permutation2;
vector<int> y_shape;
vector<int> x_shape;
@for(i, 0, DIM, permutation.push_back(DIM-1-AXES@i);)
@for(i, 0, DIM, permutation2.push_back(permutation[DIM-1-@i@@]);)
std::vector<int> reverse;
reverse.reserve(permutation2.size());
for (uint i=0; i<permutation2.size(); i++)
reverse[permutation2[i]] = i;
@for(i, 0, DIM, x_shape.push_back(x->shape[DIM-1-@i@@]);)
void CuttTransposeOp::run() {
auto* __restrict__ xp = x->mem_ptr;
auto* __restrict__ yp = y->mem_ptr;
StackVector<int> x_shape;
StackVector<int> new_shape, new_axes, trans, reverse;
int dim = x->shape.size();
for (int i=0; i<dim; i++) {
trans[i] = new_shape.size();
if (x->shape[i] != 1)
new_shape.push_back(x->shape[i]);
}
for (int i = 0; i < dim; ++i) {
if (x->shape[axes[i]] != 1) {
new_axes.push_back(trans[axes[i]]);
}
}
dim = new_shape.size();
for (int i=0; i<dim; i++)
reverse[i] = dim-1-new_axes[dim-1-i];
for (int i=0; i<dim; i++)
x_shape[i] = new_shape[dim-1-i];
if (dim == 1) {
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
return;
}
jk.clear();
jk << @DIM << ",";
for (uint i=0; i<@DIM; i++) jk << x_shape[i] << ",";
for (uint i=0; i<@DIM; i++) jk << reverse[i] << ",";
jk << sizeof(Tx) << ".";
jk << dim << ',';
for (int i=0; i<dim; i++) jk << x_shape[i] << ',';
for (int i=0; i<dim; i++) jk << reverse[i] << ',';
jk << x->dtype().dsize() << '.';
auto iter = cutt_plan_cache.find(jk.to_string());
LOGvvv << "Run cutt_transpose with key:" << jk.to_string();
if (iter!=cutt_plan_cache.end()){
cuttExecute(iter->second, xp, yp);
} else {
cuttHandle plan;
cuttPlan(&plan, @DIM, x_shape.data(), reverse.data(), sizeof(Tx), 0);
cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
cutt_plan_cache[jk.to_string()] = plan;
cuttExecute(plan, xp, yp);
}
}
#endif // JIT_cuda
#endif // JIT
} // jittor

View File

@ -19,7 +19,7 @@ struct CuttTransposeOp : Op {
const char* name() const override { return "cutt_transpose"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
void run() override;
};
} // jittor

View File

@ -33,6 +33,15 @@ def __iter__(x):
return result.__iter__()
jt.Var.__iter__ = __iter__
def all(x,dim):
return x.all_(dim).bool()
jt.Var.all = all
def any(x,dim):
return x.any_(dim).bool()
jt.Var.any = any
def repeat(x, *shape):
r'''
Repeats this var along the specified dimensions.
@ -346,9 +355,8 @@ def unique(x):
'''
x = x.reshape(-1)
_,x = jt.argsort(x)
index2 = [i for i in range(1,x.shape[0])]
index1 = [i for i in range(x.shape[0]-1)]
y = x[1:][x[index2] != x[index1]]
index,= jt.index((x.shape[0],))
y = x[1:][x[index[1:]] != x[index[:-1]]]
x = jt.contrib.concat([x[:1],y],dim=0)
return x

View File

@ -30,7 +30,7 @@ class TestCuttTransposeOp(unittest.TestCase):
for perm in perms:
with jt.log_capture_scope(
log_silent=1,
log_v=0, log_vprefix="op.cc=100"
log_v=0, log_vprefix="cutt=100"
) as raw_log:
if perm:
x = np.transpose(a, perm)
@ -39,7 +39,7 @@ class TestCuttTransposeOp(unittest.TestCase):
x = np.transpose(a)
y = jt.transpose(a).data
self.assertEqual(x.shape, y.shape)
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cutt_transpose" + ".*)")
logs = find_log_with_re(raw_log, "(Run cutt_transpose with key.*)")
if perm is None:
continue
last = -1
@ -53,7 +53,7 @@ class TestCuttTransposeOp(unittest.TestCase):
last = perm[i]
if not in_order:
assert len(logs)==1
assert (x==y).all(), f"\n{x}\n{y}"
assert (x==y).all(), f"\n{x}\n{y}\n{perm}\n{a.shape}"
ia = [gen_data([5, 7]), gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3]), gen_data([3,1,5,3,1])]
for a in ia: check(a)

View File

@ -17,6 +17,7 @@ struct StackVector {
inline T& front() { return a[0]; }
inline T& back() { return a[n-1]; }
inline int size() { return n;}
inline T* data() { return a;}
inline StackVector(int n=0) : n(n) {}
struct Iter {

View File

@ -34,9 +34,9 @@ unordered_set<string> reduce_ops = {
"add",
// @pybind(prod, product, reduce_multiply)
"multiply",
// @pybind(reduce_logical_and, all)
// @pybind(reduce_logical_and, all_)
"logical_and",
// @pybind(reduce_logical_or, any)
// @pybind(reduce_logical_or, any_)
"logical_or",
"logical_xor",
"bitwise_and",
@ -65,7 +65,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
reduce_mask |= 1<<dim;
}
}
if (x->dtype() == ns_bool && ns == ns_add)
// if (x->dtype() == ns_bool && ns == ns_add)
if (x->dtype() == ns_bool)
y = create_output(nullptr, ns_int32);
else
y = create_output(nullptr, binary_dtype_infer(ns, x, x));

View File

@ -40,38 +40,8 @@ TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
.get_constructor<VarPtr, Var*, NanoVector>();
}
if (cutt_transpose) {
bool need_reshape = false;
int dims = x->shape.size();
vector<int64> in_axes;
vector<int64> in_shape;
vector<int64> out_shape;
vector<int64> trans;
int cnt = 0;
for (int i = 0; i < dims; ++i) {
if (x->shape[i] == 1) {
need_reshape = true;
trans.push_back(-1);
} else {
trans.push_back(cnt);
cnt += 1;
in_shape.push_back(x->shape[i]);
}
out_shape.push_back(x->shape[axes[i]]);
}
for (int i = 0; i < dims; ++i) {
if (x->shape[axes[i]] != 1) {
in_axes.push_back(trans[axes[i]]);
}
}
if (need_reshape) {
auto x1 = make_reshape(x, NanoVector(in_shape));
auto x2 = cutt_transpose(x1, in_axes);
auto x3 = make_reshape(x2, NanoVector(out_shape));
forward(x3);
} else {
auto var = cutt_transpose(x, axes);
forward(var);
}
auto var = cutt_transpose(x, axes);
forward(var);
return;
}
}

View File

@ -149,6 +149,19 @@ static vector<Stack> get_stack_info() {
(int)PyFrame_GetLineNumber(prev_f)});
}
}
if (stacks.size() == 0) {
auto m = std::min(3,n);
for (int i=0; i<m; i++) {
auto f = frames[n-m+i];
auto s = to_string(f->f_code->co_filename);
auto num = (int)PyFrame_GetLineNumber(f);
stacks.emplace_back(Stack{
s+":"+S(num),
"",
s,
num});
}
}
return stacks;
}

View File

@ -23,7 +23,7 @@ static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restr
ASSERT(0 == PyBytes_AsStringAndSize(ret.obj, &s, &size));
rb->push_t<int64>(size, offset);
rb->push(size, offset);
LOGir << string(rb->get_ptr(size, offset), size);
// LOGir << string(rb->get_ptr(size, offset), size);
std::memcpy(rb->get_ptr(size, offset), s, size);
return;
}