mirror of https://github.com/Jittor/Jittor
fix windows encoding bugs
This commit is contained in:
parent
53b377ee7d
commit
495d78ad20
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.1.44'
|
__version__ = '1.3.1.45'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
@ -410,9 +410,7 @@ def flatten(input, start_dim=0, end_dim=-1):
|
||||||
return input.reshape(out_shape)
|
return input.reshape(out_shape)
|
||||||
Var.flatten = flatten
|
Var.flatten = flatten
|
||||||
|
|
||||||
def start_grad(x):
|
Var.detach_inplace = Var.start_grad
|
||||||
return x._update(x)
|
|
||||||
Var.detach_inplace = Var.start_grad = start_grad
|
|
||||||
|
|
||||||
def detach(x):
|
def detach(x):
|
||||||
return x.detach()
|
return x.detach()
|
||||||
|
|
|
@ -190,6 +190,11 @@ def setup_cub():
|
||||||
|
|
||||||
def setup_cuda_extern():
|
def setup_cuda_extern():
|
||||||
if not has_cuda: return
|
if not has_cuda: return
|
||||||
|
check_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
|
||||||
|
if "cuda" in check_ld_path.lower() and "lib" in check_ld_path.lower():
|
||||||
|
LOG.w(f"CUDA related path found in LD_LIBRARY_PATH({check_ld_path}), "
|
||||||
|
"This path may cause jittor found the wrong libs, "
|
||||||
|
"please unset LD_LIBRARY_PATH. ")
|
||||||
LOG.vv("setup cuda extern...")
|
LOG.vv("setup cuda extern...")
|
||||||
cache_path_cuda = os.path.join(cache_path, "cuda")
|
cache_path_cuda = os.path.join(cache_path, "cuda")
|
||||||
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")
|
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")
|
||||||
|
|
|
@ -199,7 +199,7 @@ def gen_jit_tests():
|
||||||
}} // jittor
|
}} // jittor
|
||||||
"""
|
"""
|
||||||
LOG.vvvv(jit_src)
|
LOG.vvvv(jit_src)
|
||||||
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
|
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w', encoding='utf8') as f:
|
||||||
f.write(jit_src)
|
f.write(jit_src)
|
||||||
|
|
||||||
def gen_jit_flags():
|
def gen_jit_flags():
|
||||||
|
@ -257,7 +257,7 @@ def gen_jit_flags():
|
||||||
}} // jittor
|
}} // jittor
|
||||||
"""
|
"""
|
||||||
LOG.vvvv(jit_src)
|
LOG.vvvv(jit_src)
|
||||||
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
|
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w', encoding='utf8') as f:
|
||||||
f.write(jit_src)
|
f.write(jit_src)
|
||||||
|
|
||||||
def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
||||||
|
@ -639,9 +639,9 @@ def compile_custom_op(header, source, op_name, warp=True):
|
||||||
make_cache_dir(cops_dir)
|
make_cache_dir(cops_dir)
|
||||||
hname = os.path.join(cops_dir, op_name+"_op.h")
|
hname = os.path.join(cops_dir, op_name+"_op.h")
|
||||||
ccname = os.path.join(cops_dir, op_name+"_op.cc")
|
ccname = os.path.join(cops_dir, op_name+"_op.cc")
|
||||||
with open(hname, 'w') as f:
|
with open(hname, 'w', encoding='utf8') as f:
|
||||||
f.write(header)
|
f.write(header)
|
||||||
with open(ccname, 'w') as f:
|
with open(ccname, 'w', encoding='utf8') as f:
|
||||||
f.write(source)
|
f.write(source)
|
||||||
m = compile_custom_ops([hname, ccname])
|
m = compile_custom_ops([hname, ccname])
|
||||||
return getattr(m, op_name)
|
return getattr(m, op_name)
|
||||||
|
@ -679,7 +679,7 @@ def compile_custom_ops(
|
||||||
dirname = os.path.dirname(name)
|
dirname = os.path.dirname(name)
|
||||||
if dirname.endswith("inc"):
|
if dirname.endswith("inc"):
|
||||||
includes.append(dirname)
|
includes.append(dirname)
|
||||||
with open(name, "r") as f:
|
with open(name, "r", encoding='utf8') as f:
|
||||||
if "@pyjt" in f.read():
|
if "@pyjt" in f.read():
|
||||||
pyjt_includes.append(name)
|
pyjt_includes.append(name)
|
||||||
bname = os.path.basename(name)
|
bname = os.path.basename(name)
|
||||||
|
@ -736,7 +736,7 @@ def compile_custom_ops(
|
||||||
"init_module(PyModuleDef* mdef, PyObject* m) {",
|
"init_module(PyModuleDef* mdef, PyObject* m) {",
|
||||||
f"jittor::pyjt_def_{bname}(m);")
|
f"jittor::pyjt_def_{bname}(m);")
|
||||||
|
|
||||||
with open(gen_head_fname, "w") as f:
|
with open(gen_head_fname, "w", encoding='utf8') as f:
|
||||||
f.write(gen_src)
|
f.write(gen_src)
|
||||||
|
|
||||||
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
LOG.vvv(f"Build custum ops lib:{gen_lib}")
|
||||||
|
@ -781,7 +781,7 @@ def compile_extern():
|
||||||
files = os.listdir(jittor_path_llvm)
|
files = os.listdir(jittor_path_llvm)
|
||||||
# test_pass.cc is used for test link problem of llvm pass plugin
|
# test_pass.cc is used for test link problem of llvm pass plugin
|
||||||
test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
|
test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
|
||||||
with open(test_pass_path, 'w') as f:
|
with open(test_pass_path, 'w', encoding='utf8') as f:
|
||||||
f.write("int main() {return 0;}")
|
f.write("int main() {return 0;}")
|
||||||
|
|
||||||
# -fno-rtti fix link error
|
# -fno-rtti fix link error
|
||||||
|
@ -1082,7 +1082,7 @@ if os.name == 'nt':
|
||||||
cc_flags = cc_flags.replace("-lstdc++", "")
|
cc_flags = cc_flags.replace("-lstdc++", "")
|
||||||
cc_flags = cc_flags.replace("-ldl", "")
|
cc_flags = cc_flags.replace("-ldl", "")
|
||||||
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
|
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
|
||||||
cc_flags += " -EHa -MD "
|
cc_flags += " -EHa -MD -utf-8 "
|
||||||
import jittor_utils
|
import jittor_utils
|
||||||
if jittor_utils.msvc_path:
|
if jittor_utils.msvc_path:
|
||||||
mp = jittor_utils.msvc_path
|
mp = jittor_utils.msvc_path
|
||||||
|
@ -1217,7 +1217,7 @@ gen_jit_tests()
|
||||||
op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True)
|
op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True)
|
||||||
jit_src = gen_jit_op_maker(op_headers)
|
jit_src = gen_jit_op_maker(op_headers)
|
||||||
LOG.vvvv(jit_src)
|
LOG.vvvv(jit_src)
|
||||||
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
|
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w', encoding='utf8') as f:
|
||||||
f.write(jit_src)
|
f.write(jit_src)
|
||||||
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
|
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
|
||||||
# gen pyjt
|
# gen pyjt
|
||||||
|
|
|
@ -460,7 +460,7 @@
|
||||||
" train_loss_list.append(train_loss)\n",
|
" train_loss_list.append(train_loss)\n",
|
||||||
" # 在验证集上进行验证,模型参数不做更新。\n",
|
" # 在验证集上进行验证,模型参数不做更新。\n",
|
||||||
" val_loss = val(model, x_val_var, y_val_var, loss_function)\n",
|
" val_loss = val(model, x_val_var, y_val_var, loss_function)\n",
|
||||||
" val_loss_list.append(val_loss)\n",
|
" val_loss_list.append(val_loss.item())\n",
|
||||||
" \n",
|
" \n",
|
||||||
"# 打印训练结束后的模型参数\n",
|
"# 打印训练结束后的模型参数\n",
|
||||||
"print(\"After training: \\n\", model.state_dict())"
|
"print(\"After training: \\n\", model.state_dict())"
|
||||||
|
@ -598,4 +598,4 @@
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 4
|
"nbformat_minor": 4
|
||||||
}
|
}
|
|
@ -192,7 +192,7 @@
|
||||||
" \n",
|
" \n",
|
||||||
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n",
|
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n",
|
||||||
" plt.show() # 展示图像\n",
|
" plt.show() # 展示图像\n",
|
||||||
" print(\"target:\", targets[num].data[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
|
" print(\"target:\", targets[num].numpy()[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
|
||||||
" break"
|
" break"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -910,12 +910,12 @@
|
||||||
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
|
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
|
||||||
" loss = loss_function(outputs, targets) # 计算损失函数\n",
|
" loss = loss_function(outputs, targets) # 计算损失函数\n",
|
||||||
" optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n",
|
" optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n",
|
||||||
" train_losses.append(loss) # 记录该批次的 Loss\n",
|
" train_losses.append(loss.item()) # 记录该批次的 Loss\n",
|
||||||
" \n",
|
" \n",
|
||||||
" if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n",
|
" if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n",
|
||||||
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
|
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
|
||||||
" epoch, batch_idx, len(train_loader),\n",
|
" epoch, batch_idx, len(train_loader),\n",
|
||||||
" 100. * batch_idx / len(train_loader), loss.data[0]))\n",
|
" 100. * batch_idx / len(train_loader), loss.item()))\n",
|
||||||
" return train_losses # 返回本纪元的 Loss\n",
|
" return train_losses # 返回本纪元的 Loss\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -926,8 +926,8 @@
|
||||||
" total_num = 0 # 本纪元数据总数\n",
|
" total_num = 0 # 本纪元数据总数\n",
|
||||||
" for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n",
|
" for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n",
|
||||||
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
|
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
|
||||||
" pred = np.argmax(outputs.data, axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
|
" pred = np.argmax(outputs.numpy(), axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
|
||||||
" correct = np.sum(targets.data==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
|
" correct = np.sum(targets.numpy()==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
|
||||||
" batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n",
|
" batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n",
|
||||||
" acc = correct / batch_size # 计算本批次的正确率\n",
|
" acc = correct / batch_size # 计算本批次的正确率\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
@ -1075,10 +1075,10 @@
|
||||||
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n",
|
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n",
|
||||||
" plt.show() \n",
|
" plt.show() \n",
|
||||||
" \n",
|
" \n",
|
||||||
" print(\"target:\", targets[num].data[0]) # 打印该数据的真实标签值\n",
|
" print(\"target:\", targets[num].numpy()[0]) # 打印该数据的真实标签值\n",
|
||||||
" \n",
|
" \n",
|
||||||
" outputs = model(inputs) # 模型根据输入数据进行预测\n",
|
" outputs = model(inputs) # 模型根据输入数据进行预测\n",
|
||||||
" pred = np.argmax(outputs.data, axis=1) # 根据最大相似度得到预测值\n",
|
" pred = np.argmax(outputs.numpy(), axis=1) # 根据最大相似度得到预测值\n",
|
||||||
" print(\"prediction:\", pred[num]) # 打印该数据的预测值\n",
|
" print(\"prediction:\", pred[num]) # 打印该数据的预测值\n",
|
||||||
" break"
|
" break"
|
||||||
]
|
]
|
||||||
|
@ -1158,4 +1158,4 @@
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 4
|
"nbformat_minor": 4
|
||||||
}
|
}
|
|
@ -63,5 +63,5 @@ for mdname in all_md:
|
||||||
ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
|
ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
|
||||||
ipynb_name = os.path.join(notebook_dir, ipynb_name)
|
ipynb_name = os.path.join(notebook_dir, ipynb_name)
|
||||||
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
|
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
|
||||||
with open(ipynb_name, "w") as f:
|
with open(ipynb_name, "w", encoding='utf8') as f:
|
||||||
f.write(json.dumps(ipynb))
|
f.write(json.dumps(ipynb))
|
|
@ -858,13 +858,13 @@ def compile_src(src, h, basename):
|
||||||
def compile_single(head_file_name, src_file_name, src=None):
|
def compile_single(head_file_name, src_file_name, src=None):
|
||||||
basename = os.path.basename(head_file_name).split(".")[0]
|
basename = os.path.basename(head_file_name).split(".")[0]
|
||||||
if src==None:
|
if src==None:
|
||||||
with open(head_file_name, 'r') as f:
|
with open(head_file_name, 'r', encoding='utf8') as f:
|
||||||
src = f.read()
|
src = f.read()
|
||||||
code = compile_src(src, head_file_name, basename)
|
code = compile_src(src, head_file_name, basename)
|
||||||
if not code: return False
|
if not code: return False
|
||||||
LOG.vvv("write to", src_file_name)
|
LOG.vvv("write to", src_file_name)
|
||||||
LOG.vvvv(code)
|
LOG.vvvv(code)
|
||||||
with open(src_file_name, 'w') as f:
|
with open(src_file_name, 'w', encoding='utf8') as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -875,14 +875,14 @@ def compile(cache_path, jittor_path):
|
||||||
basenames = []
|
basenames = []
|
||||||
pyjt_names = []
|
pyjt_names = []
|
||||||
for h in headers:
|
for h in headers:
|
||||||
with open(h, 'r') as f:
|
with open(h, 'r', encoding='utf8') as f:
|
||||||
src = f.read()
|
src = f.read()
|
||||||
|
|
||||||
bh = os.path.basename(h)
|
bh = os.path.basename(h)
|
||||||
# jit_op_maker.h merge compile with var_holder.h
|
# jit_op_maker.h merge compile with var_holder.h
|
||||||
if bh == "var_holder.h": continue
|
if bh == "var_holder.h": continue
|
||||||
if bh == "jit_op_maker.h":
|
if bh == "jit_op_maker.h":
|
||||||
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f:
|
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r", encoding='utf8') as f:
|
||||||
src = f.read() + src
|
src = f.read() + src
|
||||||
basename = bh.split(".")[0]
|
basename = bh.split(".")[0]
|
||||||
fname = "pyjt_"+basename+".cc"
|
fname = "pyjt_"+basename+".cc"
|
||||||
|
@ -913,7 +913,7 @@ def compile(cache_path, jittor_path):
|
||||||
fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
|
fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
|
||||||
LOG.vvv(("write to", fname))
|
LOG.vvv(("write to", fname))
|
||||||
LOG.vvvv(code)
|
LOG.vvvv(code)
|
||||||
with open(fname, "w") as f:
|
with open(fname, "w", encoding='utf8') as f:
|
||||||
f.write(code)
|
f.write(code)
|
||||||
pyjt_names.append(fname)
|
pyjt_names.append(fname)
|
||||||
return pyjt_names
|
return pyjt_names
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def fix_config(in_name, out_name, src_path, out_path):
|
def fix_config(in_name, out_name, src_path, out_path):
|
||||||
data = open(in_name, 'r').readlines()
|
data = open(in_name, 'r', encoding='utf8').readlines()
|
||||||
out = []
|
out = []
|
||||||
for d in data:
|
for d in data:
|
||||||
if d.startswith('INPUT ='):
|
if d.startswith('INPUT ='):
|
||||||
|
@ -9,7 +9,7 @@ def fix_config(in_name, out_name, src_path, out_path):
|
||||||
elif d.startswith('OUTPUT_DIRECTORY ='):
|
elif d.startswith('OUTPUT_DIRECTORY ='):
|
||||||
d = f'OUTPUT_DIRECTORY ={out_path}\n'
|
d = f'OUTPUT_DIRECTORY ={out_path}\n'
|
||||||
out.append(d)
|
out.append(d)
|
||||||
f = open(out_name, 'w')
|
f = open(out_name, 'w', encoding='utf8')
|
||||||
f.writelines(out)
|
f.writelines(out)
|
||||||
|
|
||||||
jt_path = os.getcwd()
|
jt_path = os.getcwd()
|
||||||
|
|
|
@ -153,7 +153,7 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
|
||||||
const char* msg = "";
|
const char* msg = "";
|
||||||
LOGvv << "Opening jit lib:" << name;
|
LOGvv << "Opening jit lib:" << name;
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
void* handle = (void*)LoadLibraryExA(name.c_str(), nullptr,
|
void* handle = (void*)LoadLibraryExA(Utf8ToGbk(name.c_str()).c_str(), nullptr,
|
||||||
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
|
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
|
||||||
LOAD_LIBRARY_SEARCH_USER_DIRS);
|
LOAD_LIBRARY_SEARCH_USER_DIRS);
|
||||||
#elif defined(__linux__)
|
#elif defined(__linux__)
|
||||||
|
@ -206,13 +206,15 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
|
||||||
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
|
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
|
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
|
||||||
|
string jit_src_path2 = Utf8ToGbk(jit_src_path.c_str());
|
||||||
#else
|
#else
|
||||||
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
|
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
|
||||||
|
string& jit_src_path2 = jit_src_path;
|
||||||
#endif
|
#endif
|
||||||
string other_src;
|
string other_src;
|
||||||
LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
|
LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
|
||||||
if (rewrite_op || !file_exist(jit_src_path))
|
if (rewrite_op || !file_exist(jit_src_path2))
|
||||||
write(jit_src_path, src);
|
write(jit_src_path2, src);
|
||||||
string cmd;
|
string cmd;
|
||||||
|
|
||||||
auto symbol_name = get_symbol_name(jit_key);
|
auto symbol_name = get_symbol_name(jit_key);
|
||||||
|
|
|
@ -45,13 +45,7 @@ DEF_IS(string, PyObject*) to_py_object(const string& a) {
|
||||||
|
|
||||||
DEF_IS(string, string) from_py_object(PyObject* obj) {
|
DEF_IS(string, string) from_py_object(PyObject* obj) {
|
||||||
Py_ssize_t size;
|
Py_ssize_t size;
|
||||||
#ifdef _WIN32
|
|
||||||
PyObjHolder a(PyUnicode_AsEncodedString(obj, win_encode.c_str(), "strict"));
|
|
||||||
char* s;
|
|
||||||
auto ret = PyBytes_AsStringAndSize(a.obj, &s, &size);
|
|
||||||
#else
|
|
||||||
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
|
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
|
||||||
#endif
|
|
||||||
CHECK(s);
|
CHECK(s);
|
||||||
return string(s, size);
|
return string(s, size);
|
||||||
}
|
}
|
||||||
|
|
|
@ -234,7 +234,15 @@ static inline bool is_full_path(const string& name) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool cache_compile(string cmd, const string& cache_path, const string& jittor_path) {
|
bool cache_compile(string cmd, const string& cache_path_, const string& jittor_path_) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
cmd = Utf8ToGbk(cmd.c_str());
|
||||||
|
string cache_path = Utf8ToGbk(cache_path_.c_str());
|
||||||
|
string jittor_path = Utf8ToGbk(jittor_path_.c_str());
|
||||||
|
#else
|
||||||
|
const string& cache_path = cache_path_;
|
||||||
|
const string& jittor_path = jittor_path_;
|
||||||
|
#endif
|
||||||
vector<string> input_names;
|
vector<string> input_names;
|
||||||
map<string,vector<string>> extra;
|
map<string,vector<string>> extra;
|
||||||
string output_name;
|
string output_name;
|
||||||
|
@ -255,6 +263,9 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
|
||||||
continue;
|
continue;
|
||||||
processed.insert(input_names[i]);
|
processed.insert(input_names[i]);
|
||||||
auto src = read_all(input_names[i]);
|
auto src = read_all(input_names[i]);
|
||||||
|
#ifdef _WIN32
|
||||||
|
src = Utf8ToGbk(src.c_str());
|
||||||
|
#endif
|
||||||
auto back = input_names[i].back();
|
auto back = input_names[i].back();
|
||||||
// *.lib
|
// *.lib
|
||||||
if (back == 'b') continue;
|
if (back == 'b') continue;
|
||||||
|
|
|
@ -316,19 +316,13 @@ int register_sigaction() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef _WIN32
|
|
||||||
string win_encode;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static int log_init() {
|
static int log_init() {
|
||||||
|
#ifdef _WIN32
|
||||||
|
SetConsoleCP(CP_UTF8);
|
||||||
|
SetConsoleOutputCP(CP_UTF8);
|
||||||
|
#endif
|
||||||
register_sigaction();
|
register_sigaction();
|
||||||
std::atexit(log_exiting);
|
std::atexit(log_exiting);
|
||||||
#ifdef _WIN32
|
|
||||||
if (getenv("JITTOR_ENCODE"))
|
|
||||||
win_encode = getenv("JITTOR_ENCODE");
|
|
||||||
else
|
|
||||||
win_encode = "gbk";
|
|
||||||
#endif
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -456,6 +450,39 @@ If you still have problems, please contact us:
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
|
||||||
|
string GbkToUtf8(const char *src_str)
|
||||||
|
{
|
||||||
|
int len = MultiByteToWideChar(CP_ACP, 0, src_str, -1, NULL, 0);
|
||||||
|
wchar_t* wstr = new wchar_t[len + 1];
|
||||||
|
memset(wstr, 0, len + 1);
|
||||||
|
MultiByteToWideChar(CP_ACP, 0, src_str, -1, wstr, len);
|
||||||
|
len = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL);
|
||||||
|
char* str = new char[len + 1];
|
||||||
|
memset(str, 0, len + 1);
|
||||||
|
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, len, NULL, NULL);
|
||||||
|
string strTemp = str;
|
||||||
|
if (wstr) delete[] wstr;
|
||||||
|
if (str) delete[] str;
|
||||||
|
return strTemp;
|
||||||
|
}
|
||||||
|
|
||||||
|
string Utf8ToGbk(const char *src_str)
|
||||||
|
{
|
||||||
|
int len = MultiByteToWideChar(CP_UTF8, 0, src_str, -1, NULL, 0);
|
||||||
|
wchar_t* wszGBK = new wchar_t[len + 1];
|
||||||
|
memset(wszGBK, 0, len * 2 + 2);
|
||||||
|
MultiByteToWideChar(CP_UTF8, 0, src_str, -1, wszGBK, len);
|
||||||
|
len = WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, NULL, 0, NULL, NULL);
|
||||||
|
char* szGBK = new char[len + 1];
|
||||||
|
memset(szGBK, 0, len + 1);
|
||||||
|
WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, szGBK, len, NULL, NULL);
|
||||||
|
string strTemp(szGBK);
|
||||||
|
if (wszGBK) delete[] wszGBK;
|
||||||
|
if (szGBK) delete[] szGBK;
|
||||||
|
return strTemp;
|
||||||
|
}
|
||||||
|
|
||||||
int system_popen(const char *cmd, const char* cwd) {
|
int system_popen(const char *cmd, const char* cwd) {
|
||||||
HANDLE g_hChildStd_OUT_Rd = NULL;
|
HANDLE g_hChildStd_OUT_Rd = NULL;
|
||||||
HANDLE g_hChildStd_OUT_Wr = NULL;
|
HANDLE g_hChildStd_OUT_Wr = NULL;
|
||||||
|
|
|
@ -16,6 +16,10 @@ namespace jittor {
|
||||||
// define in tracer.cc
|
// define in tracer.cc
|
||||||
void print_trace();
|
void print_trace();
|
||||||
void breakpoint();
|
void breakpoint();
|
||||||
|
#ifdef _WIN32
|
||||||
|
string GbkToUtf8(const char *src_str);
|
||||||
|
string Utf8ToGbk(const char *src_str);
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
|
constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
|
||||||
return path[index]
|
return path[index]
|
||||||
|
@ -277,8 +281,4 @@ bool check_vlog(const char* fileline, int verbose);
|
||||||
|
|
||||||
void system_with_check(const char* cmd, const char* cwd=nullptr);
|
void system_with_check(const char* cmd, const char* cwd=nullptr);
|
||||||
|
|
||||||
#ifdef _WIN32
|
|
||||||
extern string win_encode;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
|
@ -181,12 +181,24 @@ struct VarHolder {
|
||||||
inline void set_requires_grad(bool flag) {
|
inline void set_requires_grad(bool flag) {
|
||||||
if (flag == get_requires_grad()) return;
|
if (flag == get_requires_grad()) return;
|
||||||
if (flag)
|
if (flag)
|
||||||
_update(this);
|
start_grad();
|
||||||
else
|
else
|
||||||
stop_grad();
|
stop_grad();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* enable the gradient calculation for the Var.
|
||||||
|
*/
|
||||||
|
// @pyjt(start_grad)
|
||||||
|
// @attrs(return_self)
|
||||||
|
inline VarHolder* start_grad() {
|
||||||
|
if (!var->dtype().is_float())
|
||||||
|
LOGw << "cannot enable grad of a non-float value:" << var;
|
||||||
|
_update(this);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
// @pyjt(__get__uncertain_shape)
|
// @pyjt(__get__uncertain_shape)
|
||||||
inline NanoVector uncertain_shape() {
|
inline NanoVector uncertain_shape() {
|
||||||
return var->shape;
|
return var->shape;
|
||||||
|
|
|
@ -142,6 +142,11 @@ class TestGrad(unittest.TestCase):
|
||||||
expect_error(lambda: jt.grad(z, [y1]))
|
expect_error(lambda: jt.grad(z, [y1]))
|
||||||
dx, = jt.grad(z, [x])
|
dx, = jt.grad(z, [x])
|
||||||
self.assertEqual(dx.data, 48)
|
self.assertEqual(dx.data, 48)
|
||||||
|
|
||||||
|
def test_int_enable_grad(self):
|
||||||
|
a = jt.int([1,2,3])
|
||||||
|
a.requires_grad = True
|
||||||
|
a.start_grad()
|
||||||
|
|
||||||
def test_nth_grad(self):
|
def test_nth_grad(self):
|
||||||
x = jt.array(2.0)
|
x = jt.array(2.0)
|
||||||
|
|
|
@ -23,7 +23,7 @@ def check_is_both(src):
|
||||||
|
|
||||||
for mdname in all_src_md:
|
for mdname in all_src_md:
|
||||||
print(mdname)
|
print(mdname)
|
||||||
with open(mdname, "r") as f:
|
with open(mdname, "r", encoding='utf8') as f:
|
||||||
src = f.read()
|
src = f.read()
|
||||||
src = src.split("```")
|
src = src.split("```")
|
||||||
en_src = []
|
en_src = []
|
||||||
|
@ -47,8 +47,8 @@ for mdname in all_src_md:
|
||||||
cn_src.append("\n".join(cn_s))
|
cn_src.append("\n".join(cn_s))
|
||||||
en_src = "```".join(en_src)
|
en_src = "```".join(en_src)
|
||||||
cn_src = "```".join(cn_src)
|
cn_src = "```".join(cn_src)
|
||||||
with open(mdname.replace(".src.md", ".md"), 'w') as f:
|
with open(mdname.replace(".src.md", ".md"), 'w', encoding='utf8') as f:
|
||||||
f.write(en_src)
|
f.write(en_src)
|
||||||
with open(mdname.replace(".src.md", ".cn.md"), 'w') as f:
|
with open(mdname.replace(".src.md", ".cn.md"), 'w', encoding='utf8') as f:
|
||||||
f.write(cn_src)
|
f.write(cn_src)
|
||||||
|
|
Loading…
Reference in New Issue