add acl test for linear

This commit is contained in:
Dun Liang 2022-03-31 16:59:49 +08:00
parent 52127befec
commit baf9a91e0c
7 changed files with 132 additions and 18 deletions

View File

@ -161,24 +161,33 @@ string process_acl(const string& src, const string& name, const map<string,strin
return join(tokens, "");
}
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl) {
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) {
if (!is_acl) return;
extra_flags += " --tik-soc-version=Ascend910 ";
filename = replace(filename, ".cc", ".tikcc");
// LOGir << filename;
string new_src = process_acl(src, "", {});
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
new_src = replace(new_src, "__global__", "__ai_device_entry__");
new_src = token_replace(new_src, "__launch_bounds__($1)", "");
new_src = token_replace(new_src, "int thread_num = $1;", "int thread_num = 1;");
new_src = token_replace(new_src, "tn0=std::max(tn0, $1);", "");
new_src = token_replace(new_src, "<<<$1,$2>>>", "<<<1,0>>>");
new_src = token_replace(new_src, "int thread_id = $1;", "int thread_id = 1;");
new_src = token_replace_all(new_src, "__launch_bounds__($1)", "");
new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;");
new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", "");
new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>");
new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;");
// for inc error
new_src = token_replace(new_src, "for ($1+=$2)", "for ($1++)");
new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)");
// bit op error
new_src = token_replace(new_src, "int tnum$1;", "");
new_src = token_replace(new_src, "int tid$1=$2;", "int tid$1=0;");
new_src = token_replace_all(new_src, "int tnum$1;", "");
new_src = token_replace_all(new_src, "int p1$1;", "");
new_src = token_replace_all(new_src, "int p2$1;", "");
new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;");
new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;");
src = new_src;
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
// new_src = replace(new_src, "::max", "fmax");
src = new_src;
// auto tokens = token_split(new_src);
}

View File

@ -14,6 +14,6 @@ namespace jittor {
EXTERN_LIB uint64_t acl_jittor_tid;
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl);
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags);
}

View File

@ -206,7 +206,8 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
CHECK(cc_path.size());
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
string* src2 = (string*)&src;
JPU(op_compiler(jit_src_path, *src2, is_cuda_op));
string* extra_flags2 = (string*)&extra_flags;
JPU(op_compiler(jit_src_path, *src2, is_cuda_op, *extra_flags2));
#ifdef _WIN32
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
string jit_src_path2 = _to_winstr(jit_src_path);

View File

@ -107,7 +107,7 @@ static void parse_reg(const string& src,
}
}
void token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
src.at(src.size()-2) != '$') << "illegal src:" << src;
vector<string> patterns;
@ -186,6 +186,7 @@ void token_replace(vector<string>& tokens, int i, const string& src, const strin
for (int j=start_i+1; j<end_i; j++)
tokens[j] = "";
}
return end_i;
}
string token_replace(const string& s, const string& src, const string& dst) {
@ -194,4 +195,18 @@ string token_replace(const string& s, const string& src, const string& dst) {
return join(ss, "");
}
string token_replace_all(const string& s, const string& src, const string& dst) {
auto ss = token_split(s);
int pos = 0;
while (pos < ss.size()) {
try {
pos = token_replace(ss, pos, src, dst) + 1;
}
catch(const std::exception& e) {
return join(ss, "");
}
}
return join(ss, "");
}
} // jittor

View File

@ -35,8 +35,9 @@ string join(const vector<string>& vs, const string& x);
vector<string> token_split(const string& s);
void token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
string token_replace(const string& s, const string& src, const string& dst);
string token_replace_all(const string& s, const string& src, const string& dst);
} // jittor

View File

@ -8,6 +8,9 @@ import unittest
import jittor as jt
from .test_core import expect_error
import numpy as np
from jittor import init, Module
import numpy as np
@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
class TestACL(unittest.TestCase):
@ -24,8 +27,96 @@ class TestACL(unittest.TestCase):
b = a+a
np.testing.assert_allclose(b.numpy(), [2,4,6])
@jt.flag_scope(use_acl=1)
def test_add_float(self):
a = jt.array([1.0,2.0,3.0])
b = a+a
np.testing.assert_allclose(b.numpy(), [2,4,6])
@jt.flag_scope(use_acl=1)
def test_array_cast(self):
# this test cannot pass because cast error
x = np.random.rand(10)
y = jt.float32(x)
np.testing.assert_allclose(x, y.numpy())
def test_meminfo(self):
jt.display_memory_info()
def matmul(a, b):
(n, m), k = a.shape, b.shape[-1]
a = a.broadcast([n,m,k], dims=[2])
b = b.broadcast([n,m,k], dims=[0])
return (a*b).sum(dim=1)
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
self.b = jt.random((out_features,))-0.5 if bias else None
def execute(self, x):
x = matmul(x, self.w)
if self.b is not None:
return x+self.b
return x
def relu(x):
return jt.maximum(x, 0.0)
Relu = jt.make_module(relu)
class Model(Module):
def __init__(self, input_size):
self.linear1 = Linear(input_size, 10)
self.relu1 = Relu()
self.linear2 = Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
class TestExample(unittest.TestCase):
@jt.flag_scope(use_acl=1)
def test1(self):
np.random.seed(0)
jt.set_seed(3)
n = 1000
batch_size = 50
lr = 0.05
def get_data(n):
for i in range(n):
x = np.random.rand(batch_size, 1).astype("float32")
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model(input_size=1)
ps = model.parameters()
for i,(x,y) in enumerate(get_data(n)):
jt.sync_all(True)
pred_y = model(x).name("pred_y")
loss = ((pred_y - y).sqr()).name("loss")
loss_mean = loss.mean()
gs = jt.grad(loss_mean, ps)
for p, g in zip(ps, gs):
p -= g * lr
if i>2:
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
prev = jt.liveness_info()
print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}")
possible_results = [
0.0009948202641680837,
0.001381353591568768,
0.00110957445576787,
]
loss_mean = loss_mean.data
assert any(abs(loss_mean - r) < 1e-6 for r in possible_results)
jt.clean()
if __name__ == "__main__":
unittest.main()

View File

@ -8,7 +8,6 @@ import unittest
import jittor as jt
from jittor import init, Module
import numpy as np
f32 = jt.float32
def matmul(a, b):
(n, m), k = a.shape, b.shape[-1]
@ -46,9 +45,7 @@ class TestExample(unittest.TestCase):
jt.set_seed(3)
n = 1000
batch_size = 50
base_lr = 0.05
# we need to stop grad of global value to prevent memory leak
lr = f32(base_lr).name("lr").stop_grad()
lr = 0.05
def get_data(n):
for i in range(n):
@ -61,7 +58,7 @@ class TestExample(unittest.TestCase):
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x).name("pred_y")
loss = ((pred_y - y)**f32(2)).name("loss")
loss = ((pred_y - y).sqr()).name("loss")
loss_mean = loss.mean()
gs = jt.grad(loss_mean, ps)