mirror of https://github.com/Jittor/Jittor
add th_mode
This commit is contained in:
parent
88815e8dd3
commit
c113eabcca
|
@ -83,7 +83,7 @@ def map_flags(flags, func):
|
|||
output.append(func(s))
|
||||
return " ".join(output)
|
||||
|
||||
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags=""):
|
||||
def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="", obj_dirname="obj_files"):
|
||||
def do_compile(cmd):
|
||||
if jit_utils.cc:
|
||||
return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
|
||||
|
@ -108,13 +108,15 @@ def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="")
|
|||
obj_files = []
|
||||
ex_obj_files = []
|
||||
new_inputs = []
|
||||
obj_dir = os.path.join(cache_path, obj_dirname)
|
||||
os.makedirs(obj_dir, exist_ok=True)
|
||||
for name in inputs:
|
||||
if name[-1] in 'oab':
|
||||
ex_obj_files.append(name)
|
||||
else:
|
||||
new_inputs.append(os.path.join(jittor_path, name))
|
||||
obj_files.append(os.path.join(
|
||||
cache_path, "obj_files", os.path.basename(name)+".o"))
|
||||
obj_dir, os.path.basename(name)+".o"))
|
||||
inputs = new_inputs
|
||||
cm = lambda s: f"\"{s}\""
|
||||
cms = lambda arr: [f"\"{s}\"" for s in arr ]
|
||||
|
|
|
@ -897,7 +897,6 @@ def compile(cache_path, jittor_path):
|
|||
pyjt_names.append(fname)
|
||||
|
||||
code = f"""
|
||||
#include "pyjt/numpy.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "common.h"
|
||||
|
||||
|
@ -906,7 +905,6 @@ def compile(cache_path, jittor_path):
|
|||
{ " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}
|
||||
|
||||
void pyjt_def_all(PyObject* m) {{
|
||||
numpy_init();
|
||||
{ " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
|
||||
}}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(string, cache_path);
|
||||
DECLARE_FLAG(uint8, th_mode);
|
||||
|
||||
DEFINE_FLAG(int, try_use_32bit_index, 0,
|
||||
"If not overflow, try to use 32 bit type as index type.");
|
||||
|
@ -87,6 +88,11 @@ void Op::init() {
|
|||
exe.run_sync(vector<Var*>({need_sync}), false);
|
||||
CHECK(need_sync->num >= 0) << need_sync << "'s shape is error";
|
||||
}
|
||||
if (th_mode) {
|
||||
for (Var* v : outputs()) {
|
||||
v->set_stop_grad();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Op::compile_optimize(string& src) {}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "grad.h"
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "init.h"
|
||||
#include "pyjt/numpy.h"
|
||||
#include "utils/seh.h"
|
||||
|
||||
namespace jittor {
|
||||
|
@ -34,6 +35,7 @@ vector<VarHolder*> _grad(VarHolder* loss, const vector<VarHolder*>& targets, boo
|
|||
static void init_module(PyModuleDef* mdef, PyObject* m) {
|
||||
mdef->m_doc = "Inner c++ core of jittor";
|
||||
jittor::init();
|
||||
jittor::numpy_init();
|
||||
jittor::pyjt_def_all(m);
|
||||
}
|
||||
PYJT_MODULE_INIT(jittor_core);
|
||||
|
|
|
@ -22,6 +22,7 @@ DEFINE_FLAG(bool, no_grad, 0,
|
|||
DEFINE_FLAG(bool, no_fuse, 0,
|
||||
"No fusion optimization for all jittor Var creation");
|
||||
DEFINE_FLAG(uint8, node_order, 0, "id prior");
|
||||
DEFINE_FLAG(uint8, th_mode, 0, "th mode");
|
||||
// TODO: fuse multiple flags
|
||||
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
|
||||
|
||||
|
@ -54,7 +55,6 @@ string Var::to_string() {
|
|||
}
|
||||
|
||||
int64 Var::numel() {
|
||||
if (!shape.size()) return size=num=-1;
|
||||
bool negtive = 0;
|
||||
num=1;
|
||||
for (auto k : shape) {
|
||||
|
|
Loading…
Reference in New Issue