diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index f224021f..23124737 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.2.5' +__version__ = '1.3.2.6' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/src/executor.cc b/python/jittor/src/executor.cc index 186eeb0c..5deacd53 100644 --- a/python/jittor/src/executor.cc +++ b/python/jittor/src/executor.cc @@ -499,7 +499,8 @@ void Executor::run_sync(vector vars, bool device_sync) { sync_times++; } for (Var* v : op->inputs()) { - migrate_to_cpu(v, allocator); + if (v->allocator->is_cuda()) + migrate_to_cpu(v, allocator); } if (!use_cuda_managed_allocator) { for (auto* var : op->outputs()) { diff --git a/python/jittor/src/mem/allocator.cc b/python/jittor/src/mem/allocator.cc index 1e425b4f..46c6f308 100644 --- a/python/jittor/src/mem/allocator.cc +++ b/python/jittor/src/mem/allocator.cc @@ -105,6 +105,7 @@ void migrate_to_cpu(Var* var, Allocator* allocator) { ); } else if (!use_cuda_managed_allocator) { + if (!var->allocator->is_cuda()) return; // must be a device allocator Allocation a(allocator, var->size); checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyDeviceToHost));