mirror of https://github.com/Jittor/Jittor
Support AMD ROCm and HIP backend
This commit is contained in:
parent
46a03098f9
commit
e316f511c3
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.4.14'
|
||||
__version__ = '1.3.4.15'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -588,3 +588,8 @@ except Exception as e:
|
|||
LOG.w("MKL install failed, msg:", e)
|
||||
|
||||
setup_cuda_extern()
|
||||
|
||||
# install backend extern library
|
||||
for mod in jit_utils.backends:
|
||||
if mod.install_extern():
|
||||
break
|
|
@ -133,7 +133,7 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="",
|
|||
nflags = oflags
|
||||
cmd = f"{cm(input)} {nflags} {lto_flags} -c -o {cm(obj_file)}"
|
||||
if input.endswith(".cu"):
|
||||
if has_cuda:
|
||||
if has_cuda or has_rocm:
|
||||
cmd = f"\"{nvcc_path}\" {cuda_flags} {cmd}"
|
||||
cmd = convert_nvcc_flags(fix_cl_flags(cmd))
|
||||
else:
|
||||
|
@ -143,8 +143,10 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="",
|
|||
cmd = fix_cl_flags(cmd)
|
||||
if "nan_checker" in input:
|
||||
# nan checker needs to disable fast_math
|
||||
cmd = cmd.replace("--use_fast_math", "")
|
||||
cmd = cmd.replace("-Ofast", "-O2")
|
||||
if "--use_fast_math" in cmd:
|
||||
cmd = cmd.replace("--use_fast_math", "")
|
||||
if "-Ofast" in cmd:
|
||||
cmd = cmd.replace("-Ofast", "-O2")
|
||||
cmds.append(cmd)
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||
obj_files += ex_obj_files
|
||||
|
@ -232,7 +234,7 @@ def gen_jit_flags():
|
|||
jit_declares.append(f"DECLARE_FLAG({type}, {name});")
|
||||
alias = []
|
||||
if name == "use_cuda":
|
||||
alias = ["use_device", "use_acl"]
|
||||
alias = ["use_device", "use_acl", "use_rocm"]
|
||||
elif name == "auto_mixed_precision_level":
|
||||
alias = ["amp_level"]
|
||||
get_names = ",".join(["__get__"+a for a in [name]+alias])
|
||||
|
@ -1229,6 +1231,8 @@ if has_cuda:
|
|||
# from .acl_compiler import check_acl
|
||||
from .extern.acl import acl_compiler
|
||||
jit_utils.add_backend(acl_compiler)
|
||||
from .extern.rocm import rocm_compiler
|
||||
jit_utils.add_backend(rocm_compiler)
|
||||
|
||||
for mod in jit_utils.backends:
|
||||
if mod.check():
|
||||
|
@ -1252,7 +1256,7 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
|||
# 3. op_utils
|
||||
# 4. other
|
||||
files2 = pyjt_gen_src
|
||||
ext_args = 'c[cu]' if has_cuda else 'cc'
|
||||
ext_args = 'c[cu]' if has_cuda or has_rocm else 'cc'
|
||||
files4 = glob.glob(jittor_path+"/src/**/*."+ext_args, recursive=True)
|
||||
files4 = [ f[len(jittor_path)+1:] for f in files4 ]
|
||||
# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
|
||||
|
|
|
@ -35,6 +35,10 @@ string process_acl(const string& src, const string& name, const map<string,strin
|
|||
has_acl = 1
|
||||
|
||||
|
||||
def install_extern():
|
||||
return False
|
||||
|
||||
|
||||
def check():
|
||||
import jittor.compiler as compiler
|
||||
global has_acl, cc_flags
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,139 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os
|
||||
import ctypes
|
||||
import glob
|
||||
|
||||
import jittor_utils
|
||||
from jittor_utils import env_or_try_find, run_cmd, cache_path, LOG
|
||||
|
||||
|
||||
has_rocm = 0
|
||||
cc_flags = ""
|
||||
hipcc_path = env_or_try_find('hipcc_path', 'hipcc')
|
||||
rocm_home = ""
|
||||
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
|
||||
|
||||
|
||||
def install_rocm_jittor_core():
|
||||
import jittor.compiler as compiler
|
||||
global has_rocm, cc_flags, rocm_home
|
||||
rocm_home = run_cmd("hipconfig -R")
|
||||
rocm_version = run_cmd("hipconfig -v")
|
||||
|
||||
rocm_compiler_home = os.path.dirname(__file__)
|
||||
|
||||
rocm_cache_path = os.path.join(rocm_compiler_home, "rocm_cache.o")
|
||||
rocm_cache_gz_path = os.path.join(rocm_compiler_home, "rocm_cache.gz")
|
||||
if not os.path.exists(rocm_cache_path) and os.path.exists(rocm_cache_gz_path):
|
||||
import gzip
|
||||
with gzip.open(rocm_cache_gz_path, "rb") as f:
|
||||
data = f.read()
|
||||
with open(rocm_cache_path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
cc_files = sorted(glob.glob(rocm_compiler_home + "/**/*.cc", recursive=True))
|
||||
o_files = sorted(glob.glob(rocm_compiler_home + "/**/*.o", recursive=True))
|
||||
cc_flags += f" -DHAS_CUDA -DIS_ROCM -I{rocm_compiler_home} "
|
||||
cc_flags += " " + run_cmd("hipconfig -C") + " "
|
||||
cc_flags += ' -L"' + os.path.join(rocm_home, "lib") + '" -lamdhip64 '
|
||||
LOG.i(f"ROCm ({rocm_version}) detected in {rocm_home}")
|
||||
|
||||
mod = jittor_utils.compile_module('''
|
||||
#include "common.h"
|
||||
namespace jittor {
|
||||
// @pyjt(process)
|
||||
string process_rocm(const string& src, const string& name, const map<string,string>& kargs);
|
||||
}''', compiler.cc_flags + " " + " ".join(cc_files + o_files) + cc_flags)
|
||||
jittor_utils.process_jittor_source("rocm", mod.process)
|
||||
|
||||
# preload hip driver to ensure the correct initialization of hip context
|
||||
hip_driver = ctypes.CDLL(os.path.join(rocm_home, 'lib', 'libamdhip64.so'), os.RTLD_GLOBAL | os.RTLD_NOW)
|
||||
r = hip_driver.hipDeviceSynchronize()
|
||||
|
||||
has_rocm = 1
|
||||
|
||||
|
||||
def install_hip():
|
||||
import jittor.compiler as compiler
|
||||
|
||||
LOG.vv("setup rocm extern...")
|
||||
cache_path_cuda = os.path.join(cache_path, "cuda")
|
||||
cuda_include = os.path.join(compiler.jittor_path, "extern", "cuda", "inc")
|
||||
compiler.make_cache_dir(cache_path_cuda)
|
||||
cuda_extern_src = os.path.join(compiler.jittor_path, "extern", "cuda", "src")
|
||||
cuda_extern_files = [os.path.join(cuda_extern_src, name) for name in os.listdir(cuda_extern_src)]
|
||||
so_name = os.path.join(cache_path_cuda, "libcuda_extern" + compiler.so)
|
||||
compiler.compile(compiler.cc_path, compiler.cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
|
||||
ctypes.CDLL(so_name, dlopen_flags)
|
||||
|
||||
|
||||
def install_rocm_library(lib_name, cuda_name, link=True):
|
||||
import jittor.compiler as compiler
|
||||
import jittor.compile_extern as compile_extern
|
||||
|
||||
LOG.vv(f"setup {lib_name}...")
|
||||
rocmlib_include_path = os.path.join(rocm_home, lib_name.lower(), "include")
|
||||
|
||||
jt_cuda_include = os.path.join(compiler.jittor_path, "extern", "cuda", "inc")
|
||||
jt_culib_include = os.path.join(compiler.jittor_path, "extern", "cuda", cuda_name, "inc")
|
||||
|
||||
culib_src_dir = os.path.join(compiler.jittor_path, "extern", "cuda", cuda_name)
|
||||
culib_src_files = []
|
||||
for r, _, f in os.walk(culib_src_dir):
|
||||
for fname in f:
|
||||
culib_src_files.append(os.path.join(r, fname))
|
||||
|
||||
extra_flags = f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" -I\"{rocmlib_include_path}\" "
|
||||
extra_flags += f" -L\"{os.path.join(cache_path, 'cuda')}\" -llibcuda_extern "
|
||||
if lib_name == "rccl":
|
||||
extra_flags += compile_extern.mpi_compile_flags
|
||||
|
||||
if link:
|
||||
rocmlib_lib_path = os.path.join(rocm_home, lib_name.lower(), "lib")
|
||||
if os.path.exists(os.path.join(rocmlib_lib_path, f"lib{lib_name}.so")):
|
||||
jittor_utils.LOG.i(f"Found {os.path.join(rocmlib_lib_path, 'lib' + lib_name + '.so')}")
|
||||
extra_flags += f" -L{rocmlib_lib_path} -l{lib_name} "
|
||||
|
||||
rocmlib = compiler.compile_custom_ops(culib_src_files, return_module=True, extra_flags=extra_flags)
|
||||
setattr(compile_extern, cuda_name, rocmlib)
|
||||
setattr(compile_extern, cuda_name + "_ops", rocmlib.ops)
|
||||
|
||||
|
||||
def install_extern():
|
||||
if has_rocm:
|
||||
install_hip()
|
||||
install_rocm_library("MIOpen", "cudnn")
|
||||
install_rocm_library("rocblas", "cublas")
|
||||
install_rocm_library("rocprim", "cub", link=False)
|
||||
install_rocm_library("rccl", "nccl")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def convert_nvcc_flags(nvcc_flags):
|
||||
return nvcc_flags
|
||||
|
||||
def check():
|
||||
import jittor.compiler as compiler
|
||||
global has_rocm, cc_flags
|
||||
if hipcc_path:
|
||||
try:
|
||||
install_rocm_jittor_core()
|
||||
except Exception as e:
|
||||
jittor_utils.LOG.w(f"load ROCm failed, exception: {e}")
|
||||
has_rocm = 0
|
||||
compiler.has_rocm = has_rocm
|
||||
compiler.hipcc_path = hipcc_path
|
||||
if not has_rocm:
|
||||
return False
|
||||
|
||||
compiler.cc_flags += cc_flags
|
||||
compiler.nvcc_path = hipcc_path
|
||||
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "-std=c++17")
|
||||
compiler.convert_nvcc_flags = convert_nvcc_flags
|
||||
return True
|
|
@ -0,0 +1,14 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void rocm_jittor_op_compiler(string& filename, string& src, bool is_rocm, string& extra_flags);
|
||||
|
||||
}
|
|
@ -30,7 +30,7 @@ __device__ inline static long long floatToOrderedInt(double floatVal) {
|
|||
return (intVal >= 0 ) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF;
|
||||
}
|
||||
__device__ inline static double orderedIntToFloat(long long intVal) {
|
||||
return __longlong_as_double((intVal >= 0) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF);
|
||||
return __longlong_as_double((intVal >= 0) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF);
|
||||
}
|
||||
|
||||
__global__ inline static void fix_float_kernel(double* x, int num) {
|
||||
|
@ -113,4 +113,86 @@ T cuda_atomic_mul(T* a, T b) {
|
|||
return old_f;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_mul(T a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_max(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_and(T a, T b) {
|
||||
return a & b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_or(T a, T b) {
|
||||
return a | b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shared_reduce_xor(T a, T b) {
|
||||
return a ^ b;
|
||||
}
|
||||
|
||||
|
||||
template<typename T, T(*op)(T, T)>
|
||||
__device__ inline void warpReduce(volatile T* sdata, int tid) {
|
||||
if (blockDim.x >= 64)
|
||||
sdata[tid] = op(sdata[tid], sdata[tid + 32]);
|
||||
sdata[tid] = op(sdata[tid], sdata[tid + 16]);
|
||||
sdata[tid] = op(sdata[tid], sdata[tid + 8]);
|
||||
sdata[tid] = op(sdata[tid], sdata[tid + 4]);
|
||||
sdata[tid] = op(sdata[tid], sdata[tid + 2]);
|
||||
sdata[tid] = op(sdata[tid], sdata[tid + 1]);
|
||||
}
|
||||
|
||||
template<typename T, T(*op)(T, T)>
|
||||
__device__ inline static T shared_reduce(T u) {
|
||||
__shared__ T sdata[1024];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
sdata[tid] = u;
|
||||
__syncthreads();
|
||||
|
||||
if (blockDim.x >= 1024 && tid < 512) {
|
||||
sdata[tid] = u = op(u, sdata[tid + 512]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (blockDim.x >= 512 && tid < 256) {
|
||||
sdata[tid] = u = op(u, sdata[tid + 256]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (blockDim.x >= 256 && tid < 128) {
|
||||
sdata[tid] = u = op(u, sdata[tid + 128]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (blockDim.x >= 128 && tid < 64) {
|
||||
sdata[tid] = u = op(u, sdata[tid + 64]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid < 32)
|
||||
warpReduce<T, op>(sdata, tid);
|
||||
|
||||
return sdata[0];
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -5,8 +5,24 @@
|
|||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
|
||||
#ifdef IS_CUDA
|
||||
#include <npp.h>
|
||||
#include <math_constants.h>
|
||||
#else
|
||||
#include <limits>
|
||||
#define NPP_MIN_32U ( 0 )
|
||||
#define NPP_MAX_32U ( 4294967295U )
|
||||
#define NPP_MIN_32S (-2147483647 - 1 )
|
||||
#define NPP_MAX_32S ( 2147483647 )
|
||||
#define NPP_MIN_64U ( 0 )
|
||||
#define NPP_MAX_64U ( 18446744073709551615ULL )
|
||||
#define NPP_MIN_64S (-9223372036854775807LL - 1)
|
||||
#define NPP_MAX_64S ( 9223372036854775807LL )
|
||||
#define CUDART_INF_F std::numeric_limits<float>::infinity()
|
||||
#define CUDART_INF std::numeric_limits<double>::infinity()
|
||||
#endif
|
||||
|
||||
|
||||
template<class T> __device__ T numeric_min();
|
||||
template<class T> __device__ T numeric_max();
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2022 Jittor. All Rights Reserved.
|
||||
// Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "opt/pass/pass.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct SharedReducePass : Pass {
|
||||
SharedReducePass() : Pass("shared_reduce") {};
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -27,6 +27,7 @@
|
|||
#include "opt/pass/assume_aligned_pass.h"
|
||||
#include "opt/pass/parallel_pass.h"
|
||||
#include "opt/pass/atomic_tuner_pass.h"
|
||||
#include "opt/pass/shared_reduce_pass.h"
|
||||
#include "opt/pass/float_atomic_fix_pass.h"
|
||||
#include "opt/pass/insert_profile_loop_pass.h"
|
||||
#include "opt/pass/fake_main_pass.h"
|
||||
|
@ -111,6 +112,7 @@ void PassManager::run_passes() {
|
|||
run_pass<AssumeAlignedPass>();
|
||||
run_pass<ParallelPass>();
|
||||
run_pass<AtomicTunerPass>();
|
||||
run_pass<SharedReducePass>();
|
||||
run_pass<FloatAtomicFixPass>();
|
||||
|
||||
run_pass<InsertProfileLoopPass>();
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cctype>
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -78,12 +79,26 @@ string replace(const string& a, const string& b, const string& c) {
|
|||
|
||||
static inline bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
|
||||
|
||||
vector<string> token_split(const string& s) {
|
||||
vector<string> token_split(const string& s, bool exclude_comments) {
|
||||
vector<string> ss;
|
||||
if (!s.size()) return ss;
|
||||
ss.push_back(string()+s[0]);
|
||||
for (int i=1; i<s.size(); i++) {
|
||||
if (isvar(s[i]) != isvar(s[i-1]))
|
||||
ss.push_back("");
|
||||
for (int i = 0; i < s.size(); i++) {
|
||||
if (exclude_comments) {
|
||||
if (s[i] == '/' && s[i+1] == '/') {
|
||||
i = s.find('\n', i);
|
||||
if (i == string::npos)
|
||||
return ss;
|
||||
}
|
||||
if (s[i] == '/' && s[i+1] == '*') {
|
||||
i = s.find("*/", i);
|
||||
if (i == string::npos)
|
||||
return ss;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (i && (isvar(s[i]) != isvar(s[i-1])))
|
||||
ss.push_back("");
|
||||
ss.back() += s[i];
|
||||
}
|
||||
|
@ -92,7 +107,8 @@ vector<string> token_split(const string& s) {
|
|||
|
||||
static void parse_reg(const string& src,
|
||||
vector<string>& patterns,
|
||||
vector<int>& arg_id) {
|
||||
vector<int>& arg_id,
|
||||
bool match_whitespace=true) {
|
||||
patterns.clear();
|
||||
arg_id.clear();
|
||||
patterns.push_back("");
|
||||
|
@ -103,11 +119,12 @@ static void parse_reg(const string& src,
|
|||
patterns.push_back("");
|
||||
continue;
|
||||
}
|
||||
patterns.back() += src[j];
|
||||
if (match_whitespace || !isspace(src[j]))
|
||||
patterns.back() += src[j];
|
||||
}
|
||||
}
|
||||
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst, bool match_whitespace) {
|
||||
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
|
||||
src.at(src.size()-2) != '$') << "illegal src:" << src;
|
||||
vector<string> patterns;
|
||||
|
@ -115,7 +132,7 @@ int token_replace(vector<string>& tokens, int i, const string& src, const string
|
|||
vector<string> patterns2;
|
||||
vector<int> arg_id2;
|
||||
unordered_map<int, string> args;
|
||||
parse_reg(src, patterns, arg_id);
|
||||
parse_reg(src, patterns, arg_id, match_whitespace);
|
||||
parse_reg(dst, patterns2, arg_id2);
|
||||
|
||||
int start_i, start_pos, end_i, end_pos;
|
||||
|
@ -123,17 +140,23 @@ int token_replace(vector<string>& tokens, int i, const string& src, const string
|
|||
int match_i, match_pos;
|
||||
string c_arg;
|
||||
|
||||
auto next = [&tokens](int &c_i, int &c_pos) {
|
||||
c_pos ++;
|
||||
if (c_pos >= tokens[c_i].size()) {
|
||||
c_pos = 0;
|
||||
c_i ++;
|
||||
if (c_i >= tokens.size())
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
auto match = [&](int c_i, int c_pos, const string& pat) -> bool {
|
||||
for (int i=0; i<pat.size(); i++) {
|
||||
while (!match_whitespace && isspace(tokens[c_i][c_pos]))
|
||||
next(c_i, c_pos);
|
||||
if (tokens[c_i][c_pos] != pat[i])
|
||||
return false;
|
||||
c_pos ++;
|
||||
if (c_pos >= tokens[c_i].size()) {
|
||||
c_pos = 0;
|
||||
c_i ++;
|
||||
if (c_i >= tokens.size())
|
||||
return false;
|
||||
}
|
||||
next(c_i, c_pos);
|
||||
}
|
||||
match_i = c_i;
|
||||
match_pos = c_pos;
|
||||
|
@ -189,9 +212,9 @@ int token_replace(vector<string>& tokens, int i, const string& src, const string
|
|||
return end_i;
|
||||
}
|
||||
|
||||
string token_replace(const string& s, const string& src, const string& dst) {
|
||||
string token_replace(const string& s, const string& src, const string& dst, bool match_whitespace) {
|
||||
vector<string> ss{s};
|
||||
token_replace(ss, 0, src, dst);
|
||||
token_replace(ss, 0, src, dst, match_whitespace);
|
||||
return join(ss, "");
|
||||
}
|
||||
|
||||
|
|
|
@ -33,11 +33,11 @@ string replace(const string& a, const string& b, const string& c);
|
|||
|
||||
string join(const vector<string>& vs, const string& x);
|
||||
|
||||
vector<string> token_split(const string& s);
|
||||
vector<string> token_split(const string& s, bool exclude_comments=false);
|
||||
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst, bool match_whitespace=true);
|
||||
|
||||
string token_replace(const string& s, const string& src, const string& dst);
|
||||
string token_replace_all(const string& s, const string& src, const string& dst);
|
||||
|
||||
string token_replace_all(const string& s, const string& src, const string& dst);
|
||||
} // jittor
|
|
@ -0,0 +1,453 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
|
||||
import os
|
||||
import random
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
import jittor as jt
|
||||
from jittor import init, Module, nn, Function
|
||||
from jittor.models import vgg
|
||||
from jittor.dataset.mnist import MNIST
|
||||
import jittor.transform as trans
|
||||
|
||||
from .test_core import expect_error
|
||||
from .test_reorder_tuner import simple_parser
|
||||
from .test_log import find_log_with_re
|
||||
|
||||
|
||||
def test_rocm(use_rocm=1):
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestCudaBase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
jt.flags.use_rocm = use_rocm
|
||||
def tearDown(self):
|
||||
jt.flags.use_rocm = 0
|
||||
return TestCudaBase
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestROCm(unittest.TestCase):
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_array(self):
|
||||
a = jt.array([1,2,3])
|
||||
np.testing.assert_allclose(a.numpy(), [1,2,3])
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_add(self):
|
||||
a = jt.array([1,2,3])
|
||||
b = a+a
|
||||
np.testing.assert_allclose(b.numpy(), [2,4,6])
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_add_float(self):
|
||||
a = jt.array([1.0,2.0,3.0])
|
||||
b = a+a
|
||||
np.testing.assert_allclose(b.numpy(), [2,4,6])
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_array_cast(self):
|
||||
# this test cannot pass because cast error
|
||||
x = np.random.rand(10)
|
||||
y = jt.float32(x)
|
||||
np.testing.assert_allclose(x, y.numpy())
|
||||
|
||||
def test_meminfo(self):
|
||||
jt.display_memory_info()
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_cuda_flags(self):
|
||||
a = jt.random((10, 10))
|
||||
a.sync()
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_rocm_custom_op_from_cuda(self):
|
||||
my_op = jt.compile_custom_op("""
|
||||
struct MyCudaOp : Op {
|
||||
Var* output;
|
||||
MyCudaOp(NanoVector shape, string dtype="float");
|
||||
|
||||
const char* name() const override { return "my_cuda"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
""", """
|
||||
#ifndef JIT
|
||||
MyCudaOp::MyCudaOp(NanoVector shape, string dtype) {
|
||||
flags.set(NodeFlags::_cuda);
|
||||
output = create_output(shape, dtype);
|
||||
}
|
||||
|
||||
void MyCudaOp::jit_prepare(JK& jk) {
|
||||
add_jit_define(jk, "T", output->dtype());
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
__global__ void kernel(index_t n, T *x) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride)
|
||||
x[i] = (T)-i;
|
||||
}
|
||||
|
||||
void MyCudaOp::jit_run() {
|
||||
index_t num = output->num;
|
||||
auto* __restrict__ x = output->ptr<T>();
|
||||
int blockSize = 256;
|
||||
int numBlocks = (num + blockSize - 1) / blockSize;
|
||||
kernel<<<numBlocks, blockSize>>>(num, x);
|
||||
}
|
||||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
""",
|
||||
"my_cuda")
|
||||
a = my_op([3,4,5], 'float')
|
||||
na = a.data
|
||||
assert a.shape == [3,4,5] and a.dtype == 'float'
|
||||
assert (-na.flatten() == range(3*4*5)).all(), na
|
||||
|
||||
def test_rocm_fused_op(self):
|
||||
a = jt.array([1,2,3])
|
||||
a.sync()
|
||||
with jt.flag_scope(use_rocm=1):
|
||||
((a+a)*2).data
|
||||
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = nn.Linear(input_size, 10)
|
||||
self.relu1 = nn.Relu()
|
||||
self.linear2 = nn.Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestExample(unittest.TestCase):
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test1(self):
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
lr = 0.05
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1).astype("float32")
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model(input_size=1)
|
||||
ps = model.parameters()
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
jt.sync_all(True)
|
||||
pred_y = model(x).name("pred_y")
|
||||
loss = ((pred_y - y).sqr()).name("loss")
|
||||
loss_mean = loss.mean()
|
||||
|
||||
gs = jt.grad(loss_mean, ps)
|
||||
for p, g in zip(ps, gs):
|
||||
p -= g * lr
|
||||
|
||||
if i>2:
|
||||
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
|
||||
prev = jt.liveness_info()
|
||||
|
||||
possible_results = [
|
||||
0.0009948202641680837,
|
||||
0.001381353591568768,
|
||||
0.00110957445576787,
|
||||
0.001124994712881744
|
||||
]
|
||||
loss_mean = loss_mean.data
|
||||
assert any(abs(loss_mean - r) < 1e-6 for r in possible_results)
|
||||
|
||||
jt.clean()
|
||||
|
||||
|
||||
from .test_unary_op import TestUnaryOp
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestROCmUnaryOp(TestUnaryOp, test_rocm(1)):
|
||||
pass
|
||||
|
||||
|
||||
from .test_binary_op import TestBinaryOp
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestROCmBinaryOp(TestBinaryOp, test_rocm(1)):
|
||||
pass
|
||||
|
||||
|
||||
from .test_reduce_op import TestReduceOp
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestROCmReduceOp(TestReduceOp, test_rocm(1)):
|
||||
pass
|
||||
|
||||
|
||||
from .test_reindex_op import TestReindexOp
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestROCmReindexOp(TestReindexOp, test_rocm(1)):
|
||||
pass
|
||||
|
||||
|
||||
# from .test_reindex_reduce_op import TestReindexReduceOp
|
||||
# @unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
# class TestROCmReindexReduceOp(TestReindexReduceOp, test_rocm(1)):
|
||||
# pass
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestROCmCodeOp(unittest.TestCase):
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_cuda(self):
|
||||
a = jt.random([100000])
|
||||
b = jt.random([100000])
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @in0(i)*@in1(i);
|
||||
}
|
||||
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in1(i);
|
||||
}
|
||||
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel3(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (; i<in0_shape0; i+=stride)
|
||||
@out(i) = @dout(i)*@in0(i);
|
||||
}
|
||||
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_cuda2(self):
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
c = jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''',
|
||||
cuda_grad_src = ['''
|
||||
__global__ static void kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel<<<32, 32>>>(@ARGS);
|
||||
''', '''
|
||||
__global__ static void kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@pout(0,0);
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @dout(i,j)*@in0(i,j);
|
||||
}
|
||||
kernel<<<32, 32>>>(@ARGS);
|
||||
'''])
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
@jt.flag_scope(use_rocm=1)
|
||||
def test_cuda2_use_func(self):
|
||||
class Func(Function):
|
||||
def execute(self, a, b):
|
||||
self.save_vars = a, b
|
||||
return jt.code(a.shape, a.dtype, [a,b],
|
||||
cuda_src='''
|
||||
__global__ static void kernel1(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
|
||||
@out(i,j) = @in0(i,j)*@in1(i,j);
|
||||
}
|
||||
kernel1<<<32, 32>>>(@ARGS);
|
||||
''')
|
||||
|
||||
def grad(self, grad):
|
||||
a, b = self.save_vars
|
||||
return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad],
|
||||
cuda_src='''
|
||||
__global__ static void kernel2(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
|
||||
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x) {
|
||||
@out0(i,j) = @in2(i,j)*@in1(i,j);
|
||||
@out1(i,j) = @in2(i,j)*@in0(i,j);
|
||||
}
|
||||
}
|
||||
kernel2<<<32, 32>>>(@ARGS);
|
||||
''')
|
||||
|
||||
a = jt.random((100,100))
|
||||
b = jt.random((100,100))
|
||||
|
||||
func = Func()
|
||||
c = func(a,b)
|
||||
da, db = jt.grad(c, [a, b])
|
||||
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
|
||||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
|
||||
class TestBMM(unittest.TestCase):
|
||||
def test_bmm_rocm(self):
|
||||
def check(batch, n, m, k):
|
||||
def calc(use_rocm, a, b, mask):
|
||||
jt.flags.use_rocm = use_rocm
|
||||
a = jt.array(a)
|
||||
b = jt.array(b)
|
||||
mask = jt.array(mask)
|
||||
c = nn.bmm(a, b)
|
||||
da, db = jt.grad(c*mask, [a, b])
|
||||
return c.data, da.data, db.data
|
||||
mask = np.random.rand(batch, n, k).astype("float32")
|
||||
a = np.random.rand(batch, n, m).astype("float32")
|
||||
b = np.random.rand(batch, m, k).astype("float32")
|
||||
a1,a2,a3 = calc(0, a, b, mask)
|
||||
b1,b2,b3 = calc(1, a, b, mask)
|
||||
assert np.allclose(a1, b1)
|
||||
assert np.allclose(a2, b2)
|
||||
assert np.allclose(a3, b3)
|
||||
check(10,3,4,5)
|
||||
check(10,8,8,8)
|
||||
check(10,8,1,8)
|
||||
check(10,8,8,1)
|
||||
check(10,1,8,8)
|
||||
check(1,7,8,8)
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = nn.Linear(input_size, 10)
|
||||
self.relu1 = nn.Relu()
|
||||
self.linear2 = nn.Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
from jittor.models import resnet
|
||||
|
||||
class MnistNet(Module):
|
||||
def __init__(self):
|
||||
self.model = resnet.Resnet18()
|
||||
self.layer = nn.Linear(1000,10)
|
||||
def execute(self, x):
|
||||
x = self.model(x)
|
||||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_rocm, "skip_this_test")
|
||||
class TestResnetFp32(unittest.TestCase):
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
jt.seed(seed)
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_resnet(self):
|
||||
self.setup_seed(1)
|
||||
|
||||
# hyper-parameters
|
||||
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
if jt.flags.amp_reg:
|
||||
self.learning_rate = 0.01
|
||||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
loss_list=[]
|
||||
acc_list=[]
|
||||
mnist_net = MnistNet()
|
||||
global prev
|
||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
self.train_loader.endless = True
|
||||
|
||||
for data, target in self.train_loader:
|
||||
batch_id = self.train_loader.batch_id
|
||||
epoch_id = self.train_loader.epoch_id
|
||||
data = data.float_auto()
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
|
||||
break
|
||||
jt.sync_all(True)
|
||||
|
||||
for _ in range(10):
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(epoch_id, batch_id, loss, output, target):
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.mean(target==pred)
|
||||
jt.fetch(epoch_id, _, loss, output, target, callback)
|
||||
jt.sync_all(True)
|
||||
|
||||
all_time = time.time()
|
||||
prev = time.time()
|
||||
print('starting')
|
||||
for _ in range(100):
|
||||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(epoch_id, batch_id, loss, output, target):
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
acc = np.mean(target==pred)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev))
|
||||
prev = time.time()
|
||||
jt.fetch(epoch_id, _, loss, output, target, callback)
|
||||
jt.sync_all(True)
|
||||
print(f'all = {time.time() - all_time}')
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Binary file not shown.
|
@ -653,7 +653,7 @@ def process_jittor_source(device_type, callback):
|
|||
for name in files:
|
||||
fname = os.path.join(root, name)
|
||||
fname2 = os.path.join(root2, name)
|
||||
if fname.endswith(".h") or fname.endswith(".cc"):
|
||||
if fname.endswith(".h") or fname.endswith(".cc") or fname.endswith(".cu"):
|
||||
with open(fname, 'r', encoding="utf8") as f:
|
||||
src = f.read()
|
||||
src = callback(src, name, {"fname":fname, "fname2":fname2})
|
||||
|
|
Loading…
Reference in New Issue