Merge branch 'master' into fft

This commit is contained in:
cxjyxx_me 2022-03-26 23:22:55 -04:00
commit b04ad0ccb4
29 changed files with 203 additions and 67 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.38'
__version__ = '1.3.1.46'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -410,9 +410,7 @@ def flatten(input, start_dim=0, end_dim=-1):
return input.reshape(out_shape)
Var.flatten = flatten
def start_grad(x):
return x._update(x)
Var.detach_inplace = Var.start_grad = start_grad
Var.detach_inplace = Var.start_grad
def detach(x):
return x.detach()

View File

@ -4716,7 +4716,7 @@ class Var:
>>> jt.arg_reduce(x, 'max', dim=1, keepdims=False)
[jt.Var([2 1], dtype=int32), jt.Var([5 7], dtype=int32)]
>>> jt.arg_reduce(x, 'min', dim=1, keepdims=False)
[jt.Var([1 2], dtype=int32), jt.Var([5 7], dtype=int32)]'''
[jt.Var([1 2], dtype=int32), jt.Var([2 1], dtype=int32)]'''
...
@overload
def reduce(self, op: str, dim: int, keepdims: bool=False)-> Var: ...

View File

@ -190,6 +190,17 @@ def setup_cub():
def setup_cuda_extern():
if not has_cuda: return
def split(a): return a.replace(";",":").split(":")
check_ld_path = split(os.environ.get("LD_LIBRARY_PATH", "")) + \
split(os.environ.get("PATH", ""))
for cp in check_ld_path:
cp = cp.lower()
if "cuda" in cp and \
"lib" in cp and \
"jtcuda" not in cp:
LOG.w(f"CUDA related path found in LD_LIBRARY_PATH or PATH({check_ld_path}), "
"This path may cause jittor found the wrong libs, "
"please unset LD_LIBRARY_PATH and remove cuda lib path in Path. ")
LOG.vv("setup cuda extern...")
cache_path_cuda = os.path.join(cache_path, "cuda")
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")

View File

@ -199,7 +199,7 @@ def gen_jit_tests():
}} // jittor
"""
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w', encoding='utf8') as f:
f.write(jit_src)
def gen_jit_flags():
@ -257,7 +257,7 @@ def gen_jit_flags():
}} // jittor
"""
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w', encoding='utf8') as f:
f.write(jit_src)
def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
@ -639,9 +639,9 @@ def compile_custom_op(header, source, op_name, warp=True):
make_cache_dir(cops_dir)
hname = os.path.join(cops_dir, op_name+"_op.h")
ccname = os.path.join(cops_dir, op_name+"_op.cc")
with open(hname, 'w') as f:
with open(hname, 'w', encoding='utf8') as f:
f.write(header)
with open(ccname, 'w') as f:
with open(ccname, 'w', encoding='utf8') as f:
f.write(source)
m = compile_custom_ops([hname, ccname])
return getattr(m, op_name)
@ -679,7 +679,7 @@ def compile_custom_ops(
dirname = os.path.dirname(name)
if dirname.endswith("inc"):
includes.append(dirname)
with open(name, "r") as f:
with open(name, "r", encoding='utf8') as f:
if "@pyjt" in f.read():
pyjt_includes.append(name)
bname = os.path.basename(name)
@ -736,7 +736,7 @@ def compile_custom_ops(
"init_module(PyModuleDef* mdef, PyObject* m) {",
f"jittor::pyjt_def_{bname}(m);")
with open(gen_head_fname, "w") as f:
with open(gen_head_fname, "w", encoding='utf8') as f:
f.write(gen_src)
LOG.vvv(f"Build custum ops lib:{gen_lib}")
@ -781,7 +781,7 @@ def compile_extern():
files = os.listdir(jittor_path_llvm)
# test_pass.cc is used for test link problem of llvm pass plugin
test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
with open(test_pass_path, 'w') as f:
with open(test_pass_path, 'w', encoding='utf8') as f:
f.write("int main() {return 0;}")
# -fno-rtti fix link error
@ -987,7 +987,7 @@ if nvcc_path:
nvcc_version = list(map(int,v.split('.')))
cu += v
try:
r, s = sp.getstatusoutput(f"{sys.executable} -m jittor_utils.query_cuda_cc")
r, s = sp.getstatusoutput(f"log_v=0 {sys.executable} -m jittor_utils.query_cuda_cc")
if r==0:
s = sorted(list(set(s.strip().split())))
cu += "_sm_" + "_".join(s)
@ -1082,7 +1082,7 @@ if os.name == 'nt':
cc_flags = cc_flags.replace("-lstdc++", "")
cc_flags = cc_flags.replace("-ldl", "")
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
cc_flags += " -EHa -MD "
cc_flags += " -EHa -MD -utf-8 "
import jittor_utils
if jittor_utils.msvc_path:
mp = jittor_utils.msvc_path
@ -1176,6 +1176,7 @@ if has_cuda:
nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:")
nvcc_flags = nvcc_flags.replace("-EH", "-Xcompiler -EH")
nvcc_flags = nvcc_flags.replace("-M", "-Xcompiler -M")
nvcc_flags = nvcc_flags.replace("-utf", "-Xcompiler -utf")
nvcc_flags = nvcc_flags.replace("-nologo", "")
nvcc_flags = nvcc_flags.replace("-std:", "-std=")
nvcc_flags = nvcc_flags.replace("-Fo:", "-o")
@ -1217,7 +1218,7 @@ gen_jit_tests()
op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True)
jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w', encoding='utf8') as f:
f.write(jit_src)
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
# gen pyjt
@ -1314,7 +1315,8 @@ with jit_utils.import_scope(import_flags):
flags = core.Flags()
if has_cuda:
nvcc_flags = convert_nvcc_flags(cc_flags)
nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " "
nvcc_flags += convert_nvcc_flags(cc_flags)
nvcc_version = list(jit_utils.get_int_version(nvcc_path))
max_arch = 1000
if nvcc_version < [11,]:

