mirror of https://github.com/Jittor/Jittor
polish cutt transpose
This commit is contained in:
parent
01974db52d
commit
0748fc1854
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.81'
|
||||
__version__ = '1.2.3.82'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -89,7 +89,7 @@ void CuttTransposeOp::jit_run() {
|
|||
reverse[i] = dim-1-new_axes[dim-1-i];
|
||||
for (int i=0; i<dim; i++)
|
||||
x_shape[i] = new_shape[dim-1-i];
|
||||
if (dim == 1) {
|
||||
if (dim == 1 || x->num==1) {
|
||||
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
|
||||
return;
|
||||
}
|
||||
|
@ -105,7 +105,9 @@ void CuttTransposeOp::jit_run() {
|
|||
cuttExecute(iter->second, xp, yp);
|
||||
} else {
|
||||
cuttHandle plan;
|
||||
CHECK(0==cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0));
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
auto ret = cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
|
||||
CHECK(0==ret) << ret << jk.to_string() << x << y;
|
||||
cutt_plan_cache[jk.to_string()] = plan;
|
||||
cuttExecute(plan, xp, yp);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue