Support AMD ROCm and HIP backend

This commit is contained in:
lzhengning 2022-04-07 18:22:23 +08:00
parent 46a03098f9
commit e316f511c3
17 changed files with 787 additions and 28 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.4.14'
__version__ = '1.3.4.15'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -588,3 +588,8 @@ except Exception as e:
LOG.w("MKL install failed, msg:", e)
setup_cuda_extern()
# install backend extern library
for mod in jit_utils.backends:
if mod.install_extern():
break

View File

@ -133,7 +133,7 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="",
nflags = oflags
cmd = f"{cm(input)} {nflags} {lto_flags} -c -o {cm(obj_file)}"
if input.endswith(".cu"):
if has_cuda:
if has_cuda or has_rocm:
cmd = f"\"{nvcc_path}\" {cuda_flags} {cmd}"
cmd = convert_nvcc_flags(fix_cl_flags(cmd))
else:
@ -143,8 +143,10 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="",
cmd = fix_cl_flags(cmd)
if "nan_checker" in input:
# nan checker needs to disable fast_math
cmd = cmd.replace("--use_fast_math", "")
cmd = cmd.replace("-Ofast", "-O2")
if "--use_fast_math" in cmd:
cmd = cmd.replace("--use_fast_math", "")
if "-Ofast" in cmd:
cmd = cmd.replace("-Ofast", "-O2")
cmds.append(cmd)
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
obj_files += ex_obj_files
@ -232,7 +234,7 @@ def gen_jit_flags():
jit_declares.append(f"DECLARE_FLAG({type}, {name});")
alias = []
if name == "use_cuda":
alias = ["use_device", "use_acl"]
alias = ["use_device", "use_acl", "use_rocm"]
elif name == "auto_mixed_precision_level":
alias = ["amp_level"]
get_names = ",".join(["__get__"+a for a in [name]+alias])
@ -1229,6 +1231,8 @@ if has_cuda:
# from .acl_compiler import check_acl
from .extern.acl import acl_compiler
jit_utils.add_backend(acl_compiler)
from .extern.rocm import rocm_compiler
jit_utils.add_backend(rocm_compiler)
for mod in jit_utils.backends:
if mod.check():
@ -1252,7 +1256,7 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
# 3. op_utils
# 4. other
files2 = pyjt_gen_src
ext_args = 'c[cu]' if has_cuda else 'cc'
ext_args = 'c[cu]' if has_cuda or has_rocm else 'cc'
files4 = glob.glob(jittor_path+"/src/**/*."+ext_args, recursive=True)
files4 = [ f[len(jittor_path)+1:] for f in files4 ]
# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()

View File

@ -35,6 +35,10 @@ string process_acl(const string& src, const string& name, const map<string,strin
has_acl = 1
def install_extern():
return False
def check():
import jittor.compiler as compiler
global has_acl, cc_flags

BIN
python/jittor/extern/rocm/rocm_cache.gz vendored Normal file

Binary file not shown.

View File

@ -0,0 +1,139 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import os
import ctypes
import glob
import jittor_utils
from jittor_utils import env_or_try_find, run_cmd, cache_path, LOG
has_rocm = 0
cc_flags = ""
hipcc_path = env_or_try_find('hipcc_path', 'hipcc')
rocm_home = ""
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL
def install_rocm_jittor_core():
import jittor.compiler as compiler
global has_rocm, cc_flags, rocm_home
rocm_home = run_cmd("hipconfig -R")
rocm_version = run_cmd("hipconfig -v")
rocm_compiler_home = os.path.dirname(__file__)
rocm_cache_path = os.path.join(rocm_compiler_home, "rocm_cache.o")
rocm_cache_gz_path = os.path.join(rocm_compiler_home, "rocm_cache.gz")
if not os.path.exists(rocm_cache_path) and os.path.exists(rocm_cache_gz_path):
import gzip
with gzip.open(rocm_cache_gz_path, "rb") as f:
data = f.read()
with open(rocm_cache_path, "wb") as f:
f.write(data)
cc_files = sorted(glob.glob(rocm_compiler_home + "/**/*.cc", recursive=True))
o_files = sorted(glob.glob(rocm_compiler_home + "/**/*.o", recursive=True))
cc_flags += f" -DHAS_CUDA -DIS_ROCM -I{rocm_compiler_home} "
cc_flags += " " + run_cmd("hipconfig -C") + " "
cc_flags += ' -L"' + os.path.join(rocm_home, "lib") + '" -lamdhip64 '
LOG.i(f"ROCm ({rocm_version}) detected in {rocm_home}")
mod = jittor_utils.compile_module('''
#include "common.h"
namespace jittor {
// @pyjt(process)
string process_rocm(const string& src, const string& name, const map<string,string>& kargs);
}''', compiler.cc_flags + " " + " ".join(cc_files + o_files) + cc_flags)
jittor_utils.process_jittor_source("rocm", mod.process)
# preload hip driver to ensure the correct initialization of hip context
hip_driver = ctypes.CDLL(os.path.join(rocm_home, 'lib', 'libamdhip64.so'), os.RTLD_GLOBAL | os.RTLD_NOW)
r = hip_driver.hipDeviceSynchronize()
has_rocm = 1
def install_hip():
import jittor.compiler as compiler
LOG.vv("setup rocm extern...")
cache_path_cuda = os.path.join(cache_path, "cuda")
cuda_include = os.path.join(compiler.jittor_path, "extern", "cuda", "inc")
compiler.make_cache_dir(cache_path_cuda)
cuda_extern_src = os.path.join(compiler.jittor_path, "extern", "cuda", "src")
cuda_extern_files = [os.path.join(cuda_extern_src, name) for name in os.listdir(cuda_extern_src)]
so_name = os.path.join(cache_path_cuda, "libcuda_extern" + compiler.so)
compiler.compile(compiler.cc_path, compiler.cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name)
ctypes.CDLL(so_name, dlopen_flags)
def install_rocm_library(lib_name, cuda_name, link=True):
import jittor.compiler as compiler
import jittor.compile_extern as compile_extern
LOG.vv(f"setup {lib_name}...")
rocmlib_include_path = os.path.join(rocm_home, lib_name.lower(), "include")
jt_cuda_include = os.path.join(compiler.jittor_path, "extern", "cuda", "inc")
jt_culib_include = os.path.join(compiler.jittor_path, "extern", "cuda", cuda_name, "inc")
culib_src_dir = os.path.join(compiler.jittor_path, "extern", "cuda", cuda_name)
culib_src_files = []
for r, _, f in os.walk(culib_src_dir):
for fname in f:
culib_src_files.append(os.path.join(r, fname))
extra_flags = f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" -I\"{rocmlib_include_path}\" "
extra_flags += f" -L\"{os.path.join(cache_path, 'cuda')}\" -llibcuda_extern "
if lib_name == "rccl":
extra_flags += compile_extern.mpi_compile_flags
if link:
rocmlib_lib_path = os.path.join(rocm_home, lib_name.lower(), "lib")
if os.path.exists(os.path.join(rocmlib_lib_path, f"lib{lib_name}.so")):
jittor_utils.LOG.i(f"Found {os.path.join(rocmlib_lib_path, 'lib' + lib_name + '.so')}")
extra_flags += f" -L{rocmlib_lib_path} -l{lib_name} "
rocmlib = compiler.compile_custom_ops(culib_src_files, return_module=True, extra_flags=extra_flags)
setattr(compile_extern, cuda_name, rocmlib)
setattr(compile_extern, cuda_name + "_ops", rocmlib.ops)
def install_extern():
if has_rocm:
install_hip()
install_rocm_library("MIOpen", "cudnn")
install_rocm_library("rocblas", "cublas")
install_rocm_library("rocprim", "cub", link=False)
install_rocm_library("rccl", "nccl")
return True
else:
return False
def convert_nvcc_flags(nvcc_flags):
return nvcc_flags
def check():
import jittor.compiler as compiler
global has_rocm, cc_flags
if hipcc_path:
try:
install_rocm_jittor_core()
except Exception as e:
jittor_utils.LOG.w(f"load ROCm failed, exception: {e}")
has_rocm = 0
compiler.has_rocm = has_rocm
compiler.hipcc_path = hipcc_path
if not has_rocm:
return False
compiler.cc_flags += cc_flags
compiler.nvcc_path = hipcc_path
compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "-std=c++17")
compiler.convert_nvcc_flags = convert_nvcc_flags
return True

14
python/jittor/extern/rocm/rocm_jittor.h vendored Normal file
View File

@ -0,0 +1,14 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
void rocm_jittor_op_compiler(string& filename, string& src, bool is_rocm, string& extra_flags);
}

View File

@ -30,7 +30,7 @@ __device__ inline static long long floatToOrderedInt(double floatVal) {
return (intVal >= 0 ) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF;
}
__device__ inline static double orderedIntToFloat(long long intVal) {
return __longlong_as_double((intVal >= 0) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF);
return __longlong_as_double((intVal >= 0) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF);
}
__global__ inline static void fix_float_kernel(double* x, int num) {
@ -113,4 +113,86 @@ T cuda_atomic_mul(T* a, T b) {
return old_f;
}
template<typename T>
__device__ inline T shared_reduce_add(T a, T b) {
return a + b;
}
template<typename T>
__device__ inline T shared_reduce_mul(T a, T b) {
return a * b;
}
template<typename T>
__device__ inline T shared_reduce_max(T a, T b) {
return a > b ? a : b;
}
template<typename T>
__device__ inline T shared_reduce_min(T a, T b) {
return a < b ? a : b;
}
template<typename T>
__device__ inline T shared_reduce_and(T a, T b) {
return a & b;
}
template<typename T>
__device__ inline T shared_reduce_or(T a, T b) {
return a | b;
}
template<typename T>
__device__ inline T shared_reduce_xor(T a, T b) {
return a ^ b;
}
template<typename T, T(*op)(T, T)>
__device__ inline void warpReduce(volatile T* sdata, int tid) {
if (blockDim.x >= 64)
sdata[tid] = op(sdata[tid], sdata[tid + 32]);
sdata[tid] = op(sdata[tid], sdata[tid + 16]);
sdata[tid] = op(sdata[tid], sdata[tid + 8]);
sdata[tid] = op(sdata[tid], sdata[tid + 4]);
sdata[tid] = op(sdata[tid], sdata[tid + 2]);
sdata[tid] = op(sdata[tid], sdata[tid + 1]);
}
template<typename T, T(*op)(T, T)>
__device__ inline static T shared_reduce(T u) {
__shared__ T sdata[1024];
int tid = threadIdx.x;
sdata[tid] = u;
__syncthreads();
if (blockDim.x >= 1024 && tid < 512) {
sdata[tid] = u = op(u, sdata[tid + 512]);
}
__syncthreads();
if (blockDim.x >= 512 && tid < 256) {
sdata[tid] = u = op(u, sdata[tid + 256]);
}
__syncthreads();
if (blockDim.x >= 256 && tid < 128) {
sdata[tid] = u = op(u, sdata[tid + 128]);
}
__syncthreads();
if (blockDim.x >= 128 && tid < 64) {
sdata[tid] = u = op(u, sdata[tid + 64]);
}
__syncthreads();
if (tid < 32)
warpReduce<T, op>(sdata, tid);
return sdata[0];
}
} // jittor

View File

@ -5,8 +5,24 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#ifdef IS_CUDA
#include <npp.h>
#include <math_constants.h>
#else
#include <limits>
#define NPP_MIN_32U ( 0 )
#define NPP_MAX_32U ( 4294967295U )
#define NPP_MIN_32S (-2147483647 - 1 )
#define NPP_MAX_32S ( 2147483647 )
#define NPP_MIN_64U ( 0 )
#define NPP_MAX_64U ( 18446744073709551615ULL )
#define NPP_MIN_64S (-9223372036854775807LL - 1)
#define NPP_MAX_64S ( 9223372036854775807LL )
#define CUDART_INF_F std::numeric_limits<float>::infinity()
#define CUDART_INF std::numeric_limits<double>::infinity()
#endif
template<class T> __device__ T numeric_min();
template<class T> __device__ T numeric_max();