View File

@ -243,7 +243,7 @@ def concat(arr, dim=0):
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
# return jt.Var([[1,2],[2,2]],dtype=int32)
'''
if not isinstance(arr, Sequence):
raise TypeError("concat arr needs to be a tuple or list")

View File

@ -683,9 +683,9 @@ def cub_cumsum(x, dim=None):
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
shape = x.shape
shape = list(x.shape)
if (dim != -1 and dim != len(shape) - 1):
order = range(len(shape))
order = list(range(len(shape)))
order[dim], order[-1] = order[-1], order[dim]
shape[dim], shape[-1] = shape[-1], shape[dim]
x = x.permute(order)
@ -712,7 +712,7 @@ def cumsum(x, dim=None):
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
if jt.has_cuda:
if jt.flags.use_cuda:
return cub_cumsum(x, dim)
else:
return numpy_cumsum(x, dim)
@ -1656,4 +1656,4 @@ class CTCLoss(jt.Module):
self.zero_infinity = zero_infinity
def execute(self, log_probs, targets, input_lengths, target_lengths):
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)

View File

@ -2101,7 +2101,7 @@ ParameterDict = ParameterList
def Parameter(data, requires_grad=True):
''' The `Parameter` interface isn't needed in Jittor, this interface
doesn't nothings and it is just used for compatible.
does nothings and it is just used for compatible.
A Jittor Var is a Parameter
when it is a member of Module, if you don't want a Jittor

View File

