mirror of https://github.com/Jittor/Jittor
fix copy op
This commit is contained in:
parent
ad320ed2cf
commit
8bb698c225
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.6.3'
|
||||
__version__ = '1.1.6.4'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -8,6 +8,10 @@
|
|||
#include "var.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "ops/copy_op.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -37,7 +41,7 @@ void CopyOp::run() {
|
|||
}
|
||||
#ifdef HAS_CUDA
|
||||
else {
|
||||
std::cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0);
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue