mirror of https://github.com/Jittor/Jittor
add acl test for linear
This commit is contained in:
parent
52127befec
commit
baf9a91e0c
|
@ -161,24 +161,33 @@ string process_acl(const string& src, const string& name, const map<string,strin
|
|||
return join(tokens, "");
|
||||
}
|
||||
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl) {
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags) {
|
||||
if (!is_acl) return;
|
||||
extra_flags += " --tik-soc-version=Ascend910 ";
|
||||
filename = replace(filename, ".cc", ".tikcc");
|
||||
// LOGir << filename;
|
||||
string new_src = process_acl(src, "", {});
|
||||
new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", "");
|
||||
new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", "");
|
||||
new_src = replace(new_src, "__global__", "__ai_device_entry__");
|
||||
new_src = token_replace(new_src, "__launch_bounds__($1)", "");
|
||||
new_src = token_replace(new_src, "int thread_num = $1;", "int thread_num = 1;");
|
||||
new_src = token_replace(new_src, "tn0=std::max(tn0, $1);", "");
|
||||
new_src = token_replace(new_src, "<<<$1,$2>>>", "<<<1,0>>>");
|
||||
new_src = token_replace(new_src, "int thread_id = $1;", "int thread_id = 1;");
|
||||
new_src = token_replace_all(new_src, "__launch_bounds__($1)", "");
|
||||
new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;");
|
||||
new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", "");
|
||||
new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>");
|
||||
new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;");
|
||||
// for inc error
|
||||
new_src = token_replace(new_src, "for ($1+=$2)", "for ($1++)");
|
||||
new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)");
|
||||
// bit op error
|
||||
new_src = token_replace(new_src, "int tnum$1;", "");
|
||||
new_src = token_replace(new_src, "int tid$1=$2;", "int tid$1=0;");
|
||||
new_src = token_replace_all(new_src, "int tnum$1;", "");
|
||||
new_src = token_replace_all(new_src, "int p1$1;", "");
|
||||
new_src = token_replace_all(new_src, "int p2$1;", "");
|
||||
new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;");
|
||||
new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;");
|
||||
src = new_src;
|
||||
|
||||
new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;");
|
||||
new_src = token_replace_all(new_src, "::max($1,$2);", "($1)>($2)?($1):($2);");
|
||||
// new_src = replace(new_src, "::max", "fmax");
|
||||
src = new_src;
|
||||
// auto tokens = token_split(new_src);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,6 @@ namespace jittor {
|
|||
|
||||
EXTERN_LIB uint64_t acl_jittor_tid;
|
||||
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl);
|
||||
void acl_jittor_op_compiler(string& filename, string& src, bool is_acl, string& extra_flags);
|
||||
|
||||
}
|
||||
|
|
|
@ -206,7 +206,8 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
|||
CHECK(cc_path.size());
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
|
||||
string* src2 = (string*)&src;
|
||||
JPU(op_compiler(jit_src_path, *src2, is_cuda_op));
|
||||
string* extra_flags2 = (string*)&extra_flags;
|
||||
JPU(op_compiler(jit_src_path, *src2, is_cuda_op, *extra_flags2));
|
||||
#ifdef _WIN32
|
||||
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
|
||||
string jit_src_path2 = _to_winstr(jit_src_path);
|
||||
|
|
|
@ -107,7 +107,7 @@ static void parse_reg(const string& src,
|
|||
}
|
||||
}
|
||||
|
||||
void token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst) {
|
||||
ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' &&
|
||||
src.at(src.size()-2) != '$') << "illegal src:" << src;
|
||||
vector<string> patterns;
|
||||
|
@ -186,6 +186,7 @@ void token_replace(vector<string>& tokens, int i, const string& src, const strin
|
|||
for (int j=start_i+1; j<end_i; j++)
|
||||
tokens[j] = "";
|
||||
}
|
||||
return end_i;
|
||||
}
|
||||
|
||||
string token_replace(const string& s, const string& src, const string& dst) {
|
||||
|
@ -194,4 +195,18 @@ string token_replace(const string& s, const string& src, const string& dst) {
|
|||
return join(ss, "");
|
||||
}
|
||||
|
||||
string token_replace_all(const string& s, const string& src, const string& dst) {
|
||||
auto ss = token_split(s);
|
||||
int pos = 0;
|
||||
while (pos < ss.size()) {
|
||||
try {
|
||||
pos = token_replace(ss, pos, src, dst) + 1;
|
||||
}
|
||||
catch(const std::exception& e) {
|
||||
return join(ss, "");
|
||||
}
|
||||
}
|
||||
return join(ss, "");
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -35,8 +35,9 @@ string join(const vector<string>& vs, const string& x);
|
|||
|
||||
vector<string> token_split(const string& s);
|
||||
|
||||
void token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
|
||||
int token_replace(vector<string>& tokens, int i, const string& src, const string& dst);
|
||||
|
||||
string token_replace(const string& s, const string& src, const string& dst);
|
||||
string token_replace_all(const string& s, const string& src, const string& dst);
|
||||
|
||||
} // jittor
|
|
@ -8,6 +8,9 @@ import unittest
|
|||
import jittor as jt
|
||||
from .test_core import expect_error
|
||||
import numpy as np
|
||||
from jittor import init, Module
|
||||
import numpy as np
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
|
||||
class TestACL(unittest.TestCase):
|
||||
|
@ -24,8 +27,96 @@ class TestACL(unittest.TestCase):
|
|||
b = a+a
|
||||
np.testing.assert_allclose(b.numpy(), [2,4,6])
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_add_float(self):
|
||||
a = jt.array([1.0,2.0,3.0])
|
||||
b = a+a
|
||||
np.testing.assert_allclose(b.numpy(), [2,4,6])
|
||||
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test_array_cast(self):
|
||||
# this test cannot pass because cast error
|
||||
x = np.random.rand(10)
|
||||
y = jt.float32(x)
|
||||
np.testing.assert_allclose(x, y.numpy())
|
||||
|
||||
def test_meminfo(self):
|
||||
jt.display_memory_info()
|
||||
|
||||
|
||||
def matmul(a, b):
|
||||
(n, m), k = a.shape, b.shape[-1]
|
||||
a = a.broadcast([n,m,k], dims=[2])
|
||||
b = b.broadcast([n,m,k], dims=[0])
|
||||
return (a*b).sum(dim=1)
|
||||
|
||||
class Linear(Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
|
||||
self.b = jt.random((out_features,))-0.5 if bias else None
|
||||
def execute(self, x):
|
||||
x = matmul(x, self.w)
|
||||
if self.b is not None:
|
||||
return x+self.b
|
||||
return x
|
||||
|
||||
def relu(x):
|
||||
return jt.maximum(x, 0.0)
|
||||
Relu = jt.make_module(relu)
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = Linear(input_size, 10)
|
||||
self.relu1 = Relu()
|
||||
self.linear2 = Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_acl, "No ACL found")
|
||||
class TestExample(unittest.TestCase):
|
||||
@jt.flag_scope(use_acl=1)
|
||||
def test1(self):
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
lr = 0.05
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1).astype("float32")
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model(input_size=1)
|
||||
ps = model.parameters()
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
jt.sync_all(True)
|
||||
pred_y = model(x).name("pred_y")
|
||||
loss = ((pred_y - y).sqr()).name("loss")
|
||||
loss_mean = loss.mean()
|
||||
|
||||
gs = jt.grad(loss_mean, ps)
|
||||
for p, g in zip(ps, gs):
|
||||
p -= g * lr
|
||||
|
||||
if i>2:
|
||||
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
|
||||
prev = jt.liveness_info()
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}")
|
||||
|
||||
possible_results = [
|
||||
0.0009948202641680837,
|
||||
0.001381353591568768,
|
||||
0.00110957445576787,
|
||||
]
|
||||
loss_mean = loss_mean.data
|
||||
assert any(abs(loss_mean - r) < 1e-6 for r in possible_results)
|
||||
|
||||
jt.clean()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -8,7 +8,6 @@ import unittest
|
|||
import jittor as jt
|
||||
from jittor import init, Module
|
||||
import numpy as np
|
||||
f32 = jt.float32
|
||||
|
||||
def matmul(a, b):
|
||||
(n, m), k = a.shape, b.shape[-1]
|
||||
|
@ -46,9 +45,7 @@ class TestExample(unittest.TestCase):
|
|||
jt.set_seed(3)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
base_lr = 0.05
|
||||
# we need to stop grad of global value to prevent memory leak
|
||||
lr = f32(base_lr).name("lr").stop_grad()
|
||||
lr = 0.05
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
|
@ -61,7 +58,7 @@ class TestExample(unittest.TestCase):
|
|||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x).name("pred_y")
|
||||
loss = ((pred_y - y)**f32(2)).name("loss")
|
||||
loss = ((pred_y - y).sqr()).name("loss")
|
||||
loss_mean = loss.mean()
|
||||
|
||||
gs = jt.grad(loss_mean, ps)
|
||||
|
|
Loading…
Reference in New Issue