nanostring support numpy type, error with example

This commit is contained in:
Dun Liang 2020-04-27 22:17:56 +08:00
parent 5325225e9f
commit 2ed7ae6aec
7 changed files with 193 additions and 10 deletions

View File

@ -239,6 +239,46 @@ reg = re.compile(
# attrs args $5
, re.DOTALL)
def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename, h, class_info):
# func_head is a string like:
# (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*
lib_name = os.path.basename(h).split("_")[0]
# TODO: fix/add var help
if target_scope_name == "Var": target_scope_name = None
if target_scope_name:
if target_scope_name == "flags":
help_name = "flags"
else:
help_name = ""+target_scope_name+'.'+name
else:
help_name = name
if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]:
help_name = lib_name+'.'+help_name
help_cmd = f"help(jt.{help_name})"
LOG.vvv("gen err from func_head", func_head)
args = func_head[1:].split(")")[0].split(",")
error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\""
error_code += r' << "\n\nTypes of your inputs are:\n"'
for arg in args:
arg = arg.strip()
if arg.startswith("PyObject* "):
t, n = arg.split(' ')
if n == "args" or n == "_args":
error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} "
elif n == "kw":
error_code += f" << PyKwArgPrinter{{{n}}} "
else:
error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} "
elif arg.startswith("PyObject** "):
t, n = arg.split(' ')
error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} "
break
else:
LOG.vvv("Unhandled arg", arg)
LOG.vvv("gen err from func_head", func_head, " -> ", error_code)
return error_code
def compile_src(src, h, basename):
res = list(reg.finditer(src, re.S))
if len(res)==0: return
@ -586,7 +626,7 @@ def compile_src(src, h, basename):
arr_func_return = []
doc_all = ""
decs = "Declarations:\n"
decs = "The function declarations are:\n"
for did, has_return in enumerate(arr_has_return):
df = dfs[did]
func_call = arr_func_call[did]
@ -595,7 +635,7 @@ def compile_src(src, h, basename):
doc_all += df["doc"]
doc_all += "\nDeclaration:\n"
doc_all += df["dec"]
decs += df["dec"]+'\n'
decs += " " + df["dec"]+'\n'
if has_return:
assert "-> int" not in func_head
if "-> PyObject*" in func_head:
@ -618,6 +658,8 @@ def compile_src(src, h, basename):
assert "-> void" in func_head
arr_func_return.append(f"{func_call};return")
func_return_failed = "return"
# generate error msg when not a valid call
error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info)
func = f"""
{func_cast}[]{func_head} {{
try {{
@ -633,11 +675,15 @@ def compile_src(src, h, basename):
'''
for did in range(len(arr_func_return))
])}
LOGf << "Not a valid call";
LOGf << "Not a valid call.";
}} catch (const std::exception& e) {{
PyErr_Format(PyExc_RuntimeError, "%s\\n%s",
e.what(),
R""({decs})""
std::stringstream ss;
ss {error_log_code};
PyErr_Format(PyExc_RuntimeError,
"%s\\n%s\\nFailed reason:%s",
ss.str().c_str(),
R""({decs})"",
e.what()
);
}}
{func_return_failed};
@ -711,6 +757,7 @@ def compile_src(src, h, basename):
has_seq = class_name == "NanoVector"
code = f"""
#include "pyjt/py_converter.h"
#include "pyjt/py_arg_printer.h"
#include "common.h"
#include "{include_name}"

View File

@ -23,6 +23,7 @@ class TestNanoString(unittest.TestCase):
t = (time.time() - t)/n
# t is about 0.01 for 100w loop
# 92ns one loop
print("nanostring time", t)
assert t < [1.5e-7, 1.7e-7][mid], t
assert (jt.hash("asdasd") == 4152566416)
@ -34,6 +35,31 @@ class TestNanoString(unittest.TestCase):
# int init: 1.2
# dtype init(cache): 0.75
# final: 1.0
def test_type(self):
import numpy as np
assert str(jt.NanoString(float)) == "float"
assert str(jt.NanoString(np.float)) == "float"
assert str(jt.NanoString(np.float32)) == "float32"
assert str(jt.NanoString(np.float64)) == "float64"
assert str(jt.NanoString(np.int8)) == "int8"
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
def get_error_str(call):
es = ""
try:
call()
except Exception as e:
es = str(e)
return es
e = get_error_str(lambda: jt.code([1,], {}, [1], cpu_header=""))
assert "help(jt.ops.code)" in e
assert "cpu_header=str" in e
e = get_error_str(lambda: jt.NanoString([1,2,3], fuck=1))
assert "fuck=int" in str(e)
assert "(list, )" in str(e)
if __name__ == "__main__":
unittest.main()

View File