View File

@ -0,0 +1,17 @@
// ***************************************************************
// Copyright (c) 2022 Jittor. All Rights Reserved.
// Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "opt/pass/pass.h"
namespace jittor {
struct SharedReducePass : Pass {
SharedReducePass() : Pass("shared_reduce") {};
void run() override;
};
} // jittor

View File

@ -27,6 +27,7 @@
#include "opt/pass/assume_aligned_pass.h"
#include "opt/pass/parallel_pass.h"
#include "opt/pass/atomic_tuner_pass.h"
#include "opt/pass/shared_reduce_pass.h"
#include "opt/pass/float_atomic_fix_pass.h"
#include "opt/pass/insert_profile_loop_pass.h"
#include "opt/pass/fake_main_pass.h"
@ -111,6 +112,7 @@ void PassManager::run_passes() {
run_pass<AssumeAlignedPass>();
run_pass<ParallelPass>();
run_pass<AtomicTunerPass>();
run_pass<SharedReducePass>();
run_pass<FloatAtomicFixPass>();
run_pass<InsertProfileLoopPass>();

View File

@ -4,6 +4,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cctype>
#include "utils/str_utils.h"
namespace jittor {
@ -78,12 +79,26 @@ string replace(const string& a, const string& b, const string& c) {
static inline bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; }
vector<string> token_split(const string& s) {
vector<string> token_split(const string& s, bool exclude_comments) {
vector<string> ss;
if (!s.size()) return ss;
ss.push_back(string()+s[0]);
for (int i=1; i<s.size(); i++) {
if (isvar(s[i]) != isvar(s[i-1]))
ss.push_back("");
for (int i = 0; i < s.size(); i++) {
if (exclude_comments) {
if (s[i] == '/' && s[i+1] == '/') {
i = s.find('\n', i);
if (i == string::npos)
return ss;
}
if (s[i] == '/' && s[i+1] == '*') {
i = s.find("*/", i);
if (i == string::npos)
return ss;
i += 1;
continue;
}
}
if (i && (isvar(s[i]) != isvar(s[i-1])))
ss.push_back("");
ss.back() += s[i];
}
@ -92,7 +107,8 @@ vector<string> token_split(const string& s) {
static void parse_reg(const string& src,
vector<string>& patterns,
vector<int>& arg_id) {
vector<int>& arg_id,
bool match_whitespace=true) {
patterns.clear();
arg_id.clear();
patterns.push_back("");
@ -103,11 +119,12 @@ static void parse_reg(const string& src,
patterns.push_back("");
continue;
}
patterns.back() += src[j];
if (match_whitespace || !isspace(src[j]))
patterns.back() += src[j];
}
}
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst, bool match_whitespace) {
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
src.at(src.size()-2) != '$') << "illegal src:" << src;
vector<string> patterns;
@ -115,7 +132,7 @@ int token_replace(vector<string>& tokens, int i, const string& src, const string
vector<string> patterns2;
vector<int> arg_id2;
unordered_map<int, string> args;
parse_reg(src, patterns, arg_id);
parse_reg(src, patterns, arg_id, match_whitespace);
parse_reg(dst, patterns2, arg_id2);
int start_i, start_pos, end_i, end_pos;
@ -123,17 +140,23 @@ int token_replace(vector<string>& tokens, int i, const string& src, const string
int match_i, match_pos;
string c_arg;
auto next = [&tokens](int &c_i, int &c_pos) {
c_pos ++;
if (c_pos >= tokens[c_i].size()) {
c_pos = 0;
c_i ++;
if (c_i >= tokens.size())
return false;
}
};
auto match = [&](int c_i, int c_pos, const string& pat) -> bool {
for (int i=0; i<pat.size(); i++) {
while (!match_whitespace && isspace(tokens[c_i][c_pos]))
next(c_i, c_pos);
if (tokens[c_i][c_pos] != pat[i])
return false;
c_pos ++;
if (c_pos >= tokens[c_i].size()) {
c_pos = 0;
c_i ++;
if (c_i >= tokens.size())
return false;
}
next(c_i, c_pos);
}
match_i = c_i;
match_pos = c_pos;
@ -189,9 +212,9 @@ int token_replace(vector<string>& tokens, int i, const string& src, const string
return end_i;
}
string token_replace(const string& s, const string& src, const string& dst) {
string token_replace(const string& s, const string& src, const string& dst, bool match_whitespace) {
vector<string> ss{s};
token_replace(ss, 0, src, dst);
token_replace(ss, 0, src, dst, match_whitespace);
return join(ss, "");
}

View File

@ -33,11 +33,11 @@ string replace(const string& a, const string& b, const string& c);
string join(const vector<string>& vs, const string& x);
vector<string> token_split(const string& s);
vector<string> token_split(const string& s, bool exclude_comments=false);
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst, bool match_whitespace=true);
string token_replace(const string& s, const string& src, const string& dst);
string token_replace_all(const string& s, const string& src, const string& dst);
string token_replace_all(const string& s, const string& src, const string& dst);
} // jittor

View File

@ -0,0 +1,453 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Zheng-Ning Liu <lzhengning@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os
import random
import math
import time
import numpy as np
import tqdm
import jittor as jt
from jittor import init, Module, nn, Function
from jittor.models import vgg
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
from .test_core import expect_error
from .test_reorder_tuner import simple_parser
from .test_log import find_log_with_re
def test_rocm(use_rocm=1):
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestCudaBase(unittest.TestCase):
def setUp(self):
jt.flags.use_rocm = use_rocm
def tearDown(self):
jt.flags.use_rocm = 0
return TestCudaBase
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestROCm(unittest.TestCase):
@jt.flag_scope(use_rocm=1)
def test_array(self):
a = jt.array([1,2,3])
np.testing.assert_allclose(a.numpy(), [1,2,3])
@jt.flag_scope(use_rocm=1)
def test_add(self):
a = jt.array([1,2,3])
b = a+a
np.testing.assert_allclose(b.numpy(), [2,4,6])
@jt.flag_scope(use_rocm=1)
def test_add_float(self):
a = jt.array([1.0,2.0,3.0])
b = a+a
np.testing.assert_allclose(b.numpy(), [2,4,6])
@jt.flag_scope(use_rocm=1)
def test_array_cast(self):
# this test cannot pass because cast error
x = np.random.rand(10)
y = jt.float32(x)
np.testing.assert_allclose(x, y.numpy())
def test_meminfo(self):
jt.display_memory_info()
@jt.flag_scope(use_rocm=1)
def test_cuda_flags(self):
a = jt.random((10, 10))
a.sync()
@jt.flag_scope(use_rocm=1)
def test_rocm_custom_op_from_cuda(self):
my_op = jt.compile_custom_op("""
struct MyCudaOp : Op {
Var* output;
MyCudaOp(NanoVector shape, string dtype="float");
const char* name() const override { return "my_cuda"; }
DECLARE_jit_run;
};
""", """
#ifndef JIT
MyCudaOp::MyCudaOp(NanoVector shape, string dtype) {
flags.set(NodeFlags::_cuda);
output = create_output(shape, dtype);
}
void MyCudaOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", output->dtype());
}
#else // JIT
#ifdef JIT_cuda
__global__ void kernel(index_t n, T *x) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride)
x[i] = (T)-i;
}
void MyCudaOp::jit_run() {
index_t num = output->num;
auto* __restrict__ x = output->ptr<T>();
int blockSize = 256;
int numBlocks = (num + blockSize - 1) / blockSize;
kernel<<<numBlocks, blockSize>>>(num, x);
}
#endif // JIT_cuda
#endif // JIT
""",
"my_cuda")
a = my_op([3,4,5], 'float')
na = a.data
assert a.shape == [3,4,5] and a.dtype == 'float'
assert (-na.flatten() == range(3*4*5)).all(), na
def test_rocm_fused_op(self):
a = jt.array([1,2,3])
a.sync()
with jt.flag_scope(use_rocm=1):
((a+a)*2).data
class Model(Module):
def __init__(self, input_size):
self.linear1 = nn.Linear(input_size, 10)
self.relu1 = nn.Relu()
self.linear2 = nn.Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestExample(unittest.TestCase):
@jt.flag_scope(use_rocm=1)
def test1(self):
np.random.seed(0)
jt.set_seed(3)
n = 1000
batch_size = 50
lr = 0.05
def get_data(n):
for i in range(n):
x = np.random.rand(batch_size, 1).astype("float32")
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model(input_size=1)
ps = model.parameters()
for i,(x,y) in enumerate(get_data(n)):
jt.sync_all(True)
pred_y = model(x).name("pred_y")
loss = ((pred_y - y).sqr()).name("loss")
loss_mean = loss.mean()
gs = jt.grad(loss_mean, ps)
for p, g in zip(ps, gs):
p -= g * lr
if i>2:
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
prev = jt.liveness_info()
possible_results = [
0.0009948202641680837,
0.001381353591568768,
0.00110957445576787,
0.001124994712881744
]
loss_mean = loss_mean.data
assert any(abs(loss_mean - r) < 1e-6 for r in possible_results)
jt.clean()
from .test_unary_op import TestUnaryOp
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestROCmUnaryOp(TestUnaryOp, test_rocm(1)):
pass
from .test_binary_op import TestBinaryOp
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestROCmBinaryOp(TestBinaryOp, test_rocm(1)):
pass
from .test_reduce_op import TestReduceOp
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestROCmReduceOp(TestReduceOp, test_rocm(1)):
pass
from .test_reindex_op import TestReindexOp
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestROCmReindexOp(TestReindexOp, test_rocm(1)):
pass
# from .test_reindex_reduce_op import TestReindexReduceOp
# @unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
# class TestROCmReindexReduceOp(TestReindexReduceOp, test_rocm(1)):
# pass
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestROCmCodeOp(unittest.TestCase):
@jt.flag_scope(use_rocm=1)
def test_cuda(self):
a = jt.random([100000])
b = jt.random([100000])
c = jt.code(a.shape, a.dtype, [a,b],
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0_shape0; i+=stride)
@out(i) = @in0(i)*@in1(i);
}
kernel1<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0_shape0; i+=stride)
@out(i) = @dout(i)*@in1(i);
}
kernel2<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
''', '''
__global__ static void kernel3(@ARGS_DEF) {
@PRECALC
int i = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; i<in0_shape0; i+=stride)
@out(i) = @dout(i)*@in0(i);
}
kernel3<<<(in0_shape0-1)/1024+1, 1024>>>(@ARGS);
'''])
da, db = jt.grad(c, [a, b])
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
assert np.allclose(da.data, b.data)
assert np.allclose(db.data, a.data)
@jt.flag_scope(use_rocm=1)
def test_cuda2(self):
a = jt.random((100,100))
b = jt.random((100,100))
c = jt.code(a.shape, a.dtype, [a,b],
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
kernel1<<<32, 32>>>(@ARGS);
''',
cuda_grad_src = ['''
__global__ static void kernel(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in1(i,j);
}
kernel<<<32, 32>>>(@ARGS);
''', '''
__global__ static void kernel(@ARGS_DEF) {
@PRECALC
@pout(0,0);
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @dout(i,j)*@in0(i,j);
}
kernel<<<32, 32>>>(@ARGS);
'''])
da, db = jt.grad(c, [a, b])
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
assert np.allclose(da.data, b.data)
assert np.allclose(db.data, a.data)
@jt.flag_scope(use_rocm=1)
def test_cuda2_use_func(self):
class Func(Function):
def execute(self, a, b):
self.save_vars = a, b
return jt.code(a.shape, a.dtype, [a,b],
cuda_src='''
__global__ static void kernel1(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x)
@out(i,j) = @in0(i,j)*@in1(i,j);
}
kernel1<<<32, 32>>>(@ARGS);
''')
def grad(self, grad):
a, b = self.save_vars
return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad],
cuda_src='''
__global__ static void kernel2(@ARGS_DEF) {
@PRECALC
for (int i=blockIdx.x; i<in0_shape0; i+=gridDim.x)
for (int j=threadIdx.x; j<in0_shape1; j+=blockDim.x) {
@out0(i,j) = @in2(i,j)*@in1(i,j);
@out1(i,j) = @in2(i,j)*@in0(i,j);
}
}
kernel2<<<32, 32>>>(@ARGS);
''')
a = jt.random((100,100))
b = jt.random((100,100))
func = Func()
c = func(a,b)
da, db = jt.grad(c, [a, b])
assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data)
assert np.allclose(da.data, b.data)
assert np.allclose(db.data, a.data)
@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found")
class TestBMM(unittest.TestCase):
def test_bmm_rocm(self):
def check(batch, n, m, k):
def calc(use_rocm, a, b, mask):
jt.flags.use_rocm = use_rocm
a = jt.array(a)
b = jt.array(b)
mask = jt.array(mask)
c = nn.bmm(a, b)
da, db = jt.grad(c*mask, [a, b])
return c.data, da.data, db.data
mask = np.random.rand(batch, n, k).astype("float32")
a = np.random.rand(batch, n, m).astype("float32")
b = np.random.rand(batch, m, k).astype("float32")
a1,a2,a3 = calc(0, a, b, mask)
b1,b2,b3 = calc(1, a, b, mask)
assert np.allclose(a1, b1)
assert np.allclose(a2, b2)
assert np.allclose(a3, b3)
check(10,3,4,5)
check(10,8,8,8)
check(10,8,1,8)
check(10,8,8,1)
check(10,1,8,8)
check(1,7,8,8)
class Model(Module):
def __init__(self, input_size):
self.linear1 = nn.Linear(input_size, 10)
self.relu1 = nn.Relu()
self.linear2 = nn.Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
from jittor.models import resnet
class MnistNet(Module):
def __init__(self):
self.model = resnet.Resnet18()
self.layer = nn.Linear(1000,10)
def execute(self, x):
x = self.model(x)
x = self.layer(x)
return x
@unittest.skipIf(not jt.compiler.has_rocm, "skip_this_test")
class TestResnetFp32(unittest.TestCase):
# setup random seed
def setup_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
jt.seed(seed)
@jt.flag_scope(use_cuda=1)
def test_resnet(self):
self.setup_seed(1)
# hyper-parameters
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
self.weight_decay = 0.0001
self.momentum = 0.9
self.learning_rate = 0.1
if jt.flags.amp_reg:
self.learning_rate = 0.01
# mnist dataset
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
.set_attrs(batch_size=self.batch_size, shuffle=True)
self.train_loader.num_workers = 4
loss_list=[]
acc_list=[]
mnist_net = MnistNet()
global prev
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
self.train_loader.endless = True
for data, target in self.train_loader:
batch_id = self.train_loader.batch_id
epoch_id = self.train_loader.epoch_id
data = data.float_auto()
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
break
jt.sync_all(True)
for _ in range(10):
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
SGD.step(loss)
def callback(epoch_id, batch_id, loss, output, target):
pred = np.argmax(output, axis=1)
acc = np.mean(target==pred)
jt.fetch(epoch_id, _, loss, output, target, callback)
jt.sync_all(True)
all_time = time.time()
prev = time.time()
print('starting')
for _ in range(100):
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
SGD.step(loss)
def callback(epoch_id, batch_id, loss, output, target):
global prev
pred = np.argmax(output, axis=1)
acc = np.mean(target==pred)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
.format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev))
prev = time.time()
jt.fetch(epoch_id, _, loss, output, target, callback)
jt.sync_all(True)
print(f'all = {time.time() - all_time}')
if __name__ == "__main__":
unittest.main()

Binary file not shown.

View File

@ -653,7 +653,7 @@ def process_jittor_source(device_type, callback):
for name in files:
fname = os.path.join(root, name)
fname2 = os.path.join(root2, name)
if fname.endswith(".h") or fname.endswith(".cc"):
if fname.endswith(".h") or fname.endswith(".cc") or fname.endswith(".cu"):
with open(fname, 'r', encoding="utf8") as f:
src = f.read()
src = callback(src, name, {"fname":fname, "fname2":fname2})