From 3cf5d7f2a475f81916580fea821be15097dab292 Mon Sep 17 00:00:00 2001 From: Yuxuan Han Date: Mon, 16 Jun 2025 11:05:05 +0800 Subject: [PATCH] adjust aclnn.h reference --- python/jittor/compiler.py | 20 ++++++++++--------- python/jittor/extern/acl/acl_jittor.h | 1 + python/jittor/extern/acl/aclops/utils.cc | 1 + python/jittor/extern/acl/aclops/utils.h | 1 + python/jittor/extern/mpi/ops/mpi_reduce_op.cc | 3 +++ python/jittor/src/common.h | 2 +- python/jittor/src/ops/array_op.cc | 6 +++--- python/jittor/src/ops/copy_op.cc | 2 +- python/jittor/src/ops/fetch_op.cc | 20 +++++++++---------- 9 files changed, 32 insertions(+), 24 deletions(-) diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index d6675971..af1dc65f 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -1191,15 +1191,17 @@ ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') # build cache_compile cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" " -cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" " -cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" " -cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" " -cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" " -cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" " -cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" " -cc_flags += " -llibascendcl " -cc_flags += " -llibnnopbase " -cc_flags += " -llibopapi " + +if ascend_toolkit_home: + cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" " + cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" " + cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" " + cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" " + cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" " + cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" " + cc_flags += " -llibascendcl " + cc_flags += " -llibnnopbase " + cc_flags += " -llibopapi " cc_flags += py_include diff --git a/python/jittor/extern/acl/acl_jittor.h b/python/jittor/extern/acl/acl_jittor.h index ee9960cb..822df915 100644 --- a/python/jittor/extern/acl/acl_jittor.h +++ b/python/jittor/extern/acl/acl_jittor.h @@ -7,6 +7,7 @@ #pragma once #include "common.h" #include +#include "aclnn.h" std::string acl_error_to_string(aclError error); diff --git a/python/jittor/extern/acl/aclops/utils.cc b/python/jittor/extern/acl/aclops/utils.cc index 1aac88db..de1f9724 100644 --- a/python/jittor/extern/acl/aclops/utils.cc +++ b/python/jittor/extern/acl/aclops/utils.cc @@ -5,6 +5,7 @@ #include #include #include "utils.h" +#include "aclnn.h" namespace jittor { diff --git a/python/jittor/extern/acl/aclops/utils.h b/python/jittor/extern/acl/aclops/utils.h index de2b7bc7..c5cfdfea 100644 --- a/python/jittor/extern/acl/aclops/utils.h +++ b/python/jittor/extern/acl/aclops/utils.h @@ -6,6 +6,7 @@ #include #include #include "misc/nano_string.h" +#include "aclnn.h" namespace jittor { diff --git a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc index 78294548..62a7fa95 100644 --- a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc +++ b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc @@ -49,10 +49,13 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r forward(var); return; } else if (hccl_reduce) { + auto var = hccl_reduce(x, "sum", root); + //exe.run_sync({var}, true); forward(var); return; } } + #endif y = create_output(nullptr, x->dtype()); } diff --git a/python/jittor/src/common.h b/python/jittor/src/common.h index 58c6a37a..968a1699 100644 --- a/python/jittor/src/common.h +++ b/python/jittor/src/common.h @@ -8,7 +8,7 @@ #include #include #include "utils/log.h" -#include "../extern/acl/aclnn/aclnn.h" +// #include "../extern/acl/aclnn/aclnn.h" #define JIT_TEST(name) extern void jit_test_ ## name () void expect_error(std::function func); diff --git a/python/jittor/src/ops/array_op.cc b/python/jittor/src/ops/array_op.cc index 23b0f38b..58b50215 100644 --- a/python/jittor/src/ops/array_op.cc +++ b/python/jittor/src/ops/array_op.cc @@ -31,9 +31,9 @@ cudaEvent_t event; struct Init { Init() { if (!get_device_count()) return; - //checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - //checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); - stream = aclstream; + checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); + // stream = aclstream; } ~Init() { if (!get_device_count()) return; diff --git a/python/jittor/src/ops/copy_op.cc b/python/jittor/src/ops/copy_op.cc index b48e57f5..90420470 100644 --- a/python/jittor/src/ops/copy_op.cc +++ b/python/jittor/src/ops/copy_op.cc @@ -17,7 +17,7 @@ namespace jittor { -EXTERN_LIB aclrtStream aclstream; +// EXTERN_LIB aclrtStream aclstream; CopyOp::CopyOp(Var* x) { flags.set(NodeFlags::_cpu); diff --git a/python/jittor/src/ops/fetch_op.cc b/python/jittor/src/ops/fetch_op.cc index 98de48d3..bd2e6959 100644 --- a/python/jittor/src/ops/fetch_op.cc +++ b/python/jittor/src/ops/fetch_op.cc @@ -47,7 +47,7 @@ Init() { if (!get_device_count()) return; checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); - stream = aclstream; + // stream = aclstream; } ~Init() { if (!get_device_count()) return; @@ -123,11 +123,11 @@ void FetchOp::run() { new (&allocation) Allocation(&cuda_dual_allocator, v->size); // mostly device to device #if IS_CUDA - // checkCudaErrors(cudaMemcpyAsync( - // allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream)); checkCudaErrors(cudaMemcpyAsync( - allocation.ptr, v->size, v->mem_ptr, v->size, cudaMemcpyDefault, aclstream)); - checkCudaErrors(aclrtSynchronizeStream(aclstream)); + allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream)); + // checkCudaErrors(cudaMemcpyAsync( + // allocation.ptr, v->size, v->mem_ptr, v->size, cudaMemcpyDefault, aclstream)); + // checkCudaErrors(aclrtSynchronizeStream(aclstream)); #else checkCudaErrors(cudaMemcpyAsync( allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDeviceToDevice, stream)); @@ -135,11 +135,11 @@ void FetchOp::run() { auto host_ptr = cuda_dual_allocator.get_dual_allocation( allocation.allocation).host_ptr; // device to host - // checkCudaErrors(cudaMemcpyAsync( - // host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream)); - checkCudaErrors(aclrtMemcpyAsync( - host_ptr, v->size, allocation.ptr, v->size, cudaMemcpyDeviceToHost, aclstream)); - checkCudaErrors(aclrtSynchronizeStream(aclstream)); + checkCudaErrors(cudaMemcpyAsync( + host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream)); + // checkCudaErrors(aclrtMemcpyAsync( + // host_ptr, v->size, allocation.ptr, v->size, cudaMemcpyDeviceToHost, aclstream)); + // checkCudaErrors(aclrtSynchronizeStream(aclstream)); allocation.ptr = host_ptr; has_cuda_memcpy = true; } else