fix windows encoding bugs

This commit is contained in:
Dun Liang 2022-03-09 18:23:42 +08:00
parent 53b377ee7d
commit 495d78ad20
16 changed files with 113 additions and 59 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.1.44' __version__ = '1.3.1.45'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int
@ -410,9 +410,7 @@ def flatten(input, start_dim=0, end_dim=-1):
return input.reshape(out_shape) return input.reshape(out_shape)
Var.flatten = flatten Var.flatten = flatten
def start_grad(x): Var.detach_inplace = Var.start_grad
return x._update(x)
Var.detach_inplace = Var.start_grad = start_grad
def detach(x): def detach(x):
return x.detach() return x.detach()

View File

@ -190,6 +190,11 @@ def setup_cub():
def setup_cuda_extern(): def setup_cuda_extern():
if not has_cuda: return if not has_cuda: return
check_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
if "cuda" in check_ld_path.lower() and "lib" in check_ld_path.lower():
LOG.w(f"CUDA related path found in LD_LIBRARY_PATH({check_ld_path}), "
"This path may cause jittor found the wrong libs, "
"please unset LD_LIBRARY_PATH. ")
LOG.vv("setup cuda extern...") LOG.vv("setup cuda extern...")
cache_path_cuda = os.path.join(cache_path, "cuda") cache_path_cuda = os.path.join(cache_path, "cuda")
cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc") cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc")

View File

@ -199,7 +199,7 @@ def gen_jit_tests():
}} // jittor }} // jittor
""" """
LOG.vvvv(jit_src) LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f: with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w', encoding='utf8') as f:
f.write(jit_src) f.write(jit_src)
def gen_jit_flags(): def gen_jit_flags():
@ -257,7 +257,7 @@ def gen_jit_flags():
}} // jittor }} // jittor
""" """
LOG.vvvv(jit_src) LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f: with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w', encoding='utf8') as f:
f.write(jit_src) f.write(jit_src)
def gen_jit_op_maker(op_headers, export=False, extra_flags=""): def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
@ -639,9 +639,9 @@ def compile_custom_op(header, source, op_name, warp=True):
make_cache_dir(cops_dir) make_cache_dir(cops_dir)
hname = os.path.join(cops_dir, op_name+"_op.h") hname = os.path.join(cops_dir, op_name+"_op.h")
ccname = os.path.join(cops_dir, op_name+"_op.cc") ccname = os.path.join(cops_dir, op_name+"_op.cc")
with open(hname, 'w') as f: with open(hname, 'w', encoding='utf8') as f:
f.write(header) f.write(header)
with open(ccname, 'w') as f: with open(ccname, 'w', encoding='utf8') as f:
f.write(source) f.write(source)
m = compile_custom_ops([hname, ccname]) m = compile_custom_ops([hname, ccname])
return getattr(m, op_name) return getattr(m, op_name)
@ -679,7 +679,7 @@ def compile_custom_ops(
dirname = os.path.dirname(name) dirname = os.path.dirname(name)
if dirname.endswith("inc"): if dirname.endswith("inc"):
includes.append(dirname) includes.append(dirname)
with open(name, "r") as f: with open(name, "r", encoding='utf8') as f:
if "@pyjt" in f.read(): if "@pyjt" in f.read():
pyjt_includes.append(name) pyjt_includes.append(name)
bname = os.path.basename(name) bname = os.path.basename(name)
@ -736,7 +736,7 @@ def compile_custom_ops(
"init_module(PyModuleDef* mdef, PyObject* m) {", "init_module(PyModuleDef* mdef, PyObject* m) {",
f"jittor::pyjt_def_{bname}(m);") f"jittor::pyjt_def_{bname}(m);")
with open(gen_head_fname, "w") as f: with open(gen_head_fname, "w", encoding='utf8') as f:
f.write(gen_src) f.write(gen_src)
LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvv(f"Build custum ops lib:{gen_lib}")
@ -781,7 +781,7 @@ def compile_extern():
files = os.listdir(jittor_path_llvm) files = os.listdir(jittor_path_llvm)
# test_pass.cc is used for test link problem of llvm pass plugin # test_pass.cc is used for test link problem of llvm pass plugin
test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc") test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
with open(test_pass_path, 'w') as f: with open(test_pass_path, 'w', encoding='utf8') as f:
f.write("int main() {return 0;}") f.write("int main() {return 0;}")
# -fno-rtti fix link error # -fno-rtti fix link error
@ -1082,7 +1082,7 @@ if os.name == 'nt':
cc_flags = cc_flags.replace("-lstdc++", "") cc_flags = cc_flags.replace("-lstdc++", "")
cc_flags = cc_flags.replace("-ldl", "") cc_flags = cc_flags.replace("-ldl", "")
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} " cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
cc_flags += " -EHa -MD " cc_flags += " -EHa -MD -utf-8 "
import jittor_utils import jittor_utils
if jittor_utils.msvc_path: if jittor_utils.msvc_path:
mp = jittor_utils.msvc_path mp = jittor_utils.msvc_path
@ -1217,7 +1217,7 @@ gen_jit_tests()
op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True) op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True)
jit_src = gen_jit_op_maker(op_headers) jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src) LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f: with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w', encoding='utf8') as f:
f.write(jit_src) f.write(jit_src)
cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" ' cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" '
# gen pyjt # gen pyjt

