polish cutt transpose

This commit is contained in:
Dun Liang 2021-07-27 11:13:39 +08:00
parent 01974db52d
commit 0748fc1854
2 changed files with 5 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.81'
__version__ = '1.2.3.82'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -89,7 +89,7 @@ void CuttTransposeOp::jit_run() {
reverse[i] = dim-1-new_axes[dim-1-i];
for (int i=0; i<dim; i++)
x_shape[i] = new_shape[dim-1-i];
if (dim == 1) {
if (dim == 1 || x->num==1) {
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
return;
}
@ -105,7 +105,9 @@ void CuttTransposeOp::jit_run() {
cuttExecute(iter->second, xp, yp);
} else {
cuttHandle plan;
CHECK(0==cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0));
checkCudaErrors(cudaDeviceSynchronize());
auto ret = cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
CHECK(0==ret) << ret << jk.to_string() << x << y;
cutt_plan_cache[jk.to_string()] = plan;
cuttExecute(plan, xp, yp);
}