add jt.Var constructor

This commit is contained in:
Dun Liang 2021-01-20 21:35:43 +08:00
parent c1731cb215
commit e7909d33ba
3 changed files with 29 additions and 0 deletions

View File

@ -132,6 +132,14 @@ class TestArray(unittest.TestCase):
a = jt.array([1,2,3], dtype=jt.NanoString("float32"))
a = jt.array([1,2,3], dtype=jt.float32)
def test_var(self):
a = jt.Var([1,2,3])
b = jt.Var([1,2,3], "float32")
assert a.dtype == "int32"
assert b.dtype == "float32"
assert (a.numpy() == [1,2,3]).all()
assert (b.numpy() == [1,2,3]).all()
if __name__ == "__main__":

View File

@ -15,6 +15,7 @@
#include "graph.h"
#include "update_queue.h"
#include "mem/allocator/cuda_dual_allocator.h"
#include "ops/op_register.h"
namespace jittor {
@ -61,7 +62,23 @@ VarHolder::VarHolder(VarHolder* v) : var(v->var) {
operator delete(v);
}
static auto make_array_from_pyobj = get_op_info("array")
.get_constructor<VarPtr, PyObject*>();
static auto make_unary = get_op_info("unary")
.get_constructor<VarPtr, Var*, NanoString>();
VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
auto vp = make_array_from_pyobj(obj);
if (dtype != ns_void)
vp = make_unary(vp, dtype);
var = vp.ptr;
vp.ptr = nullptr;
add_hold_vars(this);
}
VarHolder::~VarHolder() {
if (PREDICT_BRANCH_NOT_TAKEN(!var)) return;
hold_vars.erase(iter);
var->release_both_liveness();
}

View File

@ -28,6 +28,8 @@ struct ItemData {
NanoString dtype;
};
typedef struct _object PyObject;
// @pyjt(Var)
// @attrs(heaptype)
struct VarHolder {
@ -37,6 +39,8 @@ struct VarHolder {
VarHolder(VarPtr&& v);
// will move and delete v
VarHolder(VarHolder* v);
// @pyjt(__init__)
VarHolder(PyObject* v, NanoString dtype=ns_void);
// @pyjt(__dealloc__)
~VarHolder();
string to_string();