mirror of https://github.com/Jittor/Jittor
add jt.Var constructor
This commit is contained in:
parent
c1731cb215
commit
e7909d33ba
|
@ -132,6 +132,14 @@ class TestArray(unittest.TestCase):
|
|||
a = jt.array([1,2,3], dtype=jt.NanoString("float32"))
|
||||
a = jt.array([1,2,3], dtype=jt.float32)
|
||||
|
||||
def test_var(self):
|
||||
a = jt.Var([1,2,3])
|
||||
b = jt.Var([1,2,3], "float32")
|
||||
assert a.dtype == "int32"
|
||||
assert b.dtype == "float32"
|
||||
assert (a.numpy() == [1,2,3]).all()
|
||||
assert (b.numpy() == [1,2,3]).all()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "graph.h"
|
||||
#include "update_queue.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -61,7 +62,23 @@ VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
|||
operator delete(v);
|
||||
}
|
||||
|
||||
static auto make_array_from_pyobj = get_op_info("array")
|
||||
.get_constructor<VarPtr, PyObject*>();
|
||||
static auto make_unary = get_op_info("unary")
|
||||
.get_constructor<VarPtr, Var*, NanoString>();
|
||||
|
||||
VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
|
||||
auto vp = make_array_from_pyobj(obj);
|
||||
if (dtype != ns_void)
|
||||
vp = make_unary(vp, dtype);
|
||||
var = vp.ptr;
|
||||
vp.ptr = nullptr;
|
||||
add_hold_vars(this);
|
||||
}
|
||||
|
||||
|
||||
VarHolder::~VarHolder() {
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(!var)) return;
|
||||
hold_vars.erase(iter);
|
||||
var->release_both_liveness();
|
||||
}
|
||||
|
|
|
@ -28,6 +28,8 @@ struct ItemData {
|
|||
NanoString dtype;
|
||||
};
|
||||
|
||||
typedef struct _object PyObject;
|
||||
|
||||
// @pyjt(Var)
|
||||
// @attrs(heaptype)
|
||||
struct VarHolder {
|
||||
|
@ -37,6 +39,8 @@ struct VarHolder {
|
|||
VarHolder(VarPtr&& v);
|
||||
// will move and delete v
|
||||
VarHolder(VarHolder* v);
|
||||
// @pyjt(__init__)
|
||||
VarHolder(PyObject* v, NanoString dtype=ns_void);
|
||||
// @pyjt(__dealloc__)
|
||||
~VarHolder();
|
||||
string to_string();
|
||||
|
|
Loading…
Reference in New Issue