mirror of https://github.com/Jittor/Jittor
nanostring support numpy type, error with example
This commit is contained in:
parent
5325225e9f
commit
2ed7ae6aec
|
@ -239,6 +239,46 @@ reg = re.compile(
|
|||
# attrs args $5
|
||||
, re.DOTALL)
|
||||
|
||||
def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename, h, class_info):
|
||||
# func_head is a string like:
|
||||
# (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*
|
||||
lib_name = os.path.basename(h).split("_")[0]
|
||||
# TODO: fix/add var help
|
||||
if target_scope_name == "Var": target_scope_name = None
|
||||
if target_scope_name:
|
||||
if target_scope_name == "flags":
|
||||
help_name = "flags"
|
||||
else:
|
||||
help_name = ""+target_scope_name+'.'+name
|
||||
else:
|
||||
help_name = name
|
||||
if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]:
|
||||
help_name = lib_name+'.'+help_name
|
||||
help_cmd = f"help(jt.{help_name})"
|
||||
|
||||
LOG.vvv("gen err from func_head", func_head)
|
||||
args = func_head[1:].split(")")[0].split(",")
|
||||
error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\""
|
||||
error_code += r' << "\n\nTypes of your inputs are:\n"'
|
||||
for arg in args:
|
||||
arg = arg.strip()
|
||||
if arg.startswith("PyObject* "):
|
||||
t, n = arg.split(' ')
|
||||
if n == "args" or n == "_args":
|
||||
error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} "
|
||||
elif n == "kw":
|
||||
error_code += f" << PyKwArgPrinter{{{n}}} "
|
||||
else:
|
||||
error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} "
|
||||
elif arg.startswith("PyObject** "):
|
||||
t, n = arg.split(' ')
|
||||
error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} "
|
||||
break
|
||||
else:
|
||||
LOG.vvv("Unhandled arg", arg)
|
||||
LOG.vvv("gen err from func_head", func_head, " -> ", error_code)
|
||||
return error_code
|
||||
|
||||
def compile_src(src, h, basename):
|
||||
res = list(reg.finditer(src, re.S))
|
||||
if len(res)==0: return
|
||||
|
@ -586,7 +626,7 @@ def compile_src(src, h, basename):
|
|||
|
||||
arr_func_return = []
|
||||
doc_all = ""
|
||||
decs = "Declarations:\n"
|
||||
decs = "The function declarations are:\n"
|
||||
for did, has_return in enumerate(arr_has_return):
|
||||
df = dfs[did]
|
||||
func_call = arr_func_call[did]
|
||||
|
@ -595,7 +635,7 @@ def compile_src(src, h, basename):
|
|||
doc_all += df["doc"]
|
||||
doc_all += "\nDeclaration:\n"
|
||||
doc_all += df["dec"]
|
||||
decs += df["dec"]+'\n'
|
||||
decs += " " + df["dec"]+'\n'
|
||||
if has_return:
|
||||
assert "-> int" not in func_head
|
||||
if "-> PyObject*" in func_head:
|
||||
|
@ -618,6 +658,8 @@ def compile_src(src, h, basename):
|
|||
assert "-> void" in func_head
|
||||
arr_func_return.append(f"{func_call};return")
|
||||
func_return_failed = "return"
|
||||
# generate error msg when not a valid call
|
||||
error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info)
|
||||
func = f"""
|
||||
{func_cast}[]{func_head} {{
|
||||
try {{
|
||||
|
@ -633,11 +675,15 @@ def compile_src(src, h, basename):
|
|||
'''
|
||||
for did in range(len(arr_func_return))
|
||||
])}
|
||||
LOGf << "Not a valid call";
|
||||
LOGf << "Not a valid call.";
|
||||
}} catch (const std::exception& e) {{
|
||||
PyErr_Format(PyExc_RuntimeError, "%s\\n%s",
|
||||
e.what(),
|
||||
R""({decs})""
|
||||
std::stringstream ss;
|
||||
ss {error_log_code};
|
||||
PyErr_Format(PyExc_RuntimeError,
|
||||
"%s\\n%s\\nFailed reason:%s",
|
||||
ss.str().c_str(),
|
||||
R""({decs})"",
|
||||
e.what()
|
||||
);
|
||||
}}
|
||||
{func_return_failed};
|
||||
|
@ -711,6 +757,7 @@ def compile_src(src, h, basename):
|
|||
has_seq = class_name == "NanoVector"
|
||||
code = f"""
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "pyjt/py_arg_printer.h"
|
||||
#include "common.h"
|
||||
#include "{include_name}"
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ class TestNanoString(unittest.TestCase):
|
|||
t = (time.time() - t)/n
|
||||
# t is about 0.01 for 100w loop
|
||||
# 92ns one loop
|
||||
print("nanostring time", t)
|
||||
assert t < [1.5e-7, 1.7e-7][mid], t
|
||||
|
||||
assert (jt.hash("asdasd") == 4152566416)
|
||||
|
@ -34,6 +35,31 @@ class TestNanoString(unittest.TestCase):
|
|||
# int init: 1.2
|
||||
# dtype init(cache): 0.75
|
||||
# final: 1.0
|
||||
|
||||
def test_type(self):
|
||||
import numpy as np
|
||||
assert str(jt.NanoString(float)) == "float"
|
||||
assert str(jt.NanoString(np.float)) == "float"
|
||||
assert str(jt.NanoString(np.float32)) == "float32"
|
||||
assert str(jt.NanoString(np.float64)) == "float64"
|
||||
assert str(jt.NanoString(np.int8)) == "int8"
|
||||
assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64"
|
||||
def get_error_str(call):
|
||||
es = ""
|
||||
try:
|
||||
call()
|
||||
except Exception as e:
|
||||
es = str(e)
|
||||
return es
|
||||
|
||||
e = get_error_str(lambda: jt.code([1,], {}, [1], cpu_header=""))
|
||||
assert "help(jt.ops.code)" in e
|
||||
assert "cpu_header=str" in e
|
||||
e = get_error_str(lambda: jt.NanoString([1,2,3], fuck=1))
|
||||
assert "fuck=int" in str(e)
|
||||
assert "(list, )" in str(e)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -128,13 +128,14 @@ struct NanoString {
|
|||
inline ns_t is_unary() const { return get(_type, _type_nbits)==_unary; }
|
||||
|
||||
inline NanoString() {}
|
||||
inline NanoString(const NanoString& other) : data(other.data) {}
|
||||
// @pyjt(__init__)
|
||||
inline NanoString(const char* s) {
|
||||
auto iter = __string_to_ns.find(s);
|
||||
ASSERT(iter != __string_to_ns.end()) << s;
|
||||
data = iter->second.data;
|
||||
}
|
||||
// @pyjt(__init__)
|
||||
inline NanoString(const NanoString& other) : data(other.data) {}
|
||||
inline NanoString(const string& s) : NanoString(s.c_str()) {}
|
||||
// @pyjt(__repr__)
|
||||
inline const char* to_cstring() const
|
||||
|
|
|
@ -57,7 +57,7 @@ struct CodeOp : Op {
|
|||
|
||||
```
|
||||
a = jt.random([10])
|
||||
b = jt.code(a.shape, a.dtype, [a],
|
||||
b = jt.code(a.shape, "float32", [a],
|
||||
cpu_src='''
|
||||
for (int i=0; i<in0_shape0; i++)
|
||||
@out(i) = @in0(i)*@in0(i)*2;
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "pyjt/py_arg_printer.h"
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg) {
|
||||
os << " " << arg.name << "\t= ";
|
||||
if (!arg.obj) return os << "null,";
|
||||
return os << _PyType_Name(Py_TYPE(arg.obj)) << ",\n";
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args) {
|
||||
os << " " << args.name << "\t= (";
|
||||
auto size = Py_SIZE(args.obj);
|
||||
auto arr = PySequence_Fast_ITEMS(args.obj);
|
||||
for (int i=0; i<size; i++) {
|
||||
os << _PyType_Name(Py_TYPE(arr[i])) << ", ";
|
||||
}
|
||||
return os << "),\n";
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const PyKwArgPrinter& args) {
|
||||
auto obj = args.obj;
|
||||
if (!obj) return os;
|
||||
|
||||
// auto size = Py_SIZE(obj);
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
os << " " << "kwargs\t= {";
|
||||
while (PyDict_Next(obj, &pos, &key, &value)) {
|
||||
os << from_py_object<std::string>(key) << "=" <<
|
||||
_PyType_Name(Py_TYPE(value)) << ", ";
|
||||
}
|
||||
return os << "},\n";
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args) {
|
||||
os << " args\t= (";
|
||||
auto size = args.n;
|
||||
auto arr = args.obj;
|
||||
for (int i=0; i<size; i++) {
|
||||
os << _PyType_Name(Py_TYPE(arr[i])) << ", ";
|
||||
}
|
||||
os << "),\n";
|
||||
auto kw = args.kw;
|
||||
if (!kw) return os;
|
||||
os << " kwargs\t= {";
|
||||
auto kw_n = Py_SIZE(kw);
|
||||
for (int i=0; i<kw_n; i++) {
|
||||
auto ko = PyTuple_GET_ITEM(kw, i);
|
||||
auto ks = PyUnicode_AsUTF8(ko);
|
||||
os << ks << "=" << _PyType_Name(Py_TYPE(arr[i+size])) << ", ";
|
||||
}
|
||||
return os << "},\n";
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <Python.h>
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct PyArgPrinter {
|
||||
PyObject* obj;
|
||||
const char* name;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg);
|
||||
|
||||
struct PyTupleArgPrinter {
|
||||
PyObject* obj;
|
||||
const char* name;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args);
|
||||
|
||||
struct PyKwArgPrinter {
|
||||
PyObject* obj;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os, const PyKwArgPrinter& args);
|
||||
|
||||
struct PyFastCallArgPrinter {
|
||||
PyObject** obj;
|
||||
int64 n;
|
||||
PyObject* kw;
|
||||
};
|
||||
std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args);
|
||||
|
||||
}
|
|
@ -159,7 +159,10 @@ struct NanoString;
|
|||
extern PyTypeObject PyjtNanoString;
|
||||
DEF_IS(NanoString, bool) is_type(PyObject* obj) {
|
||||
return Py_TYPE(obj) == &PyjtNanoString ||
|
||||
PyUnicode_CheckExact(obj);
|
||||
PyUnicode_CheckExact(obj) ||
|
||||
PyType_CheckExact(obj) ||
|
||||
// numpy.dtype.type
|
||||
PyObject_HasAttrString(obj, "type");
|
||||
}
|
||||
|
||||
DEF_IS(NanoString, PyObject*) to_py_object(T a) {
|
||||
|
@ -172,7 +175,14 @@ DEF_IS(NanoString, PyObject*) to_py_object(T a) {
|
|||
DEF_IS(NanoString, T) from_py_object(PyObject* obj) {
|
||||
if (Py_TYPE(obj) == &PyjtNanoString)
|
||||
return *GET_RAW_PTR(T, obj);
|
||||
return T(PyUnicode_AsUTF8(obj));
|
||||
if (PyUnicode_CheckExact(obj))
|
||||
return T(PyUnicode_AsUTF8(obj));
|
||||
// PyType
|
||||
if (PyType_CheckExact(obj))
|
||||
return T(_PyType_Name((PyTypeObject *)obj));
|
||||
PyObjHolder t(PyObject_GetAttrString(obj, "type"));
|
||||
CHECK(PyType_CheckExact(t.obj)) << "Not a valid type:" << t.obj;
|
||||
return T(_PyType_Name((PyTypeObject *)t.obj));
|
||||
}
|
||||
|
||||
// NanoVector
|
||||
|
|
Loading…
Reference in New Issue