View File

@ -460,7 +460,7 @@
" train_loss_list.append(train_loss)\n", " train_loss_list.append(train_loss)\n",
" # 在验证集上进行验证,模型参数不做更新。\n", " # 在验证集上进行验证,模型参数不做更新。\n",
" val_loss = val(model, x_val_var, y_val_var, loss_function)\n", " val_loss = val(model, x_val_var, y_val_var, loss_function)\n",
" val_loss_list.append(val_loss)\n", " val_loss_list.append(val_loss.item())\n",
" \n", " \n",
"# 打印训练结束后的模型参数\n", "# 打印训练结束后的模型参数\n",
"print(\"After training: \\n\", model.state_dict())" "print(\"After training: \\n\", model.state_dict())"
@ -598,4 +598,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

View File

@ -192,7 +192,7 @@
" \n", " \n",
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n", " plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 利用 matplotlib 根据第一个 input 绘制手写数字的图像\n",
" plt.show() # 展示图像\n", " plt.show() # 展示图像\n",
" print(\"target:\", targets[num].data[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n", " print(\"target:\", targets[num].numpy()[0]) # 打印第一个 input 数据的真实标签值,即手写数字图像所表达的真实数字\n",
" break" " break"
] ]
}, },
@ -910,12 +910,12 @@
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n", " outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" loss = loss_function(outputs, targets) # 计算损失函数\n", " loss = loss_function(outputs, targets) # 计算损失函数\n",
" optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n", " optimizer.step(loss) # 根据损失函数,对模型参数进行优化、更新\n",
" train_losses.append(loss) # 记录该批次的 Loss\n", " train_losses.append(loss.item()) # 记录该批次的 Loss\n",
" \n", " \n",
" if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n", " if batch_idx % 10 == 0: # 每十个批次,打印一次训练集上的 Loss \n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx, len(train_loader),\n", " epoch, batch_idx, len(train_loader),\n",
" 100. * batch_idx / len(train_loader), loss.data[0]))\n", " 100. * batch_idx / len(train_loader), loss.item()))\n",
" return train_losses # 返回本纪元的 Loss\n", " return train_losses # 返回本纪元的 Loss\n",
"\n", "\n",
"\n", "\n",
@ -926,8 +926,8 @@
" total_num = 0 # 本纪元数据总数\n", " total_num = 0 # 本纪元数据总数\n",
" for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n", " for batch_idx, (inputs, targets) in enumerate(val_loader): # 通过测试集加载器,按批次迭代数据\n",
" outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n", " outputs = model(inputs) # 通过模型预测手写数字。outputs 中每个数据输出有 10 个分量,对应十个数字的相似度\n",
" pred = np.argmax(outputs.data, axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n", " pred = np.argmax(outputs.numpy(), axis=1) # 根据 10 个分量,选择最大相似度的为预测的数字值\n",
" correct = np.sum(targets.data==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n", " correct = np.sum(targets.numpy()==pred) # 计算本批次中,正确预测的次数,即数据标签等于预测值的数目\n",
" batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n", " batch_size = inputs.shape[0] # 计算本批次中,数据的总数目\n",
" acc = correct / batch_size # 计算本批次的正确率\n", " acc = correct / batch_size # 计算本批次的正确率\n",
" \n", " \n",
@ -1075,10 +1075,10 @@
" plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n", " plt.imshow(inputs[num].numpy().transpose(1, 2, 0)) # 绘制该数据的手写数字图像\n",
" plt.show() \n", " plt.show() \n",
" \n", " \n",
" print(\"target:\", targets[num].data[0]) # 打印该数据的真实标签值\n", " print(\"target:\", targets[num].numpy()[0]) # 打印该数据的真实标签值\n",
" \n", " \n",
" outputs = model(inputs) # 模型根据输入数据进行预测\n", " outputs = model(inputs) # 模型根据输入数据进行预测\n",
" pred = np.argmax(outputs.data, axis=1) # 根据最大相似度得到预测值\n", " pred = np.argmax(outputs.numpy(), axis=1) # 根据最大相似度得到预测值\n",
" print(\"prediction:\", pred[num]) # 打印该数据的预测值\n", " print(\"prediction:\", pred[num]) # 打印该数据的预测值\n",
" break" " break"
] ]
@ -1158,4 +1158,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 4 "nbformat_minor": 4
} }

