mirror of https://github.com/Jittor/Jittor
72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
|
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import os
|
|
from jittor_utils import env_or_try_find
|
|
import jittor_utils
|
|
import ctypes
|
|
import glob
|
|
import jittor.compiler as compiler
|
|
|
|
has_acl = 0
|
|
cc_flags = ""
|
|
tikcc_path = env_or_try_find('tikcc_path', 'tikcc')
|
|
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
|
compiler.has_acl = has_acl
|
|
|
|
def install():
|
|
import jittor.compiler as compiler
|
|
global has_acl, cc_flags
|
|
acl_compiler_home = os.path.dirname(__file__)
|
|
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
|
|
cc_flags += f" -DHAS_CUDA -DIS_ACL \
|
|
-I/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include/ \
|
|
-L/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64 \
|
|
-I{acl_compiler_home} -ltikc_runtime -lascendcl "
|
|
ctypes.CDLL("libascendcl.so", dlopen_flags)
|
|
'''
|
|
-I/usr/local/Ascend/driver/include \
|
|
-L/usr/local/Ascend/compiler/lib64 \
|
|
-L/usr/local/Ascend/runtime/lib64 \
|
|
'''
|
|
jittor_utils.LOG.i("ACL detected")
|
|
|
|
mod = jittor_utils.compile_module('''
|
|
#include "common.h"
|
|
namespace jittor {
|
|
// @pyjt(process)
|
|
string process_acl(const string& src, const string& name, const map<string,string>& kargs);
|
|
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)
|
|
jittor_utils.process_jittor_source("acl", mod.process)
|
|
|
|
has_acl = 1
|
|
|
|
|
|
def install_extern():
|
|
return False
|
|
|
|
|
|
def check():
|
|
import jittor.compiler as compiler
|
|
global has_acl, cc_flags
|
|
if tikcc_path:
|
|
try:
|
|
install()
|
|
except Exception as e:
|
|
jittor_utils.LOG.w(f"load ACL failed, exception: {e}")
|
|
has_acl = 0
|
|
compiler.has_acl = has_acl
|
|
compiler.tikcc_path = tikcc_path
|
|
if not has_acl: return False
|
|
compiler.cc_flags += cc_flags
|
|
compiler.nvcc_path = tikcc_path
|
|
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
|
|
return True
|
|
|
|
def post_process():
|
|
if has_acl:
|
|
from jittor import pool
|
|
pool.pool_use_code_op = False |