mirror of https://github.com/Jittor/Jittor
jt.sum -> nanostring(add)
This commit is contained in:
parent
c1e0861446
commit
fc9b22e796
|
@ -44,6 +44,14 @@ class TestNanoString(unittest.TestCase):
|
|||
assert str(jt.NanoString(np.float64)) == "float64"
|
||||
assert str(jt.NanoString(np.int8)) == "int8"
|
||||
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
||||
|
||||
assert str(jt.NanoString(jt.float)) == "float"
|
||||
assert str(jt.NanoString(jt.float32)) == "float32"
|
||||
assert str(jt.NanoString(jt.float64)) == "float64"
|
||||
assert str(jt.NanoString(jt.int8)) == "int8"
|
||||
assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int64"
|
||||
assert str(jt.NanoString(jt.sum)) == "add"
|
||||
|
||||
def get_error_str(call):
|
||||
es = ""
|
||||
try:
|
||||
|
|
|
@ -147,6 +147,7 @@ static void init_ns() {
|
|||
#define INIT_NS(T) func(#T, ns_##T);
|
||||
FOR_ALL_NS(INIT_NS);
|
||||
ASSERT(NanoString::__ns_to_string.size()<=(1<<NanoString::_index_nbits));
|
||||
NanoString::__string_to_ns["sum"] = ns_add;
|
||||
LOGvv << "init __string_to_ns" << NanoString::__string_to_ns;
|
||||
LOGvv << "init __ns_to_string" << NanoString::__ns_to_string;
|
||||
}
|
||||
|
|
|
@ -161,6 +161,8 @@ DEF_IS(NanoString, bool) is_type(PyObject* obj) {
|
|||
return Py_TYPE(obj) == &PyjtNanoString ||
|
||||
PyUnicode_CheckExact(obj) ||
|
||||
PyType_CheckExact(obj) ||
|
||||
// jt.float.__name__
|
||||
PyCallable_Check(obj) ||
|
||||
// numpy.dtype.type
|
||||
PyObject_HasAttrString(obj, "type");
|
||||
}
|
||||
|
@ -180,6 +182,11 @@ DEF_IS(NanoString, T) from_py_object(PyObject* obj) {
|
|||
// PyType
|
||||
if (PyType_CheckExact(obj))
|
||||
return T(_PyType_Name((PyTypeObject *)obj));
|
||||
// jt.float.__name__
|
||||
if (PyCallable_Check(obj)) {
|
||||
PyObjHolder t(PyObject_GetAttrString(obj, "__name__"));
|
||||
return T(PyUnicode_AsUTF8(t.obj));
|
||||
}
|
||||
PyObjHolder t(PyObject_GetAttrString(obj, "type"));
|
||||
CHECK(PyType_CheckExact(t.obj)) << "Not a valid type:" << t.obj;
|
||||
return T(_PyType_Name((PyTypeObject *)t.obj));
|
||||
|
|
Loading…
Reference in New Issue