View File

@ -63,5 +63,5 @@ for mdname in all_md:
ipynb_name = os.path.basename(mdname[:-2])+"ipynb" ipynb_name = os.path.basename(mdname[:-2])+"ipynb"
ipynb_name = os.path.join(notebook_dir, ipynb_name) ipynb_name = os.path.join(notebook_dir, ipynb_name)
print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name) print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name)
with open(ipynb_name, "w") as f: with open(ipynb_name, "w", encoding='utf8') as f:
f.write(json.dumps(ipynb)) f.write(json.dumps(ipynb))

View File

@ -858,13 +858,13 @@ def compile_src(src, h, basename):
def compile_single(head_file_name, src_file_name, src=None): def compile_single(head_file_name, src_file_name, src=None):
basename = os.path.basename(head_file_name).split(".")[0] basename = os.path.basename(head_file_name).split(".")[0]
if src==None: if src==None:
with open(head_file_name, 'r') as f: with open(head_file_name, 'r', encoding='utf8') as f:
src = f.read() src = f.read()
code = compile_src(src, head_file_name, basename) code = compile_src(src, head_file_name, basename)
if not code: return False if not code: return False
LOG.vvv("write to", src_file_name) LOG.vvv("write to", src_file_name)
LOG.vvvv(code) LOG.vvvv(code)
with open(src_file_name, 'w') as f: with open(src_file_name, 'w', encoding='utf8') as f:
f.write(code) f.write(code)
return True return True
@ -875,14 +875,14 @@ def compile(cache_path, jittor_path):
basenames = [] basenames = []
pyjt_names = [] pyjt_names = []
for h in headers: for h in headers:
with open(h, 'r') as f: with open(h, 'r', encoding='utf8') as f:
src = f.read() src = f.read()
bh = os.path.basename(h) bh = os.path.basename(h)
# jit_op_maker.h merge compile with var_holder.h # jit_op_maker.h merge compile with var_holder.h
if bh == "var_holder.h": continue if bh == "var_holder.h": continue
if bh == "jit_op_maker.h": if bh == "jit_op_maker.h":
with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: with open(os.path.join(jittor_path, "src", "var_holder.h"), "r", encoding='utf8') as f:
src = f.read() + src src = f.read() + src
basename = bh.split(".")[0] basename = bh.split(".")[0]
fname = "pyjt_"+basename+".cc" fname = "pyjt_"+basename+".cc"
@ -913,7 +913,7 @@ def compile(cache_path, jittor_path):
fname = os.path.join(cache_path, "gen", "pyjt_all.cc") fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
LOG.vvv(("write to", fname)) LOG.vvv(("write to", fname))
LOG.vvvv(code) LOG.vvvv(code)
with open(fname, "w") as f: with open(fname, "w", encoding='utf8') as f:
f.write(code) f.write(code)
pyjt_names.append(fname) pyjt_names.append(fname)
return pyjt_names return pyjt_names

View File

@ -1,7 +1,7 @@
import os import os
def fix_config(in_name, out_name, src_path, out_path): def fix_config(in_name, out_name, src_path, out_path):
data = open(in_name, 'r').readlines() data = open(in_name, 'r', encoding='utf8').readlines()
out = [] out = []
for d in data: for d in data:
if d.startswith('INPUT ='): if d.startswith('INPUT ='):
@ -9,7 +9,7 @@ def fix_config(in_name, out_name, src_path, out_path):
elif d.startswith('OUTPUT_DIRECTORY ='): elif d.startswith('OUTPUT_DIRECTORY ='):
d = f'OUTPUT_DIRECTORY ={out_path}\n' d = f'OUTPUT_DIRECTORY ={out_path}\n'
out.append(d) out.append(d)
f = open(out_name, 'w') f = open(out_name, 'w', encoding='utf8')
f.writelines(out) f.writelines(out)
jt_path = os.getcwd() jt_path = os.getcwd()

