mirror of https://github.com/Jittor/Jittor
commit
7cbed2b1ab
|
@ -0,0 +1,123 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
class TestCodeOp(unittest.TestCase):
|
||||
def test(self):
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
b = data["outputs"][0]
|
||||
np.add(a,a,out=b)
|
||||
|
||||
def backward_code(np, data):
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
np.copyto(out, dout*2.0)
|
||||
|
||||
a = jt.random((5,1))
|
||||
b = jt.numpy_code(
|
||||
a.shape,
|
||||
a.dtype,
|
||||
[a],
|
||||
forward_code,
|
||||
[backward_code],
|
||||
)
|
||||
assert np.allclose(b.data,(a+a).data)
|
||||
da = jt.grad(b,a)
|
||||
one=np.ones(a.shape)
|
||||
assert np.allclose(da.data,one*2.0)
|
||||
|
||||
def test_multi_input(self):
|
||||
def forward_code(np, data):
|
||||
a,b = data["inputs"]
|
||||
c,d = data["outputs"]
|
||||
np.add(a,b,out=c)
|
||||
np.subtract(a,b,out=d)
|
||||
|
||||
def backward_code1(np, data):
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
np.copyto(out, dout)
|
||||
|
||||
def backward_code2(np, data):
|
||||
dout = data["dout"]
|
||||
out_index = data["out_index"]
|
||||
out = data["outputs"][0]
|
||||
if out_index==0:
|
||||
np.copyto(out, dout)
|
||||
else:
|
||||
np.negative(dout, out)
|
||||
|
||||
a = jt.random((5,1))
|
||||
b = jt.random((5,1))
|
||||
c, d = jt.numpy_code(
|
||||
[a.shape, a.shape],
|
||||
[a.dtype, a.dtype],
|
||||
[a, b],
|
||||
forward_code,
|
||||
[backward_code1,backward_code2],
|
||||
)
|
||||
assert np.allclose(c.data,(a+b).data)
|
||||
assert np.allclose(d.data,(a-b).data)
|
||||
dca, dcb = jt.grad(c,[a,b])
|
||||
dda, ddb = jt.grad(d,[a,b])
|
||||
one=np.ones(a.shape)
|
||||
mone=one*-1.0
|
||||
assert np.allclose(dca.data,one)
|
||||
assert np.allclose(dcb.data,one)
|
||||
assert np.allclose(dda.data,one)
|
||||
assert np.allclose(ddb.data,mone)
|
||||
|
||||
@unittest.skipIf(True, "Memory leak testing is not in progress, Skip")
|
||||
def test_memory_leak(self):
|
||||
def forward_code(np, data):
|
||||
a,b = data["inputs"]
|
||||
c,d = data["outputs"]
|
||||
np.add(a,b,out=c)
|
||||
np.subtract(a,b,out=d)
|
||||
|
||||
def backward_code1(np, data):
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
np.copyto(out, dout)
|
||||
|
||||
def backward_code2(np, data):
|
||||
dout = data["dout"]
|
||||
out_index = data["out_index"]
|
||||
out = data["outputs"][0]
|
||||
if out_index==0:
|
||||
np.copyto(out, dout)
|
||||
else:
|
||||
np.negative(dout, out)
|
||||
|
||||
for i in range(1000000):
|
||||
a = jt.random((10000,1))
|
||||
b = jt.random((10000,1))
|
||||
c, d = jt.numpy_code(
|
||||
[a.shape, a.shape],
|
||||
[a.dtype, a.dtype],
|
||||
[a, b],
|
||||
forward_code,
|
||||
[backward_code1,backward_code2],
|
||||
)
|
||||
assert np.allclose(c.data,(a+b).data)
|
||||
assert np.allclose(d.data,(a-b).data)
|
||||
dca, dcb = jt.grad(c,[a,b])
|
||||
dda, ddb = jt.grad(d,[a,b])
|
||||
one=np.ones(a.shape)
|
||||
mone=one*-1.0
|
||||
assert np.allclose(dca.data,one)
|
||||
assert np.allclose(dcb.data,one)
|
||||
assert np.allclose(dda.data,one)
|
||||
assert np.allclose(ddb.data,mone)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,52 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "common.h"
|
||||
#include "var_holder.h"
|
||||
#include "ops/array_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct NumpyResult;
|
||||
|
||||
struct NumpyFunc {
|
||||
typedef NumpyResult R;
|
||||
std::function<void(R*)> callback;
|
||||
std::function<void()> deleter;
|
||||
std::function<void()> inc_ref;
|
||||
NumpyFunc() = default;
|
||||
NumpyFunc(NumpyFunc&& other) : callback(other.callback), deleter(other.deleter), inc_ref(other.inc_ref) {
|
||||
other.callback = nullptr;
|
||||
other.deleter = nullptr;
|
||||
other.inc_ref = nullptr;
|
||||
};
|
||||
NumpyFunc(const NumpyFunc& other) : callback(other.callback), deleter(other.deleter), inc_ref(other.inc_ref) {
|
||||
inc_ref();
|
||||
};
|
||||
NumpyFunc(std::function<void(R*)>&& callback) : callback(move(callback)) {}
|
||||
NumpyFunc(std::function<void(R*)>&& callback, std::function<void()>&& deleter)
|
||||
: callback(move(callback)), deleter(move(deleter)) {};
|
||||
NumpyFunc(std::function<void(R*)>&& callback, std::function<void()>&& deleter, std::function<void()>&& inc_ref)
|
||||
: callback(move(callback)), deleter(move(deleter)), inc_ref(move(inc_ref)) {};
|
||||
~NumpyFunc() {
|
||||
if (deleter) {
|
||||
deleter();
|
||||
}
|
||||
}
|
||||
void operator =(NumpyFunc&& other) { this->~NumpyFunc(); new (this) NumpyFunc(move(other)); }
|
||||
};
|
||||
|
||||
struct NumpyResult {
|
||||
map<string, vector<DataView>> varrays;
|
||||
map<string, int> ints;
|
||||
map<string, DataView> arrays;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,111 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "ops/numpy_code_op.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static auto make_numpy_code = get_op_info("numpy_code")
|
||||
.get_constructor<VarPtr, NanoVector, NanoString, vector<Var*>&&, NumpyFunc, NumpyResult&&>();
|
||||
|
||||
NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, NumpyFunc&& forward, vector<NumpyFunc>&& sbackward)
|
||||
: _inputs(inputs), forward(move(forward))
|
||||
{
|
||||
_outputs.push_back(create_output(shape, dtype));
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
ASSERT(_outputs[0]->num >= 0);
|
||||
for (int i=0; i<sbackward.size(); i++) {
|
||||
backward.push_back(sbackward[i]);
|
||||
}
|
||||
}
|
||||
|
||||
NumpyCodeOp::NumpyCodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs, NumpyFunc&& forward, vector<NumpyFunc>&& sbackward)
|
||||
: _inputs(inputs), forward(move(forward))
|
||||
{
|
||||
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
|
||||
_outputs.resize(shapes.size());
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
CHECKop(_outputs.size(),<=,10);
|
||||
CHECKop(_outputs.size(),>,0);
|
||||
for (int i=0; i<shapes.size(); i++) {
|
||||
_outputs[i] = create_output(shapes[i], dtypes[i]);
|
||||
ASSERT(_outputs[i]->num >= 0);
|
||||
}
|
||||
for (int i=0; i<sbackward.size(); i++) {
|
||||
backward.push_back(sbackward[i]);
|
||||
}
|
||||
}
|
||||
|
||||
NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, NumpyFunc forward, NumpyResult&& results)
|
||||
: _inputs(inputs), forward(forward), _results(move(results))
|
||||
{
|
||||
_outputs.push_back(create_output(shape, dtype));
|
||||
CHECKop(_inputs.size(),<=,10);
|
||||
ASSERT(_outputs[0]->num >= 0);
|
||||
}
|
||||
|
||||
VarPtr NumpyCodeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
NumpyResult result;
|
||||
|
||||
int out_index=-1;
|
||||
for (int i=0; i<_outputs.size(); i++) {
|
||||
if (_outputs[i] == out) {
|
||||
out_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT(out_index!=-1);
|
||||
result.ints["out_index"] = out_index;
|
||||
result.arrays["dout"].ptr=dout;
|
||||
result.arrays["dout"].shape=dout->shape;
|
||||
result.arrays["dout"].dtype=dout->dtype();
|
||||
auto inputs = clone(_inputs);
|
||||
inputs.push_back(dout);
|
||||
|
||||
return make_numpy_code(
|
||||
_inputs[v_index]->shape,
|
||||
_inputs[v_index]->dtype(),
|
||||
move(inputs),
|
||||
backward[v_index],
|
||||
move(result));
|
||||
}
|
||||
|
||||
void NumpyCodeOp::run() {
|
||||
NumpyResult result;
|
||||
result.varrays = _results.varrays;
|
||||
result.ints = _results.ints;
|
||||
result.arrays = _results.arrays;
|
||||
|
||||
if (result.arrays.count("dout") > 0){
|
||||
result.arrays["dout"].ptr=((Var*)result.arrays["dout"].ptr)->ptr<DataView>();
|
||||
}
|
||||
vector<DataView> inputs(_inputs.size());
|
||||
vector<DataView> outputs(_outputs.size());
|
||||
for (int i=0; i<inputs.size(); i++) {
|
||||
inputs[i].ptr=_inputs[i]->ptr<DataView>();
|
||||
inputs[i].shape=_inputs[i]->shape;
|
||||
inputs[i].dtype=_inputs[i]->dtype();
|
||||
}
|
||||
for (int i=0; i<outputs.size(); i++) {
|
||||
outputs[i].ptr=_outputs[i]->ptr<DataView>();
|
||||
outputs[i].shape=_outputs[i]->shape;
|
||||
outputs[i].dtype=_outputs[i]->dtype();
|
||||
}
|
||||
result.varrays["inputs"] = move(inputs);
|
||||
result.varrays["outputs"] = move(outputs);
|
||||
forward.callback(&result);
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
||||
#endif // JIT
|
|
@ -0,0 +1,126 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "numpy_func.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct NumpyCodeOp : Op {
|
||||
vector<Var*> _inputs;
|
||||
vector<Var*> _outputs;
|
||||
NumpyFunc forward;
|
||||
vector<NumpyFunc> backward;
|
||||
NumpyResult _results;
|
||||
|
||||
/**
|
||||
Code Operator for easily customized op.
|
||||
|
||||
----------------
|
||||
|
||||
* [in] shape: the output shape, a integer array
|
||||
|
||||
* [in] dtype: the output data type
|
||||
|
||||
* [in] inputs: A list of input jittor Vars
|
||||
|
||||
* [in] cpu_src: cpu source code string, buildin value:
|
||||
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
|
||||
* [in] cpu_grad_src: A list of string, cpu source code string for gradient, represents gradiant for each inputm buildin value, buildin value:
|
||||
|
||||
* in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...)
|
||||
* out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...)
|
||||
* out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...)
|
||||
* pout{x}, pout{x}_shape{y}, pout{x}_stride{y}, pout{x}_type, pout{x}_p, @pout{x}(...)
|
||||
* pout, pout_shape{y}, pout_stride{y}, pout_type, pout_p, @pout(...)
|
||||
* dout, dout_shape{y}, dout_stride{y}, dout_type, dout_p, @dout(...)
|
||||
|
||||
* [in] cpu_header: cpu header code string.
|
||||
|
||||
* [in] cuda_src: cuda source code string.
|
||||
|
||||
* [in] cuda_grad_src: A list of string.
|
||||
|
||||
* [in] cuda_header: cuda header code string.
|
||||
|
||||
----------------
|
||||
|
||||
Example-1::
|
||||
|
||||
def forward_code(np, data):
|
||||
a = data["inputs"][0]
|
||||
b = data["outputs"][0]
|
||||
np.add(a,a,out=b)
|
||||
|
||||
def backward_code(np, data):
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
np.copyto(out, dout*2.0)
|
||||
|
||||
a = jt.random((5,1))
|
||||
b = jt.numpy_code(
|
||||
a.shape,
|
||||
a.dtype,
|
||||
[a],
|
||||
forward_code,
|
||||
[backward_code],
|
||||
)
|
||||
|
||||
Example-2::
|
||||
|
||||
def forward_code(np, data):
|
||||
a,b = data["inputs"]
|
||||
c,d = data["outputs"]
|
||||
np.add(a,b,out=c)
|
||||
np.subtract(a,b,out=d)
|
||||
|
||||
def backward_code1(np, data):
|
||||
dout = data["dout"]
|
||||
out = data["outputs"][0]
|
||||
np.copyto(out, dout)
|
||||
|
||||
def backward_code2(np, data):
|
||||
dout = data["dout"]
|
||||
out_index = data["out_index"]
|
||||
out = data["outputs"][0]
|
||||
if out_index==0:
|
||||
np.copyto(out, dout)
|
||||
else:
|
||||
np.negative(dout, out)
|
||||
|
||||
a = jt.random((5,1))
|
||||
b = jt.random((5,1))
|
||||
c, d = jt.numpy_code(
|
||||
[a.shape, a.shape],
|
||||
[a.dtype, a.dtype],
|
||||
[a, b],
|
||||
forward_code,
|
||||
[backward_code1,backward_code2],
|
||||
)
|
||||
|
||||
*/
|
||||
NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, NumpyFunc&& forward, vector<NumpyFunc>&& backward);
|
||||
|
||||
// @attrs(multiple_outputs)
|
||||
NumpyCodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs, NumpyFunc&& forward, vector<NumpyFunc>&& backward);
|
||||
|
||||
// @pybind(None)
|
||||
NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, NumpyFunc forward, NumpyResult&& results);
|
||||
|
||||
const char* name() const override { return "numpy_code"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -354,7 +354,6 @@ DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr<VarHolder>& holde
|
|||
|
||||
struct DataView;
|
||||
DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
||||
auto obj = GET_OBJ_FROM_RAW_PTR(a.vh);
|
||||
int64 dims[a.shape.size()];
|
||||
for (int i=0; i<a.shape.size(); i++)
|
||||
dims[i] = a.shape[i];
|
||||
|
@ -369,13 +368,24 @@ DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
|||
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEABLE, // flags
|
||||
NULL // obj
|
||||
));
|
||||
Py_INCREF(obj);
|
||||
PyObjHolder oh2(obj);
|
||||
ASSERT(PyArray_SetBaseObject(oh.obj, oh2.obj)==0);
|
||||
oh2.release();
|
||||
if (a.vh) {
|
||||
auto obj = GET_OBJ_FROM_RAW_PTR(a.vh);
|
||||
PyObjHolder oh2(obj);
|
||||
Py_INCREF(obj);
|
||||
ASSERT(PyArray_SetBaseObject(oh.obj, oh2.obj)==0);
|
||||
oh2.release();
|
||||
}
|
||||
return oh.release();
|
||||
}
|
||||
|
||||
struct NumpyFunc;
|
||||
|
||||
DEF_IS(NumpyFunc, bool) is_type(PyObject* obj) {
|
||||
return PyCallable_Check(obj);
|
||||
}
|
||||
|
||||
DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj);
|
||||
|
||||
#define CHECK_IS_1(check_type) \
|
||||
template<typename T> struct is_##check_type : public std::false_type {}; \
|
||||
template<typename T> \
|
||||
|
@ -457,6 +467,7 @@ DEF_IS(FetchFunc, T) from_py_object(PyObject* obj) {
|
|||
return func;
|
||||
}
|
||||
|
||||
|
||||
#define CHECK_IS_2(check_type) \
|
||||
template<typename T> struct is_##check_type : public std::false_type {}; \
|
||||
template<typename Ta, typename Tb> \
|
||||
|
@ -549,4 +560,34 @@ DEF_IS_1(fast_shared_ptr, T) from_py_object(PyObject* obj) {
|
|||
}
|
||||
|
||||
|
||||
|
||||
DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
|
||||
// PyObject_Call
|
||||
Py_INCREF(obj);
|
||||
T func(
|
||||
// callback
|
||||
[obj](typename T::R* result) {
|
||||
// import numpy
|
||||
PyObjHolder np(PyImport_ImportModule("numpy"));
|
||||
// data = {}
|
||||
PyObjHolder data(to_py_object(result->varrays));
|
||||
PyObjHolder data2(to_py_object(result->ints));
|
||||
PyObjHolder data3(to_py_object(result->arrays));
|
||||
PyDict_Update(data.obj, data2.obj);
|
||||
PyDict_Update(data.obj, data3.obj);
|
||||
|
||||
// args = []
|
||||
PyObjHolder args(PyTuple_New(2));
|
||||
PyTuple_SET_ITEM(args.obj, 0, np.release());
|
||||
PyTuple_SET_ITEM(args.obj, 1, data.release());
|
||||
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
|
||||
},
|
||||
// deleter
|
||||
[obj]() { Py_DECREF(obj); },
|
||||
// inc_ref
|
||||
[obj]() { Py_INCREF(obj); }
|
||||
);
|
||||
return func;
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
Loading…
Reference in New Issue