@ -460,7 +460,7 @@
" train_loss_list.append(train_loss)\n",
" # 在验证集上进行验证,模型参数不做更新。\n",
" val_loss = val(model, x_val_var, y_val_var, loss_function)\n",
" val_loss_list.append(val_loss)\n",
" val_loss_list.append(val_loss.item())\n",
" \n",
"# 打印训练结束后的模型参数\n",
"print(\"After training: \\n\", model.state_dict())"
@ -598,4 +598,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@ -192,7 +192,7 @@
" \n",
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n",
" plt.show() # 展示图像\n",
" print(\"target:\", targets[num].data[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
" print(\"target:\", targets[num].numpy()[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
" break"
]
},
@ -910,12 +910,12 @@
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" loss = loss_function(outputs, targets) # 计算损失函数\n",
" optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n",
" train_losses.append(loss) # 记录该批次的 Loss\n",
" train_losses.append(loss.item()) # 记录该批次的 Loss\n",
" \n",
" if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx, len(train_loader),\n",
" 100. * batch_idx / len(train_loader), loss.data[0]))\n",
" 100. * batch_idx / len(train_loader), loss.item()))\n",
" return train_losses # 返回本纪元的 Loss\n",
"\n",
"\n",
@ -926,8 +926,8 @@
" total_num = 0 # 本纪元数据总数\n",
" for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n",
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" pred = np.argmax(outputs.data, axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
" correct = np.sum(targets.data==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
" pred = np.argmax(outputs.numpy(), axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
" correct = np.sum(targets.numpy()==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
" batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n",
" acc = correct / batch_size # 计算本批次的正确率\n",
" \n",
@ -1075,10 +1075,10 @@
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n",
" plt.show() \n",
" \n",
" print(\"target:\", targets[num].data[0]) # 打印该数据的真实标签值\n",
" print(\"target:\", targets[num].numpy()[0]) # 打印该数据的真实标签值\n",
" \n",
" outputs = model(inputs) # 模型根据输入数据进行预测\n",
" pred = np.argmax(outputs.data, axis=1) # 根据最大相似度得到预测值\n",
" pred = np.argmax(outputs.numpy(), axis=1) # 根据最大相似度得到预测值\n",
" print(\"prediction:\", pred[num]) # 打印该数据的预测值\n",
" break"
]
@ -1158,4 +1158,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@ -1,7 +1,9 @@
from .md_to_ipynb import dirname, notebook_dir
import os
import sys
import shutil
from distutils.dir_util import copy_tree
cmd = f"cp -r {dirname}/* {notebook_dir}/ && cd {notebook_dir} && jupyter notebook {' '.join(sys.argv[1:])}"
print("run cmd:", cmd)
os.system(cmd)
copy_tree(dirname, notebook_dir)
os.chdir(notebook_dir)
os.system(f"{sys.executable} -m jupyter notebook")

View File

@ -11,7 +11,7 @@ for r, _, f in os.walk(dirname):
if not fname.endswith(".md"): continue
all_md.append(os.path.join(r, fname))
for mdname in all_md:
with open(os.path.join(dirname, mdname), "r") as f:
with open(os.path.join(dirname, mdname), "r", encoding="utf-8") as f:
src = f.read()
blocks = []
for i, b in enumerate(src.split("```")):
@ -63,5 +63,5 @@ for mdname in all_md:
ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
ipynb_name = os.path.join(notebook_dir, ipynb_name)
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
with open(ipynb_name, "w") as f:
with open(ipynb_name, "w", encoding='utf8') as f:
f.write(json.dumps(ipynb))

View File

@ -858,13 +858,13 @@ def compile_src(src, h, basename):
def compile_single(head_file_name, src_file_name, src=None):
basename = os.path.basename(head_file_name).split(".")[0]
if src==None:
with open(head_file_name, 'r') as f:
with open(head_file_name, 'r', encoding='utf8') as f:
src = f.read()
code = compile_src(src, head_file_name, basename)
if not code: return False
LOG.vvv("write to", src_file_name)
LOG.vvvv(code)
with open(src_file_name, 'w') as f:
with open(src_file_name, 'w', encoding='utf8') as f:
f.write(code)
return True
@ -875,14 +875,14 @@ def compile(cache_path, jittor_path):
basenames = []
pyjt_names = []
for h in headers:
with open(h, 'r') as f:
with open(h, 'r', encoding='utf8') as f:
src = f.read()
bh = os.path.basename(h)
# jit_op_maker.h merge compile with var_holder.h
if bh == "var_holder.h": continue
if bh == "jit_op_maker.h":
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f:
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r", encoding='utf8') as f:
src = f.read() + src
basename = bh.split(".")[0]
fname = "pyjt_"+basename+".cc"
@ -913,7 +913,7 @@ def compile(cache_path, jittor_path):
fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
LOG.vvv(("write to", fname))
LOG.vvvv(code)
with open(fname, "w") as f:
with open(fname, "w", encoding='utf8') as f:
f.write(code)
pyjt_names.append(fname)
return pyjt_names

View File

@ -1,7 +1,7 @@
import os
def fix_config(in_name, out_name, src_path, out_path):
data = open(in_name, 'r').readlines()
data = open(in_name, 'r', encoding='utf8').readlines()
out = []
for d in data:
if d.startswith('INPUT ='):
@ -9,7 +9,7 @@ def fix_config(in_name, out_name, src_path, out_path):
elif d.startswith('OUTPUT_DIRECTORY ='):
d = f'OUTPUT_DIRECTORY ={out_path}\n'
out.append(d)
f = open(out_name, 'w')
f = open(out_name, 'w', encoding='utf8')
f.writelines(out)
jt_path = os.getcwd()