View File

@ -153,7 +153,7 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
const char* msg = ""; const char* msg = "";
LOGvv << "Opening jit lib:" << name; LOGvv << "Opening jit lib:" << name;
#ifdef _WIN32 #ifdef _WIN32
void* handle = (void*)LoadLibraryExA(name.c_str(), nullptr, void* handle = (void*)LoadLibraryExA(Utf8ToGbk(name.c_str()).c_str(), nullptr,
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DEFAULT_DIRS |
LOAD_LIBRARY_SEARCH_USER_DIRS); LOAD_LIBRARY_SEARCH_USER_DIRS);
#elif defined(__linux__) #elif defined(__linux__)
@ -206,13 +206,15 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c
string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc"); string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc");
#ifdef _WIN32 #ifdef _WIN32
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll"); string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll");
string jit_src_path2 = Utf8ToGbk(jit_src_path.c_str());
#else #else
string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so"); string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so");
string& jit_src_path2 = jit_src_path;
#endif #endif
string other_src; string other_src;
LOGvvv << "Generate" << jit_src_path >> "\n" >> src; LOGvvv << "Generate" << jit_src_path >> "\n" >> src;
if (rewrite_op || !file_exist(jit_src_path)) if (rewrite_op || !file_exist(jit_src_path2))
write(jit_src_path, src); write(jit_src_path2, src);
string cmd; string cmd;
auto symbol_name = get_symbol_name(jit_key); auto symbol_name = get_symbol_name(jit_key);

View File

