code op support more than 10 args

This commit is contained in:
Dun Liang 2022-10-05 16:43:00 +08:00
parent b2fb32aa52
commit 9a5e7ea6f5
4 changed files with 64 additions and 10 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.16'
__version__ = '1.3.5.17'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -69,6 +69,16 @@ struct JitKey {
uint data;
explicit Oxhex2(uint data) : data(data) {}
};
struct dec1 {
uint data;
explicit dec1(uint data) : data(data) {}
};
struct dec2 {
uint data;
explicit dec2(uint data) : data(data) {}
};
};
struct __jk_int128 {
@ -173,6 +183,30 @@ inline JK& operator<<(JK& jk, const JK::Oxhex2& h) {
return jk << "0x" << JK::hex2(h.data);
}
inline JK& operator<<(JK& jk, const JK::dec2& h) {
uint8 a = h.data % 10;
uint8 b = h.data / 10;
if (b) jk << (char)(b+'0');
return jk << (char)(a+'0');
}
inline JK& operator<<(JK& jk, const JK::dec1& h) {
uint8 a = h.data % 10;
return jk << (char)(a+'0');
}
inline std::ostream& operator<<(std::ostream& os, const JK::dec2& h) {
uint8 a = h.data % 10;
uint8 b = h.data / 10;
if (b) os << (char)(b+'0');
return os << (char)(a+'0');
}
inline std::ostream& operator<<(std::ostream& os, const JK::dec1& h) {
uint8 a = h.data % 10;
return os << (char)(a+'0');
}
inline JK& operator<<(JK& jk, int c) {
if (c<0) {
c = -c;

View File

@ -100,13 +100,13 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// TODO: remove unused deps
// dout -> dout
std::stringstream new_alias;
new_alias << "\n@alias(dout,in" << inputs.size() << ")\n";
new_alias << "\n@alias(dout,in" << JK::dec2(inputs.size()) << ")\n";
inputs.push_back(dout);
// _outputs[i] -> poutj
for (int i=0; i<_outputs.size(); i++) {
new_alias << "\n@alias(pout" << i << ",in" << inputs.size() << ")\n";
new_alias << "\n@alias(pout" << JK::dec2(i) << ",in" << JK::dec2(inputs.size()) << ")\n";
if (_outputs[i] == out)
new_alias << "\n@alias(pout,in" << inputs.size() << ")\n";
new_alias << "\n@alias(pout,in" << JK::dec2(inputs.size()) << ")\n";
inputs.push_back(_outputs[i]);
}
auto alias = new_alias.str();
@ -123,18 +123,18 @@ void CodeOp::jit_prepare(JK& jk) {
// forward: in0 in1 in2 -> out0 out1
// backward: in0 in1 in2 in3(pout0) in4(pout1)
jk << "«IN_SIZE=" << JK::hex(_inputs.size());
jk << "«IN_SIZE:" << JK::dec2(_inputs.size());
for (uint i=0; i<_inputs.size(); i++) {
jk << "«in" << JK::hex(i) << "_dim="
jk << "«in" << JK::dec2(i) << "_dim:"
<< JK::hex1(_inputs[i]->shape.size());
jk << "«in" << JK::hex(i) << "_type:"
jk << "«in" << JK::dec2(i) << "_type:"
<< _inputs[i]->dtype();
}
jk << "«OUT_SIZE=" << JK::hex(_outputs.size());
jk << "«OUT_SIZE:" << JK::dec2(_outputs.size());
for (uint i=0; i<_outputs.size(); i++) {
jk << "«out" << JK::hex(i) << "_dim="
jk << "«out" << JK::dec2(i) << "_dim:"
<< JK::hex1(_outputs[i]->shape.size());
jk << "«out" << JK::hex(i) << "_type:"
jk << "«out" << JK::dec2(i) << "_type:"
<< _outputs[i]->dtype();
}
string& header = flags.get(NodeFlags::_cuda) ?

View File

@ -59,6 +59,26 @@ class TestCodeOp(unittest.TestCase):
cpu_src="out0_p[0] = ++a_global_int_var; ").item() == 124
assert jt.code([1], "int", [], cpu_header=header,
cpu_src="out0_p[0] = ++a_global_int_var; ").item() == 125
def test_ten_args(self):
a = jt.random([10])
b = jt.code([a.shape]*11, [a.dtype]*11, [jt.random([10])]*10+[a],
cpu_src='''
for (int i=0; i<in10_shape0; i++)
@out10(i) = @in10(i)*@in10(i)*2;
''',
cpu_grad_src = ['']*10+['''
for (int i=0; i<in10_shape0; i++) {
@out0(i) = @dout(i)*@in10(i)*4;
}
'''])[-1]
na, nb = jt.fetch_sync([a,b])
assert np.allclose(na*na*2, nb)
c = jt.random([10])
da = jt.grad(c*b, a)
assert np.allclose(c.data*na*4, da.data), (c.data*na*4, da.data)
def test_use_func(self):
class Func(Function):