add acl backend

This commit is contained in:
Dun Liang 2022-03-21 17:48:50 +08:00
parent 5efb222dd3
commit e7bb2545d3
27 changed files with 785 additions and 30 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.48'
__version__ = '1.3.1.49'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -228,13 +228,18 @@ def gen_jit_flags():
continue
visit[name] = 1
jit_declares.append(f"DECLARE_FLAG({type}, {name});")
alias = []
if name == "use_cuda":
alias = ["use_device", "use_acl"]
get_names = ",".join(["__get__"+a for a in [name]+alias])
set_names = ",".join(["__set__"+a for a in [name]+alias])
flags_defs.append(f"""
/* {name}(type:{type}, default:{default}): {doc} */
// @pyjt(__get__{name})
// @pyjt({get_names})
{type} _get_{name}() {{ return {name}; }}
// @pyjt(__set__{name})
// @pyjt({set_names})
void _set_{name}({type} v) {{ set_{name}(v); }}
{f'''// @pyjt(__set__{name})
{f'''// @pyjt({set_names})
void _set_{name}(bool v) {{ set_{name}(v); }}
''' if type=="int" else ""}
""")
@ -843,7 +848,7 @@ def check_cuda():
# this nvcc is install by package manager
cuda_lib = "/usr/lib/x86_64-linux-gnu"
cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc")
cc_flags += f" -DHAS_CUDA -I\"{cuda_include}\" -I\"{cuda_include2}\" "
cc_flags += f" -DHAS_CUDA -DIS_CUDA -I\"{cuda_include}\" -I\"{cuda_include2}\" "
if os.name == 'nt':
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib", "x64"))
# cc_flags += f" \"{cuda_lib}\\cudart.lib\" "
@ -1212,6 +1217,14 @@ if has_cuda:
return nvcc_flags
nvcc_flags = convert_nvcc_flags(nvcc_flags)
# from .acl_compiler import check_acl
from .extern.acl import acl_compiler
jit_utils.add_backend(acl_compiler)
for mod in jit_utils.backends:
if mod.check():
break
# build core
gen_jit_flags()
gen_jit_tests()
@ -1237,6 +1250,7 @@ files4 = [ f[len(jittor_path)+1:] for f in files4 ]
at_beginning = [
"src/ops/op_utils.cc",
"src/ops/op_register.cc",
"src/init.cc",
"src/event_queue.cc",
"src/mem/allocator/sfrl_allocator.cc",
"src/mem/allocator.cc",

View File

@ -0,0 +1,54 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
has_acl = 0
cc_flags = ""
tikcc_path = env_or_try_find('tikcc_path', 'tikcc')
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
def install():
import jittor.compiler as compiler
global has_acl, cc_flags
acl_compiler_home = os.path.dirname(__file__)
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
cc_flags += f" -DHAS_CUDA -DIS_ACL -I/usr/local/Ascend/runtime/include -I/usr/local/Ascend/driver/include -L/usr/local/Ascend/compiler/lib64 -L/usr/local/Ascend/runtime/lib64 -I{acl_compiler_home} -ltikc_runtime -lascendcl "
ctypes.CDLL("libascendcl.so", dlopen_flags)
jittor_utils.LOG.i("ACL detected")
mod = jittor_utils.compile_module('''
#include "common.h"
namespace jittor {
// @pyjt(process)
string process_acl(const string& src, const string& name, const map<string,string>& kargs);
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)
jittor_utils.process_jittor_source("acl", mod.process)
has_acl = 1
def check():
import jittor.compiler as compiler
global has_acl, cc_flags
if tikcc_path:
try:
install()
except Exception as e:
jittor_utils.LOG.w(f"load ACL failed, exception: {e}")
has_acl = 0
compiler.has_acl = has_acl
compiler.tikcc_path = tikcc_path
if not has_acl: return False
compiler.cc_flags += cc_flags
compiler.nvcc_path = tikcc_path
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
return True

View File

@ -0,0 +1,228 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "common.h"
using std::string;
using std::unordered_map;
typedef int aclError;
static inline unordered_map<aclError,string> gen_map(string s) {
unordered_map<aclError,string> smap;
for (int i=0; i<s.size(); i++) {
if (s[i] == ';') {
int j=s.rfind(" ", i);
int code = std::stoi(s.substr(j+1, i-j-1));
int k = s.rfind(" ", j-1);
int l = s.rfind(" ACL_", k-1);
smap[code] = s.substr(l+1, k-l-1);
}
}
return smap;
}
string acl_error_to_string(aclError error) {
static unordered_map<aclError,string> acl_error_map = gen_map(R"(
// from acl_base.h
static const int ACL_ERROR_INVALID_PARAM = 100000;
static const int ACL_ERROR_UNINITIALIZE = 100001;
static const int ACL_ERROR_REPEAT_INITIALIZE = 100002;
static const int ACL_ERROR_INVALID_FILE = 100003;
static const int ACL_ERROR_WRITE_FILE = 100004;
static const int ACL_ERROR_INVALID_FILE_SIZE = 100005;
static const int ACL_ERROR_PARSE_FILE = 100006;
static const int ACL_ERROR_FILE_MISSING_ATTR = 100007;
static const int ACL_ERROR_FILE_ATTR_INVALID = 100008;
static const int ACL_ERROR_INVALID_DUMP_CONFIG = 100009;
static const int ACL_ERROR_INVALID_PROFILING_CONFIG = 100010;
static const int ACL_ERROR_INVALID_MODEL_ID = 100011;
static const int ACL_ERROR_DESERIALIZE_MODEL = 100012;
static const int ACL_ERROR_PARSE_MODEL = 100013;
static const int ACL_ERROR_READ_MODEL_FAILURE = 100014;
static const int ACL_ERROR_MODEL_SIZE_INVALID = 100015;
static const int ACL_ERROR_MODEL_MISSING_ATTR = 100016;
static const int ACL_ERROR_MODEL_INPUT_NOT_MATCH = 100017;
static const int ACL_ERROR_MODEL_OUTPUT_NOT_MATCH = 100018;
static const int ACL_ERROR_MODEL_NOT_DYNAMIC = 100019;
static const int ACL_ERROR_OP_TYPE_NOT_MATCH = 100020;
static const int ACL_ERROR_OP_INPUT_NOT_MATCH = 100021;
static const int ACL_ERROR_OP_OUTPUT_NOT_MATCH = 100022;
static const int ACL_ERROR_OP_ATTR_NOT_MATCH = 100023;
static const int ACL_ERROR_OP_NOT_FOUND = 100024;
static const int ACL_ERROR_OP_LOAD_FAILED = 100025;
static const int ACL_ERROR_UNSUPPORTED_DATA_TYPE = 100026;
static const int ACL_ERROR_FORMAT_NOT_MATCH = 100027;
static const int ACL_ERROR_BIN_SELECTOR_NOT_REGISTERED = 100028;
static const int ACL_ERROR_KERNEL_NOT_FOUND = 100029;
static const int ACL_ERROR_BIN_SELECTOR_ALREADY_REGISTERED = 100030;
static const int ACL_ERROR_KERNEL_ALREADY_REGISTERED = 100031;
static const int ACL_ERROR_INVALID_QUEUE_ID = 100032;
static const int ACL_ERROR_REPEAT_SUBSCRIBE = 100033;
static const int ACL_ERROR_STREAM_NOT_SUBSCRIBE = 100034;
static const int ACL_ERROR_THREAD_NOT_SUBSCRIBE = 100035;
static const int ACL_ERROR_WAIT_CALLBACK_TIMEOUT = 100036;
static const int ACL_ERROR_REPEAT_FINALIZE = 100037;
static const int ACL_ERROR_NOT_STATIC_AIPP = 100038;
static const int ACL_ERROR_COMPILING_STUB_MODE = 100039;
static const int ACL_ERROR_GROUP_NOT_SET = 100040;
static const int ACL_ERROR_GROUP_NOT_CREATE = 100041;
static const int ACL_ERROR_PROF_ALREADY_RUN = 100042;
static const int ACL_ERROR_PROF_NOT_RUN = 100043;
static const int ACL_ERROR_DUMP_ALREADY_RUN = 100044;
static const int ACL_ERROR_DUMP_NOT_RUN = 100045;
static const int ACL_ERROR_PROF_REPEAT_SUBSCRIBE = 148046;
static const int ACL_ERROR_PROF_API_CONFLICT = 148047;
static const int ACL_ERROR_INVALID_MAX_OPQUEUE_NUM_CONFIG = 148048;
static const int ACL_ERROR_INVALID_OPP_PATH = 148049;
static const int ACL_ERROR_OP_UNSUPPORTED_DYNAMIC = 148050;
static const int ACL_ERROR_RELATIVE_RESOURCE_NOT_CLEARED = 148051;
static const int ACL_ERROR_BAD_ALLOC = 200000;
static const int ACL_ERROR_API_NOT_SUPPORT = 200001;
static const int ACL_ERROR_INVALID_DEVICE = 200002;
static const int ACL_ERROR_MEMORY_ADDRESS_UNALIGNED = 200003;
static const int ACL_ERROR_RESOURCE_NOT_MATCH = 200004;
static const int ACL_ERROR_INVALID_RESOURCE_HANDLE = 200005;
static const int ACL_ERROR_FEATURE_UNSUPPORTED = 200006;
static const int ACL_ERROR_PROF_MODULES_UNSUPPORTED = 200007;
static const int ACL_ERROR_STORAGE_OVER_LIMIT = 300000;
static const int ACL_ERROR_INTERNAL_ERROR = 500000;
static const int ACL_ERROR_FAILURE = 500001;
static const int ACL_ERROR_GE_FAILURE = 500002;
static const int ACL_ERROR_RT_FAILURE = 500003;
static const int ACL_ERROR_DRV_FAILURE = 500004;
static const int ACL_ERROR_PROFILING_FAILURE = 500005;
// from ge_error_codes.h
static const uint32_t ACL_ERROR_GE_PARAM_INVALID = 145000U;
static const uint32_t ACL_ERROR_GE_EXEC_NOT_INIT = 145001U;
static const uint32_t ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID = 145002U;
static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ID_INVALID = 145003U;
static const uint32_t ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID = 145006U;
static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID = 145007U;
static const uint32_t ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID = 145008U;
static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_REPEATED = 145009U;
static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID = 145011U;
static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID = 145012U;
static const uint32_t ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID = 145013U;
static const uint32_t ACL_ERROR_GE_AIPP_BATCH_EMPTY = 145014U;
static const uint32_t ACL_ERROR_GE_AIPP_NOT_EXIST = 145015U;
static const uint32_t ACL_ERROR_GE_AIPP_MODE_INVALID = 145016U;
static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017U;
static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018U;
static const uint32_t ACL_ERROR_GE_PLGMGR_PATH_INVALID = 145019U;
static const uint32_t ACL_ERROR_GE_FORMAT_INVALID = 145020U;
static const uint32_t ACL_ERROR_GE_SHAPE_INVALID = 145021U;
static const uint32_t ACL_ERROR_GE_DATATYPE_INVALID = 145022U;
static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000U;
static const uint32_t ACL_ERROR_GE_MEMORY_OPERATE_FAILED = 245001U;
static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000U;
static const uint32_t ACL_ERROR_GE_LOAD_MODEL = 545001U;
static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_PARTITION_FAILED = 545002U;
static const uint32_t ACL_ERROR_GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED = 545003U;
static const uint32_t ACL_ERROR_GE_EXEC_LOAD_TASK_PARTITION_FAILED = 545004U;
static const uint32_t ACL_ERROR_GE_EXEC_LOAD_KERNEL_PARTITION_FAILED = 545005U;
static const uint32_t ACL_ERROR_GE_EXEC_RELEASE_MODEL_DATA = 545006U;
static const uint32_t ACL_ERROR_GE_COMMAND_HANDLE = 545007U;
static const uint32_t ACL_ERROR_GE_GET_TENSOR_INFO = 545008U;
static const uint32_t ACL_ERROR_GE_UNLOAD_MODEL = 545009U;
static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid
static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id
static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null
static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context
static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context
static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid
static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal
static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned
static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed
static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed
static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream
static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread
static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set
static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create
static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream
static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type
static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle
static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type
static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout
static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support
static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error
static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error
static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow
static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device
static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail
static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission
static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource
static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource
static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource
static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource
static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource
static const int32_t ACL_ERROR_RT_OVER_LIMIT = 207012; // over limit
static const int32_t ACL_ERROR_RT_QUEUE_EMPTY = 207013; // queue is empty
static const int32_t ACL_ERROR_RT_QUEUE_FULL = 207014; // queue is full
static const int32_t ACL_ERROR_RT_REPEATED_INIT = 207015; // repeated init
static const int32_t ACL_ERROR_RT_AIVEC_OVER_FLOW = 207016; // aivec over flow
static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error
static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error
static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream
static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream
static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete
static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence
static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete
static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error
static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error
static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support
static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat
static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed
static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout
static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error
static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout
static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception
static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception
static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout
static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception
static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error
static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error
static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error
static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error
static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal
static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering
static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init
static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data
static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error
static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate
static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed
static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed
static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context
static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out
static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout
static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception
static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception
static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal
static const int32_t ACL_ERROR_RT_DIE_MODE_CHANGE_ERROR = 507038; // can not change die mode
static const int32_t ACL_ERROR_RT_DIE_SET_ERROR = 507039; // single die mode can not set die
static const int32_t ACL_ERROR_RT_INVALID_DIEID = 507040; // invalid die id
static const int32_t ACL_ERROR_RT_DIE_MODE_NOT_SET = 507041; // die mode not set
static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error
static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error
static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect
)");
if (acl_error_map.count(error))
return acl_error_map[error];
return "unknown " + std::to_string((int)error);
}

186
python/jittor/extern/acl/acl_jittor.cc vendored Normal file
View File

@ -0,0 +1,186 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "acl_jittor.h"
#include "utils/str_utils.h"
#include <chrono>
#include <thread>
namespace jittor {
uint64_t acl_jittor_tid;
int acl_jittor_thread_running=0;
aclrtContext acl_jittor_context;
#define CHECK_ACL(x) ASSERTop(x,==,0)
static void* acl_jittor_process_callback(void*) {
acl_jittor_thread_running = 1;
int deviceId = 0;
CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context));
while (acl_jittor_thread_running) {
// LOGir << "acl_jittor_process_callback";
auto ret = aclrtProcessReport(1000);
if (ret) {
if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT)
LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret);
break;
}
}
acl_jittor_thread_running = 0;
return (void*)0;
}
// void aaa(void*) {
// LOGir << "haha";
// }
struct acl_jittor_initer {
acl_jittor_initer() {
CHECK_ACL(aclInit(nullptr));
uint device_count = 0;
// 获取可用的Device数量
CHECK_ACL(aclrtGetDeviceCount(&device_count));
LOGi << "Found ACL device number:" << device_count;
CHECK_ACL(aclrtSetDevice(0));
CHECK_ACL(aclrtCreateContext(&acl_jittor_context, 0));
CHECK_ACL(aclrtSetCurrentContext(acl_jittor_context));
pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0);
// subscribe for default stream
CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,0));
// simple callback test
// aclrtStream stream;
// CHECK_ACL(aclrtCreateStream(&stream));
// CHECK_ACL(aclrtSubscribeReport(acl_jittor_tid,stream));
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, stream));
// CHECK_ACL(aclrtLaunchCallback((aclrtCallback)&aaa, 0, ACL_CALLBACK_NO_BLOCK, 0));
}
~acl_jittor_initer() {
acl_jittor_thread_running = 0;
CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid,0));
CHECK_ACL(aclrtDestroyContext(acl_jittor_context));
CHECK_ACL(aclFinalize());
}
} _acl_jittor_initer;
string process_acl(const string& src, const string& name, const map<string,string>& kargs) {
auto tokens = token_split(src);
int edit = 0;
for (int i=0; i<tokens.size(); i++) {
auto& token = tokens[i];
if (token == "cuda_runtime") token = "acl_jittor", edit ++; else
if (token == "CUDA") token = "ACL", edit ++; else
if (startswith(token, "cuda")) {
if (token.size()>=5 && token[4] >= 'A' && token[4] <= 'Z') {
if (token == "cudaGetDeviceCount") {
token_replace(tokens, i, "($1);", "((uint*)$1);");
} else if (token == "cudaLaunchHostFunc") {
// ACL_CALLBACK_BLOCK for 310
token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)",
"LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)");
} else if (token == "cudaMemcpy")
token_replace(tokens, i, "cudaMemcpy($1,$2,$3,",
"aclrtMemcpy($1,$3,$2,$3,");
else if (token == "cudaMemcpyAsync")
token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,",
"aclrtMemcpyAsync($1,$3,$2,$3,");
else if (token == "cudaMemcpyDeviceToHost") token = "ACL_MEMCPY_DEVICE_TO_HOST";
else if (token == "cudaMemcpyHostToDevice") token = "ACL_MEMCPY_HOST_TO_DEVICE";
else if (token == "cudaMemcpyDeviceToDevice") token = "ACL_MEMCPY_DEVICE_TO_DEVICE";
else if (token == "cudaMallocManaged" || token == "cudaMalloc") {
// unified address not supported
token = "aclrtMalloc";
token_replace(tokens, i, "($1,$2)",
"($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)");
} else if (token == "cudaMemGetInfo")
token_replace(tokens, i, "cudaMemGetInfo($1,$2)",
"aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)");
else if (token == "cudaGetLastError")
token_replace(tokens, i, "cudaGetLastError()", "0");
else if (token == "cudaStreamCreateWithFlags")
token_replace(tokens, i-1,
"(cudaStreamCreateWithFlags($1,$2));",
"(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));");
else if (token == "cudaEventCreate")
token_replace(tokens, i,
"cudaEventCreate($1,$2)",
"aclrtCreateEvent($1)");
else if (token == "cudaDeviceSynchronize")
token = "aclrtSynchronizeDevice";
else if (token == "cudaStreamDestroy")
token_replace(tokens, i, "cudaStreamDestroy($1)",
"(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))");
else if (token == "cudaEventDestroy")
token = "aclrtDestroyEvent";
else if (token == "cudaEventRecord")
token = "aclrtRecordEvent";
else if (token == "cudaStreamWaitEvent")
token_replace(tokens, i,
"cudaStreamWaitEvent($1,$2,$3)",
"aclrtStreamWaitEvent($1,$2)");
if (token.size() && token[0] == 'c')
token = "aclrt" + token.substr(4);
if (endswith(token, "_t"))
token = token.substr(0, token.size()-2);
edit ++;
}
} else
if (token == "_cudaGetErrorEnum") {
token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))");
edit ++;
} else
if (token == "checkCudaErrors")
token = "checkAclErrors";
else if (token == "JPU") {
edit ++;
string new_code;
if (tokens[i+2] == "op_compiler")
token_replace(tokens, i,
"JPU(op_compiler($1,$2,$3))",
"acl_jittor_op_compiler($1,$2,$3)");
else if (tokens[i+2] == "header")
new_code = "#include \"acl_jittor.h\"";
if (new_code.size())
token_replace(tokens, i, "JPU($1)", new_code);
} else if (token == "use_cuda_managed_allocator" && tokens[i+1][0]==',') {
tokens[i+2] = "0"; // disable unified address
}
}
if (!edit) return src;
return join(tokens, "");
}
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl) {
if (!is_acl) return;
filename = replace(filename, ".cc", ".tikcc");
// LOGir << filename;
string new_src = process_acl(src, "", {});
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
new_src = replace(new_src, "__global__", "__ai_device_entry__");
new_src = token_replace(new_src, "__launch_bounds__($1)", "");
new_src = token_replace(new_src, "int thread_num = $1;", "int thread_num = 1;");
new_src = token_replace(new_src, "tn0=std::max(tn0, $1);", "");
new_src = token_replace(new_src, "<<<$1,$2>>>", "<<<1,0>>>");
new_src = token_replace(new_src, "int thread_id = $1;", "int thread_id = 1;");
// for inc error
new_src = token_replace(new_src, "for ($1+=$2)", "for ($1++)");
// bit op error
new_src = token_replace(new_src, "int tnum$1;", "");
new_src = token_replace(new_src, "int tid$1=$2;", "int tid$1=0;");
src = new_src;
// auto tokens = token_split(new_src);
}
}

19
python/jittor/extern/acl/acl_jittor.h vendored Normal file
View File

@ -0,0 +1,19 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
#include <acl/acl.h>
std::string acl_error_to_string(aclError error);
namespace jittor {
EXTERN_LIB uint64_t acl_jittor_tid;
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl);
}

View File

@ -90,7 +90,7 @@ void CuttTransposeOp::jit_run() {
for (int i=0; i<dim; i++)
x_shape[i] = new_shape[dim-1-i];
if (dim == 1 || x->num==1) {
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDefault, 0));
checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDeviceToDevice, 0));
return;
}
JK& jk = get_jk();

View File

@ -25,7 +25,9 @@
#include <stdlib.h>
#include <string.h>
#ifdef IS_CUDA
#include <helper_string.h>
#endif
#ifndef EXIT_WAIVED
#define EXIT_WAIVED 2
@ -129,6 +131,9 @@ void check(T result, char const *const func, const char *const file,
}
}
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__)
#ifdef __DRIVER_TYPES_H__
// This will output the proper CUDA error strings in the event
// that a CUDA host call returns an error

View File

@ -43,7 +43,7 @@ void cleanup() {
}
static void init_cuda_devices() {
#ifdef HAS_CUDA
#ifdef IS_CUDA
if (cuda_archs.size()) return;
int count=0;
cudaGetDeviceCount(&count);

View File

@ -20,6 +20,7 @@
#include "utils/flags.h"
#include "fused_op.h"
#include "utils/str_utils.h"
JPU(header)
namespace jittor {
@ -204,6 +205,8 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
// compiler do not allowed filename too long
CHECK(cc_path.size());
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
string* src2 = (string*)&src;
JPU(op_compiler(jit_src_path, *src2, is_cuda_op));
#ifdef _WIN32
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
string jit_src_path2 = _to_winstr(jit_src_path);

View File

@ -107,7 +107,7 @@ void migrate_to_cpu(Var* var, Allocator* allocator) {
if (!use_cuda_managed_allocator) {
// must be a device allocator
Allocation a(allocator, var->size);
checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyDefault));
checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyDeviceToHost));
var->allocator->free(var->mem_ptr, var->size, var->allocation);
var->mem_ptr = a.ptr;
var->allocation = a.allocation;

View File

@ -62,7 +62,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
log << "\n=== display_memory_info ===\n";
log << "total_cpu_ram:" <<
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
log << "total_cuda_ram:" <<
log << "total_device_ram:" <<
FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n";
log << "hold_vars:" << hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars
@ -105,7 +105,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
auto total = a->used_memory + a->unused_memory;
all_total += total;
a->is_cuda() ? gpu_total += total : cpu_total += total;
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
log << "name:" << a->name() << "is_device:" << a->is_cuda()
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
@ -117,7 +117,7 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
auto total = a->used_memory + a->unused_memory;
all_total += total;
a->is_cuda() ? gpu_total += total : cpu_total += total;
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
log << "name:" << a->name() << "is_device:" << a->is_cuda()
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
@ -227,9 +227,9 @@ MemInfo::MemInfo() {
total_cuda_ram = 0;
#ifdef HAS_CUDA
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
total_cuda_ram = prop.totalGlobalMem;
size_t gpu_free = 0, _gpu_total = 0;
cudaMemGetInfo(&gpu_free, &_gpu_total);
total_cuda_ram = _gpu_total;
#endif
sigquit_callback.push_back(&meminfo_callback);
}

View File

@ -24,7 +24,7 @@ inline int get_device_count() {
} // jittor
#if CUDART_VERSION < 10000
#if defined(CUDART_VERSION) && CUDART_VERSION < 10000
#define _cudaLaunchHostFunc(a,b,c) \
cudaStreamAddCallback(a,b,c,0)
#define CUDA_HOST_FUNC_ARGS cudaStream_t stream, cudaError_t status, void*

View File

@ -17,7 +17,7 @@
namespace jittor {
#ifdef HAS_CUDA
#ifdef IS_CUDA
EXTERN_LIB void check_nan_float32(float32* ptr, int64 num);
EXTERN_LIB void check_nan_float64(float64* ptr, int64 num);
#endif
@ -28,7 +28,7 @@ bool check_nan(Var* v) {
v->input()->name() == string("empty") ||
v->input()->name() == string("setitem")))
return true;
#ifdef HAS_CUDA
#ifdef IS_CUDA
if (v->allocator->is_cuda()) {
if (v->dtype() == ns_float32) {
check_nan_float32((float32*)v->mem_ptr, v->num);

View File

@ -16,12 +16,12 @@
namespace jittor {
#if defined(__clang__)
#if __cplusplus < 201400L || defined(IS_ACL)
using string_view = string;
#elif defined(__clang__)
using std::string_view;
#elif defined(__GNUC__)
using std::experimental::string_view;
#elif __cplusplus < 201400L
using string_view = string;
#else
using std::string_view;
#endif

View File

@ -91,7 +91,7 @@ void CandidateOp::jit_run() {
int n=0;
// checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDeviceToHost));
y->set_shape({n});
exe.temp_allocator->free(np, 4, n_allocation);
exe.temp_allocator->free(maskp, xshape0, mask_allocation);

View File

@ -39,8 +39,13 @@ void CopyOp::run() {
auto x_ptr = x->mem_ptr;
auto y_ptr = outputs().front()->mem_ptr;
#ifdef HAS_CUDA
if (flags.get(NodeFlags::_cuda)) {
if (flags.get(NodeFlags::_cuda)) {
// TODO: check why cpu allocator in x
#ifdef IS_CUDA
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
#else
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0));
#endif
} else
#endif
{

View File

@ -121,13 +121,18 @@ void FetchOp::run() {
checkCudaErrors(cudaStreamWaitEvent(stream, event, 0));
new (&allocation) Allocation(&cuda_dual_allocator, v->size);
// mostly device to device
#if IS_CUDA
checkCudaErrors(cudaMemcpyAsync(
allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream));
#else
checkCudaErrors(cudaMemcpyAsync(
allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDeviceToDevice, stream));
#endif
auto host_ptr = cuda_dual_allocator.get_dual_allocation(
allocation.allocation).host_ptr;
// device to host
checkCudaErrors(cudaMemcpyAsync(
host_ptr, allocation.ptr, v->size, cudaMemcpyDefault, stream));
host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream));
allocation.ptr = host_ptr;
has_cuda_memcpy = true;
} else

View File

@ -312,7 +312,7 @@ void SetitemOp::jit_run() {
std::memcpy(op, ip, out->size);
#else
if (op != ip)
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDeviceToDevice, 0));
#endif
if (flags.get((NodeFlags::Flags(SetitemOp::_data_inplaced))) &&

View File

@ -230,7 +230,7 @@ void WhereOp::jit_run() {
int n=0;
// checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDeviceToHost));
@for(i, 0, NDIM, outs[@i]->set_shape({n});)
exe.temp_allocator->free(np, 4, n_allocation);
}

View File

@ -155,7 +155,7 @@ static void stat_peek_bandwidth(uint64 in, uint64 out, uint64 loop, uint64& peek
for (int i=0; i<warmup; i++)
#ifdef HAS_CUDA
if (use_cuda)
cudaMemcpyAsync(temp1.ptr, temp2.ptr, size, cudaMemcpyDefault, 0);
cudaMemcpyAsync(temp1.ptr, temp2.ptr, size, cudaMemcpyDeviceToDevice, 0);
else
#endif
std::memcpy(temp1.ptr, temp2.ptr, size);
@ -167,7 +167,7 @@ static void stat_peek_bandwidth(uint64 in, uint64 out, uint64 loop, uint64& peek
for (int i=0; i<loop; i++)
#ifdef HAS_CUDA
if (use_cuda)
cudaMemcpyAsync(temp1.ptr, temp2.ptr, size, cudaMemcpyDefault, 0);
cudaMemcpyAsync(temp1.ptr, temp2.ptr, size, cudaMemcpyDeviceToDevice, 0);
else
#endif
std::memcpy(temp1.ptr, temp2.ptr, size);

View File

@ -15,7 +15,7 @@
#include "misc/nano_string.h"
#include "misc/fast_shared_ptr.h"
#include "profiler/simple_profiler.h"
#ifdef HAS_CUDA
#ifdef IS_CUDA
#include "misc/cuda_flags.h"
#endif
@ -652,7 +652,7 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
[obj](typename T::R* result) {
// import numpy
string npstr="numpy";
#ifdef HAS_CUDA
#ifdef IS_CUDA
if (use_cuda) npstr="cupy";
#endif
@ -669,7 +669,7 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
PyTuple_SET_ITEM(args.obj, 0, np.release());
PyTuple_SET_ITEM(args.obj, 1, data.release());
#ifdef HAS_CUDA
#ifdef IS_CUDA
if (npstr=="cupy") {
PyObjHolder jt(PyImport_ImportModule("jittor"));
PyObjHolder pFunc(PyObject_GetAttrString(jt.obj,"numpy2cupy"));

View File

@ -239,4 +239,6 @@ std::ostream& operator<<(std::ostream& os, const Caster<T,To>& input) {
return os << ']';
}
#define JPU(x) ;
} // jittor

View File

@ -76,4 +76,122 @@ string replace(const string& a, const string& b, const string& c) {
return join(vs, c);
}
static inline bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
vector<string> token_split(const string& s) {
vector<string> ss;
if (!s.size()) return ss;
ss.push_back(string()+s[0]);
for (int i=1; i<s.size(); i++) {
if (isvar(s[i]) != isvar(s[i-1]))
ss.push_back("");
ss.back() += s[i];
}
return ss;
}
static void parse_reg(const string& src,
vector<string>& patterns,
vector<int>& arg_id) {
patterns.clear();
arg_id.clear();
patterns.push_back("");
for (int j=0; j<src.size(); j++) {
if (src[j] == '$') {
j++;
arg_id.push_back(src[j]-'0');
patterns.push_back("");
continue;
}
patterns.back() += src[j];
}
}
void token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
src.at(src.size()-2) != '$') << "illegal src:" << src;
vector<string> patterns;
vector<int> arg_id;
vector<string> patterns2;
vector<int> arg_id2;
unordered_map<int, string> args;
parse_reg(src, patterns, arg_id);
parse_reg(dst, patterns2, arg_id2);
int start_i, start_pos, end_i, end_pos;
int c_i = i, c_pos = 0;
int match_i, match_pos;
string c_arg;
auto match = [&](int c_i, int c_pos, const string& pat) -> bool {
for (int i=0; i<pat.size(); i++) {
if (tokens[c_i][c_pos] != pat[i])
return false;
c_pos ++;
if (c_pos >= tokens[c_i].size()) {
c_pos = 0;
c_i ++;
if (c_i >= tokens.size())
return false;
}
}
match_i = c_i;
match_pos = c_pos;
return true;
};
for (int j=0; j<patterns.size(); j++) {
int ok = 0;
while (c_i < tokens.size()) {
while (c_pos < tokens[c_i].size()) {
if (match(c_i, c_pos, patterns[j])) {
ok = 1;
break;
}
c_arg += tokens[c_i][c_pos];
c_pos ++;
}
if (ok) break;
c_i ++;
c_pos = 0;
}
CHECK(ok) << "Pattern not match:" << patterns[j] << j;
if (j == 0) {
start_i = c_i;
start_pos = c_pos;
}
if (j) {
args[arg_id[j-1]] = c_arg;
}
c_arg = "";
c_i = match_i;
c_pos = match_pos;
if (j == patterns.size()-1) {
end_i = c_i;
end_pos = c_pos;
}
}
string new_src;
for (int j=0; j<patterns2.size(); j++) {
if (j) new_src += args[arg_id2.at(j-1)];
new_src += patterns2[j];
}
if (start_i == end_i) {
tokens[start_i] = tokens[start_i].substr(0, start_pos) +
new_src + tokens[start_i].substr(end_pos);
} else {
tokens[start_i] = tokens[start_i].substr(0, start_pos)
+ new_src;
tokens[end_i] = tokens[end_i].substr(end_pos);
for (int j=start_i+1; j<end_i; j++)
tokens[j] = "";
}
}
string token_replace(const string& s, const string& src, const string& dst) {
vector<string> ss{s};
token_replace(ss, 0, src, dst);
return join(ss, "");
}
} // jittor

View File

@ -33,4 +33,10 @@ string replace(const string& a, const string& b, const string& c);
string join(const vector<string>& vs, const string& x);
vector<string> token_split(const string& s);
void token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
string token_replace(const string& s, const string& src, const string& dst);
} // jittor

View File

@ -0,0 +1,31 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
from .test_core import expect_error
import numpy as np
@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
class TestACL(unittest.TestCase):
@jt.flag_scope(use_acl=1)
def test_array(self):
print("use_acl", jt.flags.use_acl)
a = jt.array([1,2,3])
np.testing.assert_allclose(a.numpy(), [1,2,3])
@jt.flag_scope(use_acl=1)
def test_add(self):
a = jt.array([1,2,3])
b = a+a
np.testing.assert_allclose(b.numpy(), [2,4,6])
def test_meminfo(self):
jt.display_memory_info()
if __name__ == "__main__":
unittest.main()

View File

@ -19,6 +19,7 @@ import time
from ctypes import cdll
import shutil
import urllib.request
import ctypes
if platform.system() == 'Darwin':
mp.set_start_method('fork')
@ -386,6 +387,22 @@ def env_or_find(name, bname, silent=False):
return path
return find_exe(bname, silent=silent)
def env_or_try_find(name, bname):
if name in os.environ:
path = os.environ[name]
if path != "":
version = get_version(path)
LOG.i(f"Found {bname}{version} at {path}")
return path
return try_find_exe(bname)
def try_find_exe(*args):
try:
return find_exe(*args)
except:
LOG.v(f"{args[0]} not found.")
return ""
def get_cc_type(cc_path):
bname = os.path.basename(cc_path)
if "clang" in bname: return "clang"
@ -526,3 +543,65 @@ if os.name == 'nt':
os.environ["PATH"] = path+';'+os.environ["PATH"]
if hasattr(os, "add_dll_directory"):
os.add_dll_directory(path)
backends = []
def add_backend(mod):
backends.append(mod)
def compile_module(source, flags):
tmp_path = os.path.join(cache_path, "tmp")
os.makedirs(tmp_path, exist_ok=True)
hash = "hash_" + get_str_hash(source)
so = get_py3_extension_suffix()
header_name = os.path.join(tmp_path, hash+".h")
source_name = os.path.join(tmp_path, hash+".cc")
lib_name = hash+so
with open(header_name, "w", encoding="utf8") as f:
f.write(source)
from jittor.pyjt_compiler import compile_single
ok = compile_single(header_name, source_name)
assert ok, "no pyjt interface found"
entry_src = f'''
static void init_module(PyModuleDef* mdef, PyObject* m) {{
mdef->m_doc = "generated py jittor_utils.compile_module";
jittor::pyjt_def_{hash}(m);
}}
PYJT_MODULE_INIT({hash});
'''
with open(source_name, "r", encoding="utf8") as f:
src = f.read()
with open(source_name, "w", encoding="utf8") as f:
f.write(src + entry_src)
jittor_path = os.path.join(os.path.dirname(__file__), "..", "jittor")
jittor_path = os.path.abspath(jittor_path)
do_compile([f"\"{cc_path}\" \"{source_name}\" \"{jittor_path}/src/pyjt/py_arg_printer.cc\" {flags} -o \"{cache_path+'/'+lib_name}\" ",
cache_path, jittor_path])
with import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
exec(f"import {hash}")
mod = locals()[hash]
return mod
def process_jittor_source(device_type, callback):
import jittor.compiler as compiler
import shutil
djittor = device_type + "_jittor"
djittor_path = os.path.join(compiler.cache_path, djittor)
os.makedirs(djittor_path, exist_ok=True)
for root, dir, files in os.walk(compiler.jittor_path):
root2 = root.replace(compiler.jittor_path, djittor_path)
os.makedirs(root2, exist_ok=True)
for name in files:
fname = os.path.join(root, name)
fname2 = os.path.join(root2, name)
if fname.endswith(".h") or fname.endswith(".cc"):
with open(fname, 'r', encoding="utf8") as f:
src = f.read()
src = callback(src, name, {"fname":fname, "fname2":fname2})
with open(fname2, 'w', encoding="utf8") as f:
f.write(src)
else:
shutil.copy(fname, fname2)
compiler.cc_flags = compiler.cc_flags.replace(compiler.jittor_path, djittor_path) + f" -I\"{djittor_path}/extern/cuda/inc\" "
compiler.jittor_path = djittor_path