add th_mode

This commit is contained in:
Dun Liang 2022-05-10 18:04:35 +08:00
parent 88815e8dd3
commit c113eabcca
5 changed files with 13 additions and 5 deletions

View File

@ -83,7 +83,7 @@ def map_flags(flags, func):
output.append(func(s))
return " ".join(output)
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags=""):
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="", obj_dirname="obj_files"):
def do_compile(cmd):
if jit_utils.cc:
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
@ -108,13 +108,15 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="")
obj_files = []
ex_obj_files = []
new_inputs = []
obj_dir = os.path.join(cache_path, obj_dirname)
os.makedirs(obj_dir, exist_ok=True)
for name in inputs:
if name[-1] in 'oab':
ex_obj_files.append(name)
else:
new_inputs.append(os.path.join(jittor_path, name))
obj_files.append(os.path.join(
cache_path, "obj_files", os.path.basename(name)+".o"))
obj_dir, os.path.basename(name)+".o"))
inputs = new_inputs
cm = lambda s: f"\"{s}\""
cms = lambda arr: [f"\"{s}\"" for s in arr ]

View File

@ -897,7 +897,6 @@ def compile(cache_path, jittor_path):
pyjt_names.append(fname)
code = f"""
#include "pyjt/numpy.h"
#include "pyjt/py_converter.h"
#include "common.h"
@ -906,7 +905,6 @@ def compile(cache_path, jittor_path):
{ " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}
void pyjt_def_all(PyObject* m) {{
numpy_init();
{ " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
}}

View File

@ -20,6 +20,7 @@
namespace jittor {
DECLARE_FLAG(string, cache_path);
DECLARE_FLAG(uint8, th_mode);
DEFINE_FLAG(int, try_use_32bit_index, 0,
"If not overflow, try to use 32 bit type as index type.");
@ -87,6 +88,11 @@ void Op::init() {
exe.run_sync(vector<Var*>({need_sync}), false);
CHECK(need_sync->num >= 0) << need_sync << "'s shape is error";
}
if (th_mode) {
for (Var* v : outputs()) {
v->set_stop_grad();
}
}
}
void Op::compile_optimize(string& src) {}

View File

@ -8,6 +8,7 @@
#include "grad.h"
#include "pyjt/py_obj_holder.h"
#include "init.h"
#include "pyjt/numpy.h"
#include "utils/seh.h"
namespace jittor {
@ -34,6 +35,7 @@ vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets, boo
static void init_module(PyModuleDef* mdef, PyObject* m) {
mdef->m_doc = "Inner c++ core of jittor";
jittor::init();
jittor::numpy_init();
jittor::pyjt_def_all(m);
}
PYJT_MODULE_INIT(jittor_core);

View File

@ -22,6 +22,7 @@ DEFINE_FLAG(bool, no_grad, 0,
DEFINE_FLAG(bool, no_fuse, 0,
"No fusion optimization for all jittor Var creation");
DEFINE_FLAG(uint8, node_order, 0, "id prior");
DEFINE_FLAG(uint8, th_mode, 0, "th mode");
// TODO: fuse multiple flags
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
@ -54,7 +55,6 @@ string Var::to_string() {
}
int64 Var::numel() {
if (!shape.size()) return size=num=-1;
bool negtive = 0;
num=1;
for (auto k : shape) {