polish backend

This commit is contained in:
Dun Liang 2022-10-26 14:29:53 +08:00
parent 335a2e5c1d
commit 003cdf6b16
12 changed files with 146 additions and 20 deletions

View File

@ -1889,3 +1889,9 @@ from . import sparse
from . import optim
from . import dataset
from . import init
import jittor_utils
for backend in jittor_utils.backends:
if hasattr(backend, "post_process"):
backend.post_process()

View File

@ -263,6 +263,8 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h")
# cuda11 prefer cudnn 8
nvcc_version = get_int_version(nvcc_path)
if has_corex:
nvcc_version = (10,2,89)
prefer_version = ()
if nvcc_version[0] == 11:
prefer_version = ("8",)

View File

@ -234,7 +234,7 @@ def gen_jit_flags():
jit_declares.append(f"DECLARE_FLAG({type}, {name});")
alias = []
if name == "use_cuda":
alias = ["use_device", "use_acl", "use_rocm"]
alias = ["use_device", "use_acl", "use_rocm", "use_corex"]
elif name == "auto_mixed_precision_level":
alias = ["amp_level"]
get_names = ",".join(["__get__"+a for a in [name]+alias])
@ -867,7 +867,7 @@ def check_cuda():
cc_flags += f" -lcudart -L\"{cuda_lib}\" "
# ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags)
ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags)
has_cuda = 1
is_cuda = has_cuda = 1
def check_cache_compile():
files = [
@ -1181,7 +1181,7 @@ check_cache_compile()
LOG.v(f"Get cache_compile: {jit_utils.cc}")
# check cuda
has_cuda = 0
is_cuda = has_cuda = 0
check_cuda()
nvcc_flags = os.environ.get("nvcc_flags", "")
if has_cuda:
@ -1233,6 +1233,8 @@ from .extern.acl import acl_compiler
jit_utils.add_backend(acl_compiler)
from .extern.rocm import rocm_compiler
jit_utils.add_backend(rocm_compiler)
from .extern.corex import corex_compiler
jit_utils.add_backend(corex_compiler)
for mod in jit_utils.backends:
mod.check()
@ -1346,7 +1348,7 @@ with jit_utils.import_scope(import_flags):
flags = core.Flags()
if has_cuda:
if has_cuda and is_cuda:
nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " "
nvcc_flags += convert_nvcc_flags(cc_flags)
nvcc_version = list(jit_utils.get_int_version(nvcc_path))

View File

@ -9,11 +9,13 @@ from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
has_acl = 0
cc_flags = ""
tikcc_path = env_or_try_find('tikcc_path', 'tikcc')
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
compiler.has_acl = has_acl
def install():
import jittor.compiler as compiler
@ -64,3 +66,7 @@ def check():
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14","")
return True
def post_process():
if has_acl:
from jittor import pool
pool.pool_use_code_op = False

View File

@ -0,0 +1,85 @@
# ***************************************************************
# Copyright (c) 2022 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os
from jittor_utils import env_or_try_find
import jittor_utils
import ctypes
import glob
import jittor.compiler as compiler
has_corex = 0
cc_flags = ""
compiler.has_corex = has_corex
def install():
import jittor.compiler as compiler
global has_corex, cc_flags
acl_compiler_home = os.path.dirname(__file__)
cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True))
jittor_utils.LOG.i("COREX detected")
mod = jittor_utils.compile_module('''
#include "common.h"
#include "utils/str_utils.h"
namespace jittor {
// @pyjt(process)
string process_acl(const string& src, const string& name, const map<string,string>& kargs) {
auto new_src = src;
new_src = replace(new_src, "helper_cuda.h", "../inc/helper_cuda.h");
if (name == "string_view_map.h")
new_src = replace(new_src, "using std::string_view;", "using string_view = string;");
if (name == "nan_checker.cu")
new_src = replace(new_src, "__trap()", "assert(0)");
if (name == "jit_compiler.cc") {
// remove asm tuner
new_src = token_replace_all(new_src, "cmd = python_path$1;", "");
new_src = token_replace_all(new_src,
"if (is_cuda_op && $1 != string::npos)",
"if (is_cuda_op)");
}
return new_src;
}
}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags)
jittor_utils.process_jittor_source("corex", mod.process)
# def nvcc_flags_to_corex(nvcc_flags):
# nvcc_flags = nvcc_flags.replace("--cudart=shared", "")
# nvcc_flags = nvcc_flags.replace("--cudart=shared", "")
has_corex = 1
compiler.has_corex = has_corex
corex_home = "/usr/local/corex"
compiler.nvcc_path = corex_home + "/bin/clang++"
compiler.cc_path = compiler.nvcc_path
compiler.cc_flags = compiler.cc_flags.replace("-fopenmp", "")
# compiler.nvcc_flags = cc_flags_to_corex(compiler.cc_flags)
compiler.nvcc_flags = compiler.cc_flags + " -x cu -Ofast -DNO_ATOMIC64 "
compiler.convert_nvcc_flags = lambda x:x
compiler.is_cuda = 0
os.environ["use_cutt"] = "0"
compiler.cc_type = "clang"
def install_extern():
return False
def check():
global has_corex, cc_flags
if os.path.isdir("/usr/local/corex"):
try:
install()
except Exception as e:
jittor_utils.LOG.w(f"load COREX failed, exception: {e}")
has_corex = 0
if not has_corex: return False
return True
def post_process():
if not has_corex: return
import jittor.compiler as compiler
compiler.flags.cc_flags = compiler.flags.cc_flags.replace("-fopenmp", "")

View File

@ -11,6 +11,7 @@ import tarfile
import jittor_utils
from jittor_utils import env_or_try_find, run_cmd, cache_path, LOG
import jittor.compiler as compiler
has_rocm = 0
@ -18,6 +19,7 @@ cc_flags = ""
hipcc_path = env_or_try_find('hipcc_path', 'hipcc')
rocm_home = ""
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
compiler.has_rocm = has_rocm
def check_gcc_use_cxx11_abi():

View File

@ -15,6 +15,8 @@ from jittor import init, Module
import numpy as np
import math
pool_use_code_op = True
class Pool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
assert dilation == None
@ -40,7 +42,7 @@ class Pool(Module):
w = (W+self.padding[1]*2-self.kernel_size[1] + self.stride[1] - 1)//self.stride[1]+1
use_code_op = self.op in ['maximum', 'minimum', 'mean']
if use_code_op:
if use_code_op and pool_use_code_op:
if self.op == 'mean':
if self.count_include_pad:
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]};"
@ -218,7 +220,7 @@ class Pool3d(Module):
w = (W+self.padding[2]*2-self.kernel_size[2] + self.stride[2] - 1)//self.stride[2]+1
use_code_op = self.op in ['maximum', 'minimum', 'mean']
if use_code_op:
if use_code_op and pool_use_code_op:
if self.op == 'mean':
if self.count_include_pad:
count = f"int count = {self.kernel_size[0]*self.kernel_size[1]*self.kernel_size[2]};"

View File

@ -212,7 +212,7 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
// compiler do not allowed filename too long
CHECK(cc_path.size());
string jit_src_path;
if(is_cuda_op && extra_flags.find("-dc") != string::npos)
if (is_cuda_op && extra_flags.find("-dc") != string::npos)
jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cu");
else
jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");

View File

@ -55,10 +55,12 @@ inline float cuda_atomic_max(float* a, float b) {
return orderedIntToFloat(atomicMax((int *)a, floatToOrderedInt(b)));
}
#ifndef NO_ATOMIC64
template<> __device__
inline double cuda_atomic_max(double* a, double b) {
return orderedIntToFloat(atomicMax((long long *)a, floatToOrderedInt(b)));
}
#endif
template<class T> __device__
T cuda_atomic_min(T* a, T b) {
@ -70,10 +72,12 @@ inline float cuda_atomic_min(float* a, float b) {
return orderedIntToFloat(atomicMin((int *)a, floatToOrderedInt(b)));
}
#ifndef NO_ATOMIC64
template<> __device__
inline double cuda_atomic_min(double* a, double b) {
return orderedIntToFloat(atomicMin((long long *)a, floatToOrderedInt(b)));
}
#endif
template <class T> struct int_mapper {
typedef T src;

View File

@ -374,7 +374,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
// VarHolder
struct VarHolder;
EXTERN_LIB PyHeapTypeObject PyjtVarHolder;
namespace jit_op_maker { EXTERN_LIB VarHolder* array_(ArrayArgs&& args); }
namespace jit_op_maker {
EXTERN_LIB VarHolder* array_(ArrayArgs&&);
EXTERN_LIB VarHolder* array__(PyObject* obj);
}
DEF_IS(VarHolder*, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
is_type<ArrayArgs>(obj);
@ -397,8 +400,7 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj) {
DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holder) {
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type)
return GET_RAW_PTR(VarHolder, obj);
auto args = from_py_object<ArrayArgs>(obj);
holder.reset(jit_op_maker::array_(move(args)));
holder.reset(jit_op_maker::array__(obj));
return holder.get();
}

View File

@ -40,32 +40,32 @@ typename std::enable_if<0<nbyte,void>::type
vload(T* __restrict__ a, T* __restrict__ b) {
if (nbyte<=0) return;
if (nbyte>=16) {
auto __restrict__ aa = (float4* __restrict__)a;
auto __restrict__ bb = (float4* __restrict__)b;
auto* __restrict__ aa = (float4* __restrict__)a;
auto* __restrict__ bb = (float4* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-16>(aa+1, bb+1);
}
if (nbyte>=8) {
auto __restrict__ aa = (float2* __restrict__)a;
auto __restrict__ bb = (float2* __restrict__)b;
auto* __restrict__ aa = (float2* __restrict__)a;
auto* __restrict__ bb = (float2* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-8>(aa+1, bb+1);
}
if (nbyte>=4) {
auto __restrict__ aa = (float* __restrict__)a;
auto __restrict__ bb = (float* __restrict__)b;
auto* __restrict__ aa = (float* __restrict__)a;
auto* __restrict__ bb = (float* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-4>(aa+1, bb+1);
}
if (nbyte>=2) {
auto __restrict__ aa = (__half* __restrict__)a;
auto __restrict__ bb = (__half* __restrict__)b;
auto* __restrict__ aa = (__half* __restrict__)a;
auto* __restrict__ bb = (__half* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-2>(aa+1, bb+1);
}
if (nbyte>=1) {
auto __restrict__ aa = (int8_t* __restrict__)a;
auto __restrict__ bb = (int8_t* __restrict__)b;
auto* __restrict__ aa = (int8_t* __restrict__)a;
auto* __restrict__ bb = (int8_t* __restrict__)b;
aa[0] = bb[0];
return vload<nbyte-1>(aa+1, bb+1);
}

View File

@ -81,6 +81,21 @@ class TestACL(unittest.TestCase):
np.testing.assert_allclose(y.numpy(), ny)
# y.sync(True)
@jt.flag_scope(use_acl=1)
def test_max(self):
x = jt.rand(10,10)
y = x.max(1).data
ny = x.data.max(1)
np.testing.assert_allclose(y, ny)
@jt.flag_scope(use_acl=1)
def test_resnet(self):
from jittor.models import resnet50
net = resnet50()
x = jt.rand(2,3,224,224)
y = net(x)
y.sync()
def matmul(a, b):