@ -128,13 +128,14 @@ struct NanoString {
inline ns_t is_unary() const { return get(_type, _type_nbits)==_unary; }
inline NanoString() {}
inline NanoString(const NanoString& other) : data(other.data) {}
// @pyjt(__init__)
inline NanoString(const char* s) {
auto iter = __string_to_ns.find(s);
ASSERT(iter != __string_to_ns.end()) << s;
data = iter->second.data;
}
// @pyjt(__init__)
inline NanoString(const NanoString& other) : data(other.data) {}
inline NanoString(const string& s) : NanoString(s.c_str()) {}
// @pyjt(__repr__)
inline const char* to_cstring() const

View File

@ -57,7 +57,7 @@ struct CodeOp : Op {
```
a = jt.random([10])
b = jt.code(a.shape, a.dtype, [a],
b = jt.code(a.shape, "float32", [a],
cpu_src='''
for (int i=0; i<in0_shape0; i++)
@out(i) = @in0(i)*@in0(i)*2;

View File

@ -0,0 +1,63 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "pyjt/py_arg_printer.h"
#include "pyjt/py_obj_holder.h"
#include "pyjt/py_converter.h"
namespace jittor {
std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg) {
os << " " << arg.name << "\t= ";
if (!arg.obj) return os << "null,";
return os << _PyType_Name(Py_TYPE(arg.obj)) << ",\n";
}
std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args) {
os << " " << args.name << "\t= (";
auto size = Py_SIZE(args.obj);
auto arr = PySequence_Fast_ITEMS(args.obj);
for (int i=0; i<size; i++) {
os << _PyType_Name(Py_TYPE(arr[i])) << ", ";
}
return os << "),\n";
}
std::ostream& operator<<(std::ostream& os, const PyKwArgPrinter& args) {
auto obj = args.obj;
if (!obj) return os;
// auto size = Py_SIZE(obj);
PyObject *key, *value;
Py_ssize_t pos = 0;
os << " " << "kwargs\t= {";
while (PyDict_Next(obj, &pos, &key, &value)) {
os << from_py_object<std::string>(key) << "=" <<
_PyType_Name(Py_TYPE(value)) << ", ";
}
return os << "},\n";
}
std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args) {
os << " args\t= (";
auto size = args.n;
auto arr = args.obj;
for (int i=0; i<size; i++) {
os << _PyType_Name(Py_TYPE(arr[i])) << ", ";
}
os << "),\n";
auto kw = args.kw;
if (!kw) return os;
os << " kwargs\t= {";
auto kw_n = Py_SIZE(kw);
for (int i=0; i<kw_n; i++) {
auto ko = PyTuple_GET_ITEM(kw, i);
auto ks = PyUnicode_AsUTF8(ko);
os << ks << "=" << _PyType_Name(Py_TYPE(arr[i+size])) << ", ";
}
return os << "},\n";
}
}

36
src/pyjt/py_arg_printer.h Normal file
View File

@ -0,0 +1,36 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <Python.h>
#include "common.h"
namespace jittor {
struct PyArgPrinter {
PyObject* obj;
const char* name;
};
std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg);
struct PyTupleArgPrinter {
PyObject* obj;
const char* name;
};
std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args);
struct PyKwArgPrinter {
PyObject* obj;
};
std::ostream& operator<<(std::ostream& os, const PyKwArgPrinter& args);
struct PyFastCallArgPrinter {
PyObject** obj;
int64 n;
PyObject* kw;
};
std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args);
}

View File

@ -159,7 +159,10 @@ struct NanoString;
extern PyTypeObject PyjtNanoString;
DEF_IS(NanoString, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtNanoString ||
PyUnicode_CheckExact(obj);
PyUnicode_CheckExact(obj) ||
PyType_CheckExact(obj) ||
// numpy.dtype.type
PyObject_HasAttrString(obj, "type");
}
DEF_IS(NanoString, PyObject*) to_py_object(T a) {
@ -172,7 +175,14 @@ DEF_IS(NanoString, PyObject*) to_py_object(T a) {
DEF_IS(NanoString, T) from_py_object(PyObject* obj) {
if (Py_TYPE(obj) == &PyjtNanoString)
return *GET_RAW_PTR(T, obj);
return T(PyUnicode_AsUTF8(obj));
if (PyUnicode_CheckExact(obj))
return T(PyUnicode_AsUTF8(obj));
// PyType
if (PyType_CheckExact(obj))
return T(_PyType_Name((PyTypeObject *)obj));
PyObjHolder t(PyObject_GetAttrString(obj, "type"));
CHECK(PyType_CheckExact(t.obj)) << "Not a valid type:" << t.obj;
return T(_PyType_Name((PyTypeObject *)t.obj));
}
// NanoVector