mirror of https://github.com/Jittor/Jittor
add data interface for code op without recompile
This commit is contained in:
parent
f7ba3cab31
commit
607d13079f
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.37'
|
||||
__version__ = '1.3.5.38'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
namespace jittor {
|
||||
|
||||
static auto make_code = get_op_info("code")
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&, string&&, vector<string>&&, string&&>();
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, string&&, vector<string>&&, string&&, string&&, vector<string>&&, string&&, DataMap&&>();
|
||||
|
||||
static inline void check_vary_shape(NanoVector v) {
|
||||
ASSERT(v.size()) << "Vary shape should not be zero dimension";
|
||||
|
@ -28,9 +28,11 @@ static inline void check_vary_shape(NanoVector v) {
|
|||
|
||||
CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header,
|
||||
DataMap&& data)
|
||||
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header)),
|
||||
data(move(data))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
|
@ -48,9 +50,11 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
|
|||
CodeOp::CodeOp(
|
||||
vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header,
|
||||
DataMap&& data)
|
||||
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header)),
|
||||
data(move(data))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
|
@ -70,9 +74,11 @@ CodeOp::CodeOp(
|
|||
CodeOp::CodeOp(
|
||||
vector<Var*>&& inputs, vector<Var*>&& outputs,
|
||||
string&& cpu_src, vector<string>&& cpu_grad_src, string&& cpu_header,
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header)
|
||||
string&& cuda_src, vector<string>&& cuda_grad_src, string&& cuda_header,
|
||||
DataMap&& data)
|
||||
: _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)),
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header))
|
||||
cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header)),
|
||||
data(move(data))
|
||||
{
|
||||
flags.set(NodeFlags::_cpu, !!this->cpu_src.size());
|
||||
flags.set(NodeFlags::_cuda, !!this->cuda_src.size());
|
||||
|
@ -115,7 +121,8 @@ VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
_inputs[v_index]->dtype(),
|
||||
move(inputs),
|
||||
move(cpu_src), {}, alias+cpu_header,
|
||||
move(cuda_src), {}, alias+cuda_header
|
||||
move(cuda_src), {}, alias+cuda_header,
|
||||
DataMap(data)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -216,9 +223,12 @@ int __tmp
|
|||
)
|
||||
|
||||
@alias(out, out0)
|
||||
#undef out
|
||||
|
||||
@HEADER
|
||||
|
||||
#define out out0
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void CodeOp::jit_run() {
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
typedef unordered_map<string,double> DataMap;
|
||||
|
||||
struct CodeOp : Op {
|
||||
vector<Var*> _inputs;
|
||||
vector<Var*> _outputs;
|
||||
|
@ -18,6 +20,7 @@ struct CodeOp : Op {
|
|||
string cuda_src;
|
||||
vector<string> cuda_grad_src;
|
||||
string cuda_header;
|
||||
DataMap data;
|
||||
/**
|
||||
Code Operator for easily customized op.
|
||||
|
||||
|
@ -154,6 +157,22 @@ struct CodeOp : Op {
|
|||
print(b[0])
|
||||
# will output 233
|
||||
|
||||
Example-6::
|
||||
|
||||
|
||||
# This example shows how to pass custom data
|
||||
# into code op kernel without kernel recompiling.
|
||||
# In this example, the data {"x":123} canbe vary
|
||||
# and kernel will not recompile.
|
||||
# NOTE: the data type pass into kernel is float64
|
||||
# cast to int if you want
|
||||
|
||||
a = jt.code([1], "float32", inputs=[],
|
||||
data = {"x":123},
|
||||
cpu_src='''
|
||||
@out0(0) = data["x"];
|
||||
''').sync()
|
||||
assert a.item() == 123
|
||||
|
||||
CUDA Example-1::
|
||||
|
||||
|
@ -243,13 +262,13 @@ struct CodeOp : Op {
|
|||
print(c)
|
||||
print(jt.grad(c, [a, b]))
|
||||
*/
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="", DataMap&& data={});
|
||||
|
||||
// @attrs(multiple_outputs)
|
||||
CodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
CodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs={}, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="", DataMap&& data={});
|
||||
|
||||
// @attrs(multiple_outputs,replace_outputs)
|
||||
CodeOp(vector<Var*>&& inputs, vector<Var*>&& outputs, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="");
|
||||
CodeOp(vector<Var*>&& inputs, vector<Var*>&& outputs, string&& cpu_src="", vector<string>&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector<string>&& cuda_grad_src={}, string&& cuda_header="", DataMap&& data={});
|
||||
|
||||
|
||||
const char* name() const override { return "code"; }
|
||||
|
|
|
@ -221,8 +221,7 @@ static string get_marks(Op* op, bool is_fused) {
|
|||
|
||||
static string origin_key(const string& s) {
|
||||
if (s.size() && s[0]=='[') {
|
||||
auto vs = split(s, "]");
|
||||
return vs[vs.size()-1];
|
||||
return s.substr(s.find("]")+1);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
|
|
@ -155,8 +155,9 @@ VarHolder* VarHolder::_update(VarHolder* v) {
|
|||
|
||||
EXTERN_LIB Executor exe;
|
||||
|
||||
void VarHolder::sync(bool device_sync, bool weak_sync) {
|
||||
VarHolder* VarHolder::sync(bool device_sync, bool weak_sync) {
|
||||
jittor::sync({this}, device_sync, weak_sync);
|
||||
return this;
|
||||
}
|
||||
|
||||
ArrayArgs VarHolder::fetch_sync() {
|
||||
|
|
|
@ -49,7 +49,8 @@ struct VarHolder {
|
|||
~VarHolder();
|
||||
string to_string();
|
||||
// @pyjt(sync)
|
||||
void sync(bool device_sync = false, bool weak_sync = true);
|
||||
// @attrs(return_self)
|
||||
VarHolder* sync(bool device_sync = false, bool weak_sync = true);
|
||||
|
||||
/**
|
||||
* Returns a numpy array copy of the Var.
|
||||
|
|
|
@ -374,6 +374,13 @@ class TestCodeOp(unittest.TestCase):
|
|||
assert np.allclose(da.data, b.data)
|
||||
assert np.allclose(db.data, a.data)
|
||||
|
||||
def test_simple_var(self):
|
||||
a = jt.code([1], "float32", inputs=[],
|
||||
data = {"x":123},
|
||||
cpu_src='''
|
||||
@out0(0) = data["x"];
|
||||
''').sync()
|
||||
assert a.item() == 123
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue