mirror of https://github.com/Jittor/Jittor
add jt.Var constructor
This commit is contained in:
parent
c1731cb215
commit
e7909d33ba
|
@ -132,6 +132,14 @@ class TestArray(unittest.TestCase):
|
||||||
a = jt.array([1,2,3], dtype=jt.NanoString("float32"))
|
a = jt.array([1,2,3], dtype=jt.NanoString("float32"))
|
||||||
a = jt.array([1,2,3], dtype=jt.float32)
|
a = jt.array([1,2,3], dtype=jt.float32)
|
||||||
|
|
||||||
|
def test_var(self):
|
||||||
|
a = jt.Var([1,2,3])
|
||||||
|
b = jt.Var([1,2,3], "float32")
|
||||||
|
assert a.dtype == "int32"
|
||||||
|
assert b.dtype == "float32"
|
||||||
|
assert (a.numpy() == [1,2,3]).all()
|
||||||
|
assert (b.numpy() == [1,2,3]).all()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "graph.h"
|
#include "graph.h"
|
||||||
#include "update_queue.h"
|
#include "update_queue.h"
|
||||||
#include "mem/allocator/cuda_dual_allocator.h"
|
#include "mem/allocator/cuda_dual_allocator.h"
|
||||||
|
#include "ops/op_register.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -61,7 +62,23 @@ VarHolder::VarHolder(VarHolder* v) : var(v->var) {
|
||||||
operator delete(v);
|
operator delete(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static auto make_array_from_pyobj = get_op_info("array")
|
||||||
|
.get_constructor<VarPtr, PyObject*>();
|
||||||
|
static auto make_unary = get_op_info("unary")
|
||||||
|
.get_constructor<VarPtr, Var*, NanoString>();
|
||||||
|
|
||||||
|
VarHolder::VarHolder(PyObject* obj, NanoString dtype) {
|
||||||
|
auto vp = make_array_from_pyobj(obj);
|
||||||
|
if (dtype != ns_void)
|
||||||
|
vp = make_unary(vp, dtype);
|
||||||
|
var = vp.ptr;
|
||||||
|
vp.ptr = nullptr;
|
||||||
|
add_hold_vars(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
VarHolder::~VarHolder() {
|
VarHolder::~VarHolder() {
|
||||||
|
if (PREDICT_BRANCH_NOT_TAKEN(!var)) return;
|
||||||
hold_vars.erase(iter);
|
hold_vars.erase(iter);
|
||||||
var->release_both_liveness();
|
var->release_both_liveness();
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,8 @@ struct ItemData {
|
||||||
NanoString dtype;
|
NanoString dtype;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef struct _object PyObject;
|
||||||
|
|
||||||
// @pyjt(Var)
|
// @pyjt(Var)
|
||||||
// @attrs(heaptype)
|
// @attrs(heaptype)
|
||||||
struct VarHolder {
|
struct VarHolder {
|
||||||
|
@ -37,6 +39,8 @@ struct VarHolder {
|
||||||
VarHolder(VarPtr&& v);
|
VarHolder(VarPtr&& v);
|
||||||
// will move and delete v
|
// will move and delete v
|
||||||
VarHolder(VarHolder* v);
|
VarHolder(VarHolder* v);
|
||||||
|
// @pyjt(__init__)
|
||||||
|
VarHolder(PyObject* v, NanoString dtype=ns_void);
|
||||||
// @pyjt(__dealloc__)
|
// @pyjt(__dealloc__)
|
||||||
~VarHolder();
|
~VarHolder();
|
||||||
string to_string();
|
string to_string();
|
||||||
|
|
Loading…
Reference in New Issue