mirror of https://github.com/Jittor/Jittor
jt.sum -> nanostring(add)
This commit is contained in:
parent
c1e0861446
commit
fc9b22e796
|
@ -44,6 +44,14 @@ class TestNanoString(unittest.TestCase):
|
||||||
assert str(jt.NanoString(np.float64)) == "float64"
|
assert str(jt.NanoString(np.float64)) == "float64"
|
||||||
assert str(jt.NanoString(np.int8)) == "int8"
|
assert str(jt.NanoString(np.int8)) == "int8"
|
||||||
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
||||||
|
|
||||||
|
assert str(jt.NanoString(jt.float)) == "float"
|
||||||
|
assert str(jt.NanoString(jt.float32)) == "float32"
|
||||||
|
assert str(jt.NanoString(jt.float64)) == "float64"
|
||||||
|
assert str(jt.NanoString(jt.int8)) == "int8"
|
||||||
|
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
|
||||||
|
assert str(jt.NanoString(jt.sum)) == "add"
|
||||||
|
|
||||||
def get_error_str(call):
|
def get_error_str(call):
|
||||||
es = ""
|
es = ""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -147,6 +147,7 @@ static void init_ns() {
|
||||||
#define INIT_NS(T) func(#T, ns_##T);
|
#define INIT_NS(T) func(#T, ns_##T);
|
||||||
FOR_ALL_NS(INIT_NS);
|
FOR_ALL_NS(INIT_NS);
|
||||||
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
|
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
|
||||||
|
NanoString::__string_to_ns["sum"] = ns_add;
|
||||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
||||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,6 +161,8 @@ DEF_IS(NanoString, bool) is_type(PyObject* obj) {
|
||||||
return Py_TYPE(obj) == &PyjtNanoString ||
|
return Py_TYPE(obj) == &PyjtNanoString ||
|
||||||
PyUnicode_CheckExact(obj) ||
|
PyUnicode_CheckExact(obj) ||
|
||||||
PyType_CheckExact(obj) ||
|
PyType_CheckExact(obj) ||
|
||||||
|
// jt.float.__name__
|
||||||
|
PyCallable_Check(obj) ||
|
||||||
// numpy.dtype.type
|
// numpy.dtype.type
|
||||||
PyObject_HasAttrString(obj, "type");
|
PyObject_HasAttrString(obj, "type");
|
||||||
}
|
}
|
||||||
|
@ -180,6 +182,11 @@ DEF_IS(NanoString, T) from_py_object(PyObject* obj) {
|
||||||
// PyType
|
// PyType
|
||||||
if (PyType_CheckExact(obj))
|
if (PyType_CheckExact(obj))
|
||||||
return T(_PyType_Name((PyTypeObject *)obj));
|
return T(_PyType_Name((PyTypeObject *)obj));
|
||||||
|
// jt.float.__name__
|
||||||
|
if (PyCallable_Check(obj)) {
|
||||||
|
PyObjHolder t(PyObject_GetAttrString(obj, "__name__"));
|
||||||
|
return T(PyUnicode_AsUTF8(t.obj));
|
||||||
|
}
|
||||||
PyObjHolder t(PyObject_GetAttrString(obj, "type"));
|
PyObjHolder t(PyObject_GetAttrString(obj, "type"));
|
||||||
CHECK(PyType_CheckExact(t.obj)) << "Not a valid type:" << t.obj;
|
CHECK(PyType_CheckExact(t.obj)) << "Not a valid type:" << t.obj;
|
||||||
return T(_PyType_Name((PyTypeObject *)t.obj));
|
return T(_PyType_Name((PyTypeObject *)t.obj));
|
||||||
|
|
Loading…
Reference in New Issue