mirror of https://github.com/Jittor/Jittor
polish migrate to cpu
This commit is contained in:
parent
db88d73ed1
commit
9048f3fd41
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.2.5'
|
||||
__version__ = '1.3.2.6'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -499,7 +499,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
sync_times++;
|
||||
}
|
||||
for (Var* v : op->inputs()) {
|
||||
migrate_to_cpu(v, allocator);
|
||||
if (v->allocator->is_cuda())
|
||||
migrate_to_cpu(v, allocator);
|
||||
}
|
||||
if (!use_cuda_managed_allocator) {
|
||||
for (auto* var : op->outputs()) {
|
||||
|
|
|
@ -105,6 +105,7 @@ void migrate_to_cpu(Var* var, Allocator* allocator) {
|
|||
);
|
||||
} else
|
||||
if (!use_cuda_managed_allocator) {
|
||||
if (!var->allocator->is_cuda()) return;
|
||||
// must be a device allocator
|
||||
Allocation a(allocator, var->size);
|
||||
checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyDeviceToHost));
|
||||
|
|
Loading…
Reference in New Issue