mirror of https://github.com/Jittor/Jittor
code op support more than 10 args
This commit is contained in:
parent
b2fb32aa52
commit
9a5e7ea6f5
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.16'
|
||||
__version__ = '1.3.5.17'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -69,6 +69,16 @@ struct JitKey {
|
|||
uint data;
|
||||
explicit Oxhex2(uint data) : data(data) {}
|
||||
};
|
||||
|
||||
struct dec1 {
|
||||
uint data;
|
||||
explicit dec1(uint data) : data(data) {}
|
||||
};
|
||||
|
||||
struct dec2 {
|
||||
uint data;
|
||||
explicit dec2(uint data) : data(data) {}
|
||||
};
|
||||
};
|
||||
|
||||
struct __jk_int128 {
|
||||
|
@ -173,6 +183,30 @@ inline JK& operator<<(JK& jk, const JK::Oxhex2& h) {
|
|||
return jk << "0x" << JK::hex2(h.data);
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const JK::dec2& h) {
|
||||
uint8 a = h.data % 10;
|
||||
uint8 b = h.data / 10;
|
||||
if (b) jk << (char)(b+'0');
|
||||
return jk << (char)(a+'0');
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const JK::dec1& h) {
|
||||
uint8 a = h.data % 10;
|
||||
return jk << (char)(a+'0');
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const JK::dec2& h) {
|
||||
uint8 a = h.data % 10;
|
||||
uint8 b = h.data / 10;
|
||||
if (b) os << (char)(b+'0');
|
||||
return os << (char)(a+'0');
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const JK::dec1& h) {
|
||||
uint8 a = h.data % 10;
|
||||
return os << (char)(a+'0');
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, int c) {
|
||||
if (c<0) {
|
||||
c = -c;
|
||||
|
|
|
@ -100,13 +100,13 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
// TODO: remove unused deps
|
||||
// dout -> dout
|
||||
std::stringstream new_alias;
|
||||
new_alias << "\n@alias(dout,in" << inputs.size() << ")\n";
|
||||
new_alias << "\n@alias(dout,in" << JK::dec2(inputs.size()) << ")\n";
|
||||
inputs.push_back(dout);
|
||||
// _outputs[i] -> poutj
|
||||
for (int i=0; i<_outputs.size(); i++) {
|
||||
new_alias << "\n@alias(pout" << i << ",in" << inputs.size() << ")\n";
|
||||
new_alias << "\n@alias(pout" << JK::dec2(i) << ",in" << JK::dec2(inputs.size()) << ")\n";
|
||||
if (_outputs[i] == out)
|
||||
new_alias << "\n@alias(pout,in" << inputs.size() << ")\n";
|
||||
new_alias << "\n@alias(pout,in" << JK::dec2(inputs.size()) << ")\n";
|
||||
inputs.push_back(_outputs[i]);
|
||||
}
|
||||
auto alias = new_alias.str();
|
||||
|
@ -123,18 +123,18 @@ void CodeOp::jit_prepare(JK& jk) {
|
|||
|
||||
// forward: in0 in1 in2 -> out0 out1
|
||||
// backward: in0 in1 in2 in3(pout0) in4(pout1)
|
||||
jk << "«IN_SIZE=" << JK::hex(_inputs.size());
|
||||
jk << "«IN_SIZE:" << JK::dec2(_inputs.size());
|
||||
for (uint i=0; i<_inputs.size(); i++) {
|
||||
jk << "«in" << JK::hex(i) << "_dim="
|
||||
jk << "«in" << JK::dec2(i) << "_dim:"
|
||||
<< JK::hex1(_inputs[i]->shape.size());
|
||||
jk << "«in" << JK::hex(i) << "_type:"
|
||||
jk << "«in" << JK::dec2(i) << "_type:"
|
||||
<< _inputs[i]->dtype();
|
||||
}
|
||||
jk << "«OUT_SIZE=" << JK::hex(_outputs.size());
|
||||
jk << "«OUT_SIZE:" << JK::dec2(_outputs.size());
|
||||
for (uint i=0; i<_outputs.size(); i++) {
|
||||
jk << "«out" << JK::hex(i) << "_dim="
|
||||
jk << "«out" << JK::dec2(i) << "_dim:"
|
||||
<< JK::hex1(_outputs[i]->shape.size());
|
||||
jk << "«out" << JK::hex(i) << "_type:"
|
||||
jk << "«out" << JK::dec2(i) << "_type:"
|
||||
<< _outputs[i]->dtype();
|
||||
}
|
||||
string& header = flags.get(NodeFlags::_cuda) ?
|
||||
|
|
|
@ -59,6 +59,26 @@ class TestCodeOp(unittest.TestCase):
|
|||
cpu_src="out0_p[0] = ++a_global_int_var; ").item() == 124
|
||||
assert jt.code([1], "int", [], cpu_header=header,
|
||||
cpu_src="out0_p[0] = ++a_global_int_var; ").item() == 125
|
||||
|
||||
def test_ten_args(self):
|
||||
a = jt.random([10])
|
||||
b = jt.code([a.shape]*11, [a.dtype]*11, [jt.random([10])]*10+[a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in10_shape0; i++)
|
||||
@out10(i) = @in10(i)*@in10(i)*2;
|
||||
''',
|
||||
cpu_grad_src = ['']*10+['''
|
||||
for (int i=0; i<in10_shape0; i++) {
|
||||
@out0(i) = @dout(i)*@in10(i)*4;
|
||||
}
|
||||
'''])[-1]
|
||||
na, nb = jt.fetch_sync([a,b])
|
||||
assert np.allclose(na*na*2, nb)
|
||||
|
||||
c = jt.random([10])
|
||||
da = jt.grad(c*b, a)
|
||||
assert np.allclose(c.data*na*4, da.data), (c.data*na*4, da.data)
|
||||
|
||||
|
||||
def test_use_func(self):
|
||||
class Func(Function):
|
||||
|
|
Loading…
Reference in New Issue