View File

@ -153,7 +153,7 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
const char* msg = "";
LOGvv << "Opening jit lib:" << name;
#ifdef _WIN32
void* handle = (void*)LoadLibraryExA(name.c_str(), nullptr,
void* handle = (void*)LoadLibraryExA(_to_winstr(name).c_str(), nullptr,
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
LOAD_LIBRARY_SEARCH_USER_DIRS);
#elif defined(__linux__)
@ -206,13 +206,15 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
#ifdef _WIN32
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
string jit_src_path2 = _to_winstr(jit_src_path);
#else
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
string& jit_src_path2 = jit_src_path;
#endif
string other_src;
LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
if (rewrite_op || !file_exist(jit_src_path))
write(jit_src_path, src);
if (rewrite_op || !file_exist(jit_src_path2))
write(jit_src_path2, src);
string cmd;
auto symbol_name = get_symbol_name(jit_key);

View File

@ -31,7 +31,7 @@ int _has_lock = 0;
DEFINE_FLAG(bool, disable_lock, 0, "Disable file lock");
void set_lock_path(string path) {
lock_fd = open(path.c_str(), O_RDWR);
lock_fd = open(_to_winstr(path).c_str(), O_RDWR);
ASSERT(lock_fd >= 0);
LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid();
}

View File

@ -313,7 +313,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
LOGvvvv << "Found defs include" << inc;
auto src_path = join(jittor_path, "src");
src_path = join(src_path, inc);
auto inc_src = read_all(src_path);
auto inc_src = read_all(_to_winstr(src_path));
// load_macros from include src
precompile(defs, inc_src, macros);
// we do not include defs.h
@ -736,9 +736,9 @@ string OpCompiler::get_jit_src(Op* op) {
else
after_include_src += src;
}
ASSERT(file_exist(src_path)) << src_path;
ASSERT(file_exist(_to_winstr(src_path))) << src_path;
LOGvvv << "Read from" << src_path;
string src = read_all(src_path);
string src = read_all(_to_winstr(src_path));
ASSERT(src.size()) << "Source read failed:" << src_path;
unordered_map<string,string> defs(jit_define.begin(), jit_define.end());
@ -1038,7 +1038,14 @@ jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) {
// add extra flags for custom ops
bool is_cuda = _op->flags.get(NodeFlags::_cuda);
auto op_info = get_op_info(_op->name());
return jit_compiler::compile(jit_key, src, is_cuda, op_info.extra_flags);
string extra_flags = op_info.extra_flags;
for (auto v : _op->outputs())
if (v->loop_options)
for (auto& kv : v->loop_options.data()) {
if (kv.second && startswith(kv.first, "FLAGS:"))
extra_flags += " "+kv.first.substr(6)+" ";
}
return jit_compiler::compile(jit_key, src, is_cuda, extra_flags);
}
jit_op_entry_t OpCompiler::do_compile(Op* op) {

View File

@ -137,6 +137,23 @@ struct CodeOp : Op {
assert (b.data == [5,3,1]).all()
assert (c.data == [-4,-2]).all()
Example-5::
# This example shows how to customize code op
# compilation flags, such as add include search
# path, add definitions, or any command line options
a = jt.random([10])
b = jt.code(a.shape, a.dtype, [a],
cpu_src='''
@out0(0) = HAHAHA;
''')
# HAHAHA is defined in flags below
# /any/include/path can be change to any path you want to include
b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1}
print(b[0])
# will output 233
CUDA Example-1::

View File

@ -33,7 +33,7 @@ typedef struct {
int pagemap_get_entry(PagemapEntry* entry, int pagemap_fd, uintptr_t vaddr)
{
size_t nread;
ssize_t ret;
int64_t ret;
uint64_t data;
uintptr_t vpn;

View File

@ -31,8 +31,8 @@ struct PyArray_Proxy {
PyObject_HEAD
char* data;
int nd;
ssize_t* dimensions;
ssize_t* strides;
Py_ssize_t* dimensions;
Py_ssize_t* strides;
PyObject *base;
PyArrayDescr_Proxy *descr;
int flags;

View File

@ -234,7 +234,15 @@ static inline bool is_full_path(const string& name) {
#endif
}
bool cache_compile(string cmd, const string& cache_path, const string& jittor_path) {
bool cache_compile(string cmd, const string& cache_path_, const string& jittor_path_) {
#ifdef _WIN32
cmd = _to_winstr(cmd);
string cache_path = _to_winstr(cache_path_);
string jittor_path = _to_winstr(jittor_path_);
#else
const string& cache_path = cache_path_;
const string& jittor_path = jittor_path_;
#endif
vector<string> input_names;
map<string,vector<string>> extra;
string output_name;
@ -255,6 +263,9 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
continue;
processed.insert(input_names[i]);
auto src = read_all(input_names[i]);
#ifdef _WIN32
src = _to_winstr(src);
#endif
auto back = input_names[i].back();
// *.lib
if (back == 'b') continue;

View File

@ -174,7 +174,7 @@ void send_log(std::ostringstream&& out, char level, int verbose) {
} else {
std::lock_guard<std::mutex> lk(sync_log_m);
// std::cerr << "[SYNC]";
std::cerr << out.str();
std::cerr << _to_winstr(out.str());
std::cerr.flush();
}
}
@ -304,7 +304,9 @@ int register_sigaction() {
sigaction(SIGKILL, &sa, NULL);
sigaction(SIGSTOP, &sa, NULL);
sigaction(SIGFPE, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
// jupyter use sigint to interp
if (getenv("JPY_PARENT_PID") == nullptr)
sigaction(SIGINT, &sa, NULL);
sigaction(SIGCHLD, &sa, NULL);
sigaction(SIGILL, &sa, NULL);
sigaction(SIGBUS, &sa, NULL);
@ -315,6 +317,10 @@ int register_sigaction() {
}
static int log_init() {
#ifdef _WIN32
// SetConsoleCP(CP_UTF8);
// SetConsoleOutputCP(CP_UTF8);
#endif
register_sigaction();
std::atexit(log_exiting);
return 1;
@ -444,6 +450,39 @@ If you still have problems, please contact us:
}
#ifdef _WIN32
string GbkToUtf8(const char *src_str)
{
int len = MultiByteToWideChar(CP_ACP, 0, src_str, -1, NULL, 0);
wchar_t* wstr = new wchar_t[len + 1];
memset(wstr, 0, len + 1);
MultiByteToWideChar(CP_ACP, 0, src_str, -1, wstr, len);
len = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL);
char* str = new char[len + 1];
memset(str, 0, len + 1);
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, len, NULL, NULL);
string strTemp = str;
if (wstr) delete[] wstr;
if (str) delete[] str;
return strTemp;
}
string Utf8ToGbk(const char *src_str)
{
int len = MultiByteToWideChar(CP_UTF8, 0, src_str, -1, NULL, 0);
wchar_t* wszGBK = new wchar_t[len + 1];
memset(wszGBK, 0, len * 2 + 2);
MultiByteToWideChar(CP_UTF8, 0, src_str, -1, wszGBK, len);
len = WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, NULL, 0, NULL, NULL);
char* szGBK = new char[len + 1];
memset(szGBK, 0, len + 1);
WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, szGBK, len, NULL, NULL);
string strTemp(szGBK);
if (wszGBK) delete[] wszGBK;
if (szGBK) delete[] szGBK;
return strTemp;
}
int system_popen(const char *cmd, const char* cwd) {
HANDLE g_hChildStd_OUT_Rd = NULL;
HANDLE g_hChildStd_OUT_Wr = NULL;

View File

@ -16,6 +16,15 @@ namespace jittor {
// define in tracer.cc
void print_trace();
void breakpoint();
#ifdef _WIN32
string GbkToUtf8(const char *src_str);
string Utf8ToGbk(const char *src_str);
#define _to_winstr(x) Utf8ToGbk(x.c_str())
#define _from_winstr(x) GbkToUtf8(x.c_str())
#else
#define _to_winstr(x) (x)
#define _from_winstr(x) (x)
#endif
constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
return path[index]

View File

@ -181,12 +181,24 @@ struct VarHolder {
inline void set_requires_grad(bool flag) {
if (flag == get_requires_grad()) return;
if (flag)
_update(this);
start_grad();
else
stop_grad();
return;
}
/**
* enable the gradient calculation for the Var.
*/
// @pyjt(start_grad)
// @attrs(return_self)
inline VarHolder* start_grad() {
if (!var->dtype().is_float())
LOGw << "cannot enable grad of a non-float value:" << var;
_update(this);
return this;
}
// @pyjt(__get__uncertain_shape)
inline NanoVector uncertain_shape() {
return var->shape;

View File

@ -29,6 +29,19 @@ class TestCodeOp(unittest.TestCase):
da = jt.grad(c*b, a)
assert np.allclose(c.data*na*4, da.data), (c.data*na*4, da.data)
def test_exflags(self):
a = jt.random([10])
b = jt.code(a.shape, a.dtype, [a],
cpu_src='''
LOGir << HAHAHA;
@out0(0) = HAHAHA;
''')
b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1}
# print(b[0])
assert b[0].item() == 233
def test_use_func(self):
class Func(Function):
def execute(self, x):

View File

@ -123,8 +123,9 @@ class TestCudnnConvOp(unittest.TestCase):
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)")
assert len(logs)==3 and "oihw" in logs[0][0], logs
assert np.allclose(y.data, cy.data)
assert np.allclose(dx.data, cdx.data, 1e-2)
assert np.allclose(dw.data, cdw.data, 1e-2)
np.testing.assert_allclose(dx.data, cdx.data, atol=1e-2)
np.testing.assert_allclose(dw.data, cdw.data, atol=1e-2)
if os.name == 'nt': return
check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1)
check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1)
check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1)
@ -142,13 +143,15 @@ class TestCudnnConvOp(unittest.TestCase):
y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
dx2, dw2 = jt.grad(masky*y2, [x, w])
np.testing.assert_allclose(y.data, y2.data)
np.testing.assert_allclose(y.data, y2.data, rtol=1e-5, atol=1e-3)
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-5, atol=1e-3)
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
# TODO: check why windows failed in this test
if os.name == "nt": return
check((2,4,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
check((2,4,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
check((2,4,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
@ -181,6 +184,7 @@ class TestCudnnConvOp(unittest.TestCase):
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
if os.name == 'nt': return
check((2,5,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
check((2,5,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
check((2,5,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))

View File

@ -142,6 +142,11 @@ class TestGrad(unittest.TestCase):
expect_error(lambda: jt.grad(z, [y1]))
dx, = jt.grad(z, [x])
self.assertEqual(dx.data, 48)
def test_int_enable_grad(self):
a = jt.int([1,2,3])
a.requires_grad = True
a.start_grad()
def test_nth_grad(self):
x = jt.array(2.0)

View File

@ -37,5 +37,5 @@ for k in syms:
src += f" {k}\n"
# print(src)
with open(def_path, "w") as f:
with open(def_path, "w", encoding="utf8") as f:
f.write(src)

View File

@ -120,7 +120,11 @@ def try_import_jit_utils_core(silent=None):
if is_in_ipynb: os.environ["log_sync"] = "1"
import jit_utils_core as cc
if is_in_ipynb:
cc.ostream_redirect(True, True)
if os.name != 'nt':
# windows jupyter has import error
# disable ostream redirect
# TODO: find a better way
cc.ostream_redirect(True, True)
except Exception as _:
if int(os.environ.get("log_v", "0")) > 0:
print(_)

View File

@ -23,7 +23,7 @@ def check_is_both(src):
for mdname in all_src_md:
print(mdname)
with open(mdname, "r") as f:
with open(mdname, "r", encoding='utf8') as f:
src = f.read()
src = src.split("```")
en_src = []
@ -47,8 +47,8 @@ for mdname in all_src_md:
cn_src.append("\n".join(cn_s))
en_src = "```".join(en_src)
cn_src = "```".join(cn_src)
with open(mdname.replace(".src.md", ".md"), 'w') as f:
with open(mdname.replace(".src.md", ".md"), 'w', encoding='utf8') as f:
f.write(en_src)
with open(mdname.replace(".src.md", ".cn.md"), 'w') as f:
with open(mdname.replace(".src.md", ".cn.md"), 'w', encoding='utf8') as f:
f.write(cn_src)