mirror of https://github.com/Jittor/Jittor
polish backend
This commit is contained in:
parent
335a2e5c1d
commit
003cdf6b16
|
@ -1889,3 +1889,9 @@ from . import sparse
|
|||
from . import optim
|
||||
from . import dataset
|
||||
from . import init
|
||||
|
||||
import jittor_utils
|
||||
|
||||
for backend in jittor_utils.backends:
|
||||
if hasattr(backend, "post_process"):
|
||||
backend.post_process()
|
||||
|
|
|
@ -263,6 +263,8 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
|
||||
# cuda11 prefer cudnn 8
|
||||
nvcc_version = get_int_version(nvcc_path)
|
||||
if has_corex:
|
||||
nvcc_version = (10,2,89)
|
||||
prefer_version = ()
|
||||
if nvcc_version[0] == 11:
|
||||
prefer_version = ("8",)
|
||||
|
|
|
@ -234,7 +234,7 @@ def gen_jit_flags():
|
|||
jit_declares.append(f"DECLARE_FLAG({type}, {name});")
|
||||
alias = []
|
||||
if name == "use_cuda":
|
||||
alias = ["use_device", "use_acl", "use_rocm"]
|
||||
alias = ["use_device", "use_acl", "use_rocm", "use_corex"]
|
||||
elif name == "auto_mixed_precision_level":
|
||||
alias = ["amp_level"]
|
||||
get_names = ",".join(["__get__"+a for a in [name]+alias])
|
||||
|
@ -867,7 +867,7 @@ def check_cuda():
|
|||
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
|
||||
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
|
||||
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
|
||||
has_cuda = 1
|
||||
is_cuda = has_cuda = 1
|
||||
|
||||
def check_cache_compile():
|
||||
files = [
|
||||
|
@ -1181,7 +1181,7 @@ check_cache_compile()
|
|||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||
|
||||
# check cuda
|
||||
has_cuda = 0
|
||||
is_cuda = has_cuda = 0
|
||||
check_cuda()
|
||||
nvcc_flags = os.environ.get("nvcc_flags", "")
|
||||
if has_cuda:
|
||||
|
@ -1233,6 +1233,8 @@ from .extern.acl import acl_compiler
|
|||
jit_utils.add_backend(acl_compiler)
|
||||
from .extern.rocm import rocm_compiler
|
||||
jit_utils.add_backend(rocm_compiler)
|
||||
from .extern.corex import corex_compiler
|
||||
jit_utils.add_backend(corex_compiler)
|
||||
|
||||
for mod in jit_utils.backends:
|
||||
mod.check()
|
||||
|
@ -1346,7 +1348,7 @@ with jit_utils.import_scope(import_flags):
|
|||
|
||||
flags = core.Flags()
|
||||
|
||||
if has_cuda:
|
||||
if has_cuda and is_cuda:
|
||||
nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " "
|
||||
nvcc_flags += convert_nvcc_flags(cc_flags)
|
||||
nvcc_version = list(jit_utils.get_int_version(nvcc_path))
|
||||
|
|
|
@ -9,11 +9,13 @@ from jittor_utils import env_or_try_find
|
|||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
|
||||
has_acl = 0
|
||||
cc_flags = ""
|
||||
tikcc_path = env_or_try_find('tikcc_path', 'tikcc')
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
compiler.has_acl = has_acl
|
||||
|
||||
def install():
|
||||
import jittor.compiler as compiler
|
||||
|
@ -64,3 +66,7 @@ def check():
|
|||
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
|
||||
return True
|
||||
|
||||
def post_process():
|
||||
if has_acl:
|
||||
from jittor import pool
|
||||
pool.pool_use_code_op = False
|
|
@ -0,0 +1,85 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2022 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os
|
||||
from jittor_utils import env_or_try_find
|
||||
import jittor_utils
|
||||
import ctypes
|
||||
import glob
|
||||
import jittor.compiler as compiler
|
||||
|
||||
has_corex = 0
|
||||
cc_flags = ""
|
||||
compiler.has_corex = has_corex
|
||||
|
||||
def install():
|
||||
import jittor.compiler as compiler
|
||||
global has_corex, cc_flags
|
||||
acl_compiler_home = os.path.dirname(__file__)
|
||||
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
|
||||
jittor_utils.LOG.i("COREX detected")
|
||||
|
||||
mod = jittor_utils.compile_module('''
|
||||
#include "common.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
// @pyjt(process)
|
||||
string process_acl(const string& src, const string& name, const map<string,string>& kargs) {
|
||||
auto new_src = src;
|
||||
new_src = replace(new_src, "helper_cuda.h", "../inc/helper_cuda.h");
|
||||
if (name == "string_view_map.h")
|
||||
new_src = replace(new_src, "using std::string_view;", "using string_view = string;");
|
||||
if (name == "nan_checker.cu")
|
||||
new_src = replace(new_src, "__trap()", "assert(0)");
|
||||
if (name == "jit_compiler.cc") {
|
||||
// remove asm tuner
|
||||
new_src = token_replace_all(new_src, "cmd = python_path$1;", "");
|
||||
new_src = token_replace_all(new_src,
|
||||
"if (is_cuda_op && $1 != string::npos)",
|
||||
"if (is_cuda_op)");
|
||||
}
|
||||
return new_src;
|
||||
}
|
||||
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)
|
||||
jittor_utils.process_jittor_source("corex", mod.process)
|
||||
# def nvcc_flags_to_corex(nvcc_flags):
|
||||
# nvcc_flags = nvcc_flags.replace("--cudart=shared", "")
|
||||
# nvcc_flags = nvcc_flags.replace("--cudart=shared", "")
|
||||
|
||||
has_corex = 1
|
||||
compiler.has_corex = has_corex
|
||||
corex_home = "/usr/local/corex"
|
||||
compiler.nvcc_path = corex_home + "/bin/clang++"
|
||||
compiler.cc_path = compiler.nvcc_path
|
||||
compiler.cc_flags = compiler.cc_flags.replace("-fopenmp", "")
|
||||
# compiler.nvcc_flags = cc_flags_to_corex(compiler.cc_flags)
|
||||
compiler.nvcc_flags = compiler.cc_flags + " -x cu -Ofast -DNO_ATOMIC64 "
|
||||
compiler.convert_nvcc_flags = lambda x:x
|
||||
compiler.is_cuda = 0
|
||||
os.environ["use_cutt"] = "0"
|
||||
compiler.cc_type = "clang"
|
||||
|
||||
|
||||
def install_extern():
|
||||
return False
|
||||
|
||||
|
||||
def check():
|
||||
global has_corex, cc_flags
|
||||
if os.path.isdir("/usr/local/corex"):
|
||||
try:
|
||||
install()
|
||||
except Exception as e:
|
||||
jittor_utils.LOG.w(f"load COREX failed, exception: {e}")
|
||||
has_corex = 0
|
||||
if not has_corex: return False
|
||||
return True
|
||||
|
||||
def post_process():
|
||||
if not has_corex: return
|
||||
import jittor.compiler as compiler
|
||||
compiler.flags.cc_flags = compiler.flags.cc_flags.replace("-fopenmp", "")
|
|
@ -11,6 +11,7 @@ import tarfile
|
|||
|
||||
import jittor_utils
|
||||
from jittor_utils import env_or_try_find, run_cmd, cache_path, LOG
|
||||
import jittor.compiler as compiler
|
||||
|
||||
|
||||
has_rocm = 0
|
||||
|
@ -18,6 +19,7 @@ cc_flags = ""
|
|||
hipcc_path = env_or_try_find('hipcc_path', 'hipcc')
|
||||
rocm_home = ""
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
compiler.has_rocm = has_rocm
|
||||
|
||||
|
||||
def check_gcc_use_cxx11_abi():
|
||||
|
|
|
@ -15,6 +15,8 @@ from jittor import init, Module
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
pool_use_code_op = True
|
||||
|
||||
class Pool(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
|
||||
assert dilation == None
|
||||
|
@ -40,7 +42,7 @@ class Pool(Module):
|
|||
w = (W+self.padding[1]*2-self.kernel_size[1] + self.stride[1] - 1)//self.stride[1]+1
|
||||
use_code_op = self.op in ['maximum', 'minimum', 'mean']
|
||||
|
||||
if use_code_op:
|
||||
if use_code_op and pool_use_code_op:
|
||||
if self.op == 'mean':
|
||||
if self.count_include_pad:
|
||||
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]};"
|
||||
|
@ -218,7 +220,7 @@ class Pool3d(Module):
|
|||
w = (W+self.padding[2]*2-self.kernel_size[2] + self.stride[2] - 1)//self.stride[2]+1
|
||||
use_code_op = self.op in ['maximum', 'minimum', 'mean']
|
||||
|
||||
if use_code_op:
|
||||
if use_code_op and pool_use_code_op:
|
||||
if self.op == 'mean':
|
||||
if self.count_include_pad:
|
||||
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]*self.kernel_size[2]};"
|
||||
|
|
|
@ -212,7 +212,7 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
// compiler do not allowed filename too long
|
||||
CHECK(cc_path.size());
|
||||
string jit_src_path;
|
||||
if(is_cuda_op && extra_flags.find("-dc") != string::npos)
|
||||
if (is_cuda_op && extra_flags.find("-dc") != string::npos)
|
||||
jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cu");
|
||||
else
|
||||
jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
|
||||
|
|
|
@ -55,10 +55,12 @@ inline float cuda_atomic_max(float* a, float b) {
|
|||
return orderedIntToFloat(atomicMax((int *)a, floatToOrderedInt(b)));
|
||||
}
|
||||
|
||||
#ifndef NO_ATOMIC64
|
||||
template<> __device__
|
||||
inline double cuda_atomic_max(double* a, double b) {
|
||||
return orderedIntToFloat(atomicMax((long long *)a, floatToOrderedInt(b)));
|
||||
}
|
||||
#endif
|
||||
|
||||
template<class T> __device__
|
||||
T cuda_atomic_min(T* a, T b) {
|
||||
|
@ -70,10 +72,12 @@ inline float cuda_atomic_min(float* a, float b) {
|
|||
return orderedIntToFloat(atomicMin((int *)a, floatToOrderedInt(b)));
|
||||
}
|
||||
|
||||
#ifndef NO_ATOMIC64
|
||||
template<> __device__
|
||||
inline double cuda_atomic_min(double* a, double b) {
|
||||
return orderedIntToFloat(atomicMin((long long *)a, floatToOrderedInt(b)));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class T> struct int_mapper {
|
||||
typedef T src;
|
||||
|
|
|
@ -374,7 +374,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
|||
// VarHolder
|
||||
struct VarHolder;
|
||||
EXTERN_LIB PyHeapTypeObject PyjtVarHolder;
|
||||
namespace jit_op_maker { EXTERN_LIB VarHolder* array_(ArrayArgs&& args); }
|
||||
namespace jit_op_maker {
|
||||
EXTERN_LIB VarHolder* array_(ArrayArgs&&);
|
||||
EXTERN_LIB VarHolder* array__(PyObject* obj);
|
||||
}
|
||||
DEF_IS(VarHolder*, bool) is_type(PyObject* obj) {
|
||||
return Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
|
||||
is_type<ArrayArgs>(obj);
|
||||
|
@ -397,8 +400,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj) {
|
|||
DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holder) {
|
||||
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type)
|
||||
return GET_RAW_PTR(VarHolder, obj);
|
||||
auto args = from_py_object<ArrayArgs>(obj);
|
||||
holder.reset(jit_op_maker::array_(move(args)));
|
||||
holder.reset(jit_op_maker::array__(obj));
|
||||
return holder.get();
|
||||
}
|
||||
|
||||
|
|
|
@ -40,32 +40,32 @@ typename std::enable_if<0<nbyte,void>::type
|
|||
vload(T* __restrict__ a, T* __restrict__ b) {
|
||||
if (nbyte<=0) return;
|
||||
if (nbyte>=16) {
|
||||
auto __restrict__ aa = (float4* __restrict__)a;
|
||||
auto __restrict__ bb = (float4* __restrict__)b;
|
||||
auto* __restrict__ aa = (float4* __restrict__)a;
|
||||
auto* __restrict__ bb = (float4* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-16>(aa+1, bb+1);
|
||||
}
|
||||
if (nbyte>=8) {
|
||||
auto __restrict__ aa = (float2* __restrict__)a;
|
||||
auto __restrict__ bb = (float2* __restrict__)b;
|
||||
auto* __restrict__ aa = (float2* __restrict__)a;
|
||||
auto* __restrict__ bb = (float2* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-8>(aa+1, bb+1);
|
||||
}
|
||||
if (nbyte>=4) {
|
||||
auto __restrict__ aa = (float* __restrict__)a;
|
||||
auto __restrict__ bb = (float* __restrict__)b;
|
||||
auto* __restrict__ aa = (float* __restrict__)a;
|
||||
auto* __restrict__ bb = (float* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-4>(aa+1, bb+1);
|
||||
}
|
||||
if (nbyte>=2) {
|
||||
auto __restrict__ aa = (__half* __restrict__)a;
|
||||
auto __restrict__ bb = (__half* __restrict__)b;
|
||||
auto* __restrict__ aa = (__half* __restrict__)a;
|
||||
auto* __restrict__ bb = (__half* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-2>(aa+1, bb+1);
|
||||
}
|
||||
if (nbyte>=1) {
|
||||
auto __restrict__ aa = (int8_t* __restrict__)a;
|
||||
auto __restrict__ bb = (int8_t* __restrict__)b;
|
||||
auto* __restrict__ aa = (int8_t* __restrict__)a;
|
||||
auto* __restrict__ bb = (int8_t* __restrict__)b;
|
||||
aa[0] = bb[0];
|
||||
return vload<nbyte-1>(aa+1, bb+1);
|
||||
}
|
||||
|
|
|
@ -81,6 +81,21 @@ class TestACL(unittest.TestCase):
|
|||
np.testing.assert_allclose(y.numpy(), ny)
|
||||
# y.sync(True)
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_max(self):
|
||||
x = jt.rand(10,10)
|
||||
y = x.max(1).data
|
||||
ny = x.data.max(1)
|
||||
np.testing.assert_allclose(y, ny)
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_resnet(self):
|
||||
from jittor.models import resnet50
|
||||
net = resnet50()
|
||||
x = jt.rand(2,3,224,224)
|
||||
y = net(x)
|
||||
y.sync()
|
||||
|
||||
|
||||
|
||||
def matmul(a, b):
|
||||
|
|
Loading…
Reference in New Issue