mirror of https://github.com/Jittor/Jittor
add binary doc string compiler
This commit is contained in:
parent
4f102fd3e8
commit
3c6b5c3b53
|
@ -460,9 +460,9 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
with open(os.path.join(jittor_path, f"src/ops/{func_name}_op.cc"), encoding="utf-8") as f:
|
||||
src = f.read()
|
||||
src = src.split(f"unordered_set<string> {func_name}_ops = ""{")[1].split("};")[0]
|
||||
res2 = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src, re.S)
|
||||
match_result = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src, re.S)
|
||||
# remove /* doc_string */ pattern
|
||||
res2 = [ (_[3], _[4]) for _ in res2 ]
|
||||
res2 = [ (_[3], _[4]) for _ in match_result ]
|
||||
LOG.vvvv(f"All supported {func_name} ops: {res2}")
|
||||
# remove op args
|
||||
if func_name == "reduce":
|
||||
|
@ -477,6 +477,10 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
last_tid = res2.index(("","float64"))
|
||||
# for each functor
|
||||
for tid, (bind_name, func_name2) in enumerate(res2):
|
||||
# get certain op doc_string
|
||||
doc_string2 = match_result[tid][1].strip()
|
||||
if len(doc_string2) == 0:
|
||||
doc_string2 = doc_string
|
||||
# add _ for types
|
||||
if func_name == "unary" and tid <= last_tid:
|
||||
func_name3 = func_name2 + "_"
|
||||
|
@ -501,7 +505,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
bind_name,
|
||||
py_args_s,
|
||||
jit_cc_src,
|
||||
doc_string,
|
||||
doc_string2,
|
||||
attrs
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue