From e7909d33ba6a87767756fa97bfa6ab07854dc09b Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Wed, 20 Jan 2021 21:35:43 +0800 Subject: [PATCH] add jt.Var constructor --- python/jittor/test/test_array.py | 8 ++++++++ src/var_holder.cc | 17 +++++++++++++++++ src/var_holder.h | 4 ++++ 3 files changed, 29 insertions(+) diff --git a/python/jittor/test/test_array.py b/python/jittor/test/test_array.py index c13a2981..342efd8b 100644 --- a/python/jittor/test/test_array.py +++ b/python/jittor/test/test_array.py @@ -132,6 +132,14 @@ class TestArray(unittest.TestCase): a = jt.array([1,2,3], dtype=jt.NanoString("float32")) a = jt.array([1,2,3], dtype=jt.float32) + def test_var(self): + a = jt.Var([1,2,3]) + b = jt.Var([1,2,3], "float32") + assert a.dtype == "int32" + assert b.dtype == "float32" + assert (a.numpy() == [1,2,3]).all() + assert (b.numpy() == [1,2,3]).all() + if __name__ == "__main__": diff --git a/src/var_holder.cc b/src/var_holder.cc index 36586d09..b133f516 100644 --- a/src/var_holder.cc +++ b/src/var_holder.cc @@ -15,6 +15,7 @@ #include "graph.h" #include "update_queue.h" #include "mem/allocator/cuda_dual_allocator.h" +#include "ops/op_register.h" namespace jittor { @@ -61,7 +62,23 @@ VarHolder::VarHolder(VarHolder* v) : var(v->var) { operator delete(v); } +static auto make_array_from_pyobj = get_op_info("array") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); + +VarHolder::VarHolder(PyObject* obj, NanoString dtype) { + auto vp = make_array_from_pyobj(obj); + if (dtype != ns_void) + vp = make_unary(vp, dtype); + var = vp.ptr; + vp.ptr = nullptr; + add_hold_vars(this); +} + + VarHolder::~VarHolder() { + if (PREDICT_BRANCH_NOT_TAKEN(!var)) return; hold_vars.erase(iter); var->release_both_liveness(); } diff --git a/src/var_holder.h b/src/var_holder.h index 744c699c..56beeadf 100644 --- a/src/var_holder.h +++ b/src/var_holder.h @@ -28,6 +28,8 @@ struct ItemData { NanoString dtype; }; +typedef struct _object PyObject; + // @pyjt(Var) // @attrs(heaptype) struct VarHolder { @@ -37,6 +39,8 @@ struct VarHolder { VarHolder(VarPtr&& v); // will move and delete v VarHolder(VarHolder* v); + // @pyjt(__init__) + VarHolder(PyObject* v, NanoString dtype=ns_void); // @pyjt(__dealloc__) ~VarHolder(); string to_string();