@ -45,13 +45,7 @@ DEF_IS(string, PyObject*) to_py_object(const string& a) {
DEF_IS(string, string) from_py_object(PyObject* obj) { DEF_IS(string, string) from_py_object(PyObject* obj) {
Py_ssize_t size; Py_ssize_t size;
#ifdef _WIN32
PyObjHolder a(PyUnicode_AsEncodedString(obj, win_encode.c_str(), "strict"));
char* s;
auto ret = PyBytes_AsStringAndSize(a.obj, &s, &size);
#else
const char* s = PyUnicode_AsUTF8AndSize(obj, &size); const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
#endif
CHECK(s); CHECK(s);
return string(s, size); return string(s, size);
} }

View File

@ -234,7 +234,15 @@ static inline bool is_full_path(const string& name) {
#endif #endif
} }
bool cache_compile(string cmd, const string& cache_path, const string& jittor_path) { bool cache_compile(string cmd, const string& cache_path_, const string& jittor_path_) {
#ifdef _WIN32
cmd = Utf8ToGbk(cmd.c_str());
string cache_path = Utf8ToGbk(cache_path_.c_str());
string jittor_path = Utf8ToGbk(jittor_path_.c_str());
#else
const string& cache_path = cache_path_;
const string& jittor_path = jittor_path_;
#endif
vector<string> input_names; vector<string> input_names;
map<string,vector<string>> extra; map<string,vector<string>> extra;
string output_name; string output_name;
@ -255,6 +263,9 @@ bool cache_compile(string cmd, const string& cache_path, const string& jittor_pa
continue; continue;
processed.insert(input_names[i]); processed.insert(input_names[i]);
auto src = read_all(input_names[i]); auto src = read_all(input_names[i]);
#ifdef _WIN32
src = Utf8ToGbk(src.c_str());
#endif
auto back = input_names[i].back(); auto back = input_names[i].back();
// *.lib // *.lib
if (back == 'b') continue; if (back == 'b') continue;

View File

@ -316,19 +316,13 @@ int register_sigaction() {
return 0; return 0;
} }
#ifdef _WIN32
string win_encode;
#endif
static int log_init() { static int log_init() {
#ifdef _WIN32
SetConsoleCP(CP_UTF8);
SetConsoleOutputCP(CP_UTF8);
#endif
register_sigaction(); register_sigaction();
std::atexit(log_exiting); std::atexit(log_exiting);
#ifdef _WIN32
if (getenv("JITTOR_ENCODE"))
win_encode = getenv("JITTOR_ENCODE");
else
win_encode = "gbk";
#endif
return 1; return 1;
} }
@ -456,6 +450,39 @@ If you still have problems, please contact us:
} }
#ifdef _WIN32 #ifdef _WIN32
string GbkToUtf8(const char *src_str)
{
int len = MultiByteToWideChar(CP_ACP, 0, src_str, -1, NULL, 0);
wchar_t* wstr = new wchar_t[len + 1];
memset(wstr, 0, len + 1);
MultiByteToWideChar(CP_ACP, 0, src_str, -1, wstr, len);
len = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL);
char* str = new char[len + 1];
memset(str, 0, len + 1);
WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, len, NULL, NULL);
string strTemp = str;
if (wstr) delete[] wstr;
if (str) delete[] str;
return strTemp;
}
string Utf8ToGbk(const char *src_str)
{
int len = MultiByteToWideChar(CP_UTF8, 0, src_str, -1, NULL, 0);
wchar_t* wszGBK = new wchar_t[len + 1];
memset(wszGBK, 0, len * 2 + 2);
MultiByteToWideChar(CP_UTF8, 0, src_str, -1, wszGBK, len);
len = WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, NULL, 0, NULL, NULL);
char* szGBK = new char[len + 1];
memset(szGBK, 0, len + 1);
WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, szGBK, len, NULL, NULL);
string strTemp(szGBK);
if (wszGBK) delete[] wszGBK;
if (szGBK) delete[] szGBK;
return strTemp;
}
int system_popen(const char *cmd, const char* cwd) { int system_popen(const char *cmd, const char* cwd) {
HANDLE g_hChildStd_OUT_Rd = NULL; HANDLE g_hChildStd_OUT_Rd = NULL;
HANDLE g_hChildStd_OUT_Wr = NULL; HANDLE g_hChildStd_OUT_Wr = NULL;

View File

@ -16,6 +16,10 @@ namespace jittor {
// define in tracer.cc // define in tracer.cc
void print_trace(); void print_trace();
void breakpoint(); void breakpoint();
#ifdef _WIN32
string GbkToUtf8(const char *src_str);
string Utf8ToGbk(const char *src_str);
#endif
constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) { constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) {
return path[index] return path[index]
@ -277,8 +281,4 @@ bool check_vlog(const char* fileline, int verbose);
void system_with_check(const char* cmd, const char* cwd=nullptr); void system_with_check(const char* cmd, const char* cwd=nullptr);
#ifdef _WIN32
extern string win_encode;
#endif
} // jittor } // jittor

View File

@ -181,12 +181,24 @@ struct VarHolder {
inline void set_requires_grad(bool flag) { inline void set_requires_grad(bool flag) {
if (flag == get_requires_grad()) return; if (flag == get_requires_grad()) return;
if (flag) if (flag)
_update(this); start_grad();
else else
stop_grad(); stop_grad();
return; return;
} }
/**
* enable the gradient calculation for the Var.
*/
// @pyjt(start_grad)
// @attrs(return_self)
inline VarHolder* start_grad() {
if (!var->dtype().is_float())
LOGw << "cannot enable grad of a non-float value:" << var;
_update(this);
return this;
}
// @pyjt(__get__uncertain_shape) // @pyjt(__get__uncertain_shape)
inline NanoVector uncertain_shape() { inline NanoVector uncertain_shape() {
return var->shape; return var->shape;

View File

@ -142,6 +142,11 @@ class TestGrad(unittest.TestCase):
expect_error(lambda: jt.grad(z, [y1])) expect_error(lambda: jt.grad(z, [y1]))
dx, = jt.grad(z, [x]) dx, = jt.grad(z, [x])
self.assertEqual(dx.data, 48) self.assertEqual(dx.data, 48)
def test_int_enable_grad(self):
a = jt.int([1,2,3])
a.requires_grad = True
a.start_grad()
def test_nth_grad(self): def test_nth_grad(self):
x = jt.array(2.0) x = jt.array(2.0)

View File

@ -23,7 +23,7 @@ def check_is_both(src):
for mdname in all_src_md: for mdname in all_src_md:
print(mdname) print(mdname)
with open(mdname, "r") as f: with open(mdname, "r", encoding='utf8') as f:
src = f.read() src = f.read()
src = src.split("```") src = src.split("```")
en_src = [] en_src = []
@ -47,8 +47,8 @@ for mdname in all_src_md:
cn_src.append("\n".join(cn_s)) cn_src.append("\n".join(cn_s))
en_src = "```".join(en_src) en_src = "```".join(en_src)
cn_src = "```".join(cn_src) cn_src = "```".join(cn_src)
with open(mdname.replace(".src.md", ".md"), 'w') as f: with open(mdname.replace(".src.md", ".md"), 'w', encoding='utf8') as f:
f.write(en_src) f.write(en_src)
with open(mdname.replace(".src.md", ".cn.md"), 'w') as f: with open(mdname.replace(".src.md", ".cn.md"), 'w', encoding='utf8') as f:
f.write(cn_src) f.write(cn_src)