polish windows

This commit is contained in:
Dun Liang 2021-09-29 20:03:33 +08:00
parent 9293045993
commit cfbadbbfb0
9 changed files with 224 additions and 275 deletions

View File

@ -1043,7 +1043,7 @@ if os.name == 'nt':
cc_flags = cc_flags.replace("-lstdc++", "")
cc_flags = cc_flags.replace("-ldl", "")
cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} "
cc_flags += " -EHsc "
cc_flags += " -EHa "
import jittor_utils
if jittor_utils.msvc_path:
mp = jittor_utils.msvc_path
@ -1109,6 +1109,7 @@ make_cache_dir(cache_path)
make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen"))
make_cache_dir(os.path.join(cache_path, "tmp"))
ck_path = os.path.join(cache_path, "checkpoints")
make_cache_dir(ck_path)
@ -1128,7 +1129,7 @@ if has_cuda:
# nvcc don't support -Wall option
if os.name == 'nt':
nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:")
nvcc_flags = nvcc_flags.replace("-EHsc", "-Xcompiler -EHsc")
nvcc_flags = nvcc_flags.replace("-EHa", "-Xcompiler -EHa")
nvcc_flags = nvcc_flags.replace("-nologo", "")
nvcc_flags = nvcc_flags.replace("-std:", "-std=")
nvcc_flags = nvcc_flags.replace("-Fo:", "-o")

View File

@ -675,7 +675,7 @@ def compile_src(src, h, basename):
error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info)
func = f"""
{func_cast}[]{func_head} {{
try {{_JT_SEH_START3;
try {{
{func_fill};
uint64 arg_filled=0;
(void)arg_filled;
@ -689,7 +689,7 @@ def compile_src(src, h, basename):
for did in range(len(arr_func_return))
])}
LOGf << "Not a valid call.";
_JT_SEH_END3; }} catch (const std::exception& e) {{
}} catch (const std::exception& e) {{
if (!PyErr_Occurred()) {{
std::stringstream ss;
ss {error_log_code};

View File

@ -20,6 +20,24 @@ static auto make_binary = get_op_info("binary")
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
#ifdef _WIN32
template<class T> struct StackIniter {
T* a;
int n;
inline StackIniter(T* a, int n) :a(a), n(n) {
for (int i=0; i<n; i++)
new(a+i) T();
}
inline ~StackIniter() {
for (int i=0; i<n; i++)
a[i].~T();
}
};
#define STACK_ALLOC2(T, a, n) T* a = (T*)_alloca(sizeof(T)*(n)); StackIniter<T> __init_##a(a, n);
#else
#define STACK_ALLOC2(T, a, n) T* a = (T*)_alloca(sizeof(T)*(n))
#endif
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
if (dout == nullptr) return nullptr;
@ -154,7 +172,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
// backward together
auto n_i = op->inputs().size();
STACK_ALLOC(Var*, douts, n_o);
STACK_ALLOC(VarPtr, dins, n_i);
STACK_ALLOC2(VarPtr, dins, n_i);
// dump "for (Var* out : op->outputs())"
for (int i=0; i<n_o; i++,j++) {
auto id = id_buffer[j].second;

View File

@ -8,9 +8,12 @@
#include "grad.h"
#include "pyjt/py_obj_holder.h"
#include "init.h"
#include "utils/seh.h"
namespace jittor {
SEH_HOOK;
// Those function is generated by python
EXTERN_LIB void pyjt_def_all(PyObject* m);

View File

@ -19,208 +19,11 @@
#include <iterator>
#include <algorithm>
#include <cstring>
#ifdef _WIN32
#include <exception>
#include <windows.h>
#include <eh.h>
#include <sstream>
#endif
#include "utils/seh.h"
namespace jittor {
#ifdef _WIN32
using std::stringstream;
void raise_win_error(int ierr) {
DWORD err = (DWORD)ierr;
WCHAR *s_buf = NULL; /* Free via LocalFree */
stringstream message;
if (err==0) {
err = GetLastError();
}
auto len = FormatMessageW(
/* Error API error */
FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, /* no message source */
err,
MAKELANGID(LANG_NEUTRAL,
SUBLANG_DEFAULT), /* Default language */
(LPWSTR) &s_buf,
0, /* size not used */
NULL); /* no args */
if (len==0) {
/* Only seen this in out of mem situations */
message << "Windows Error " << err;
s_buf = NULL;
} else {
/* remove trailing cr/lf and dots */
while (len > 0 && (s_buf[len-1] <= L' ' || s_buf[len-1] == L'.'))
s_buf[--len] = L'\0';
message << s_buf;
}
if (s_buf)
LocalFree(s_buf);
throw std::runtime_error(message.str());
}
void raise_cxx_exception(DWORD code, const EXCEPTION_RECORD* pr) {
/* The 'code' is a normal win32 error code so it could be handled by
raise_win_error(). However, for some errors, we have additional
information not included in the error code. We handle those here and
delegate all others to the generic function. */
stringstream message;
switch (code) {
case EXCEPTION_ACCESS_VIOLATION:
/* The thread attempted to read from or write
to a virtual address for which it does not
have the appropriate access. */
if (pr->ExceptionInformation[0] == 0)
message << "exception: access violation reading " << (void*)pr->ExceptionInformation[1];
else
message << "exception: access violation writing " << (void*)pr->ExceptionInformation[1];
break;
case EXCEPTION_BREAKPOINT:
/* A breakpoint was encountered. */
message << "exception: breakpoint encountered";
break;
case EXCEPTION_DATATYPE_MISALIGNMENT:
/* The thread attempted to read or write data that is
misaligned on hardware that does not provide
alignment. For example, 16-bit values must be
aligned on 2-byte boundaries, 32-bit values on
4-byte boundaries, and so on. */
message << "exception: datatype misalignment";
break;
case EXCEPTION_SINGLE_STEP:
/* A trace trap or other single-instruction mechanism
signaled that one instruction has been executed. */
message << "exception: single step";
break;
case EXCEPTION_ARRAY_BOUNDS_EXCEEDED:
/* The thread attempted to access an array element
that is out of bounds, and the underlying hardware
supports bounds checking. */
message << "exception: array bounds exceeded";
break;
case EXCEPTION_FLT_DENORMAL_OPERAND:
/* One of the operands in a floating-point operation
is denormal. A denormal value is one that is too
small to represent as a standard floating-point
value. */
message << "exception: floating-point operand denormal";
break;
case EXCEPTION_FLT_DIVIDE_BY_ZERO:
/* The thread attempted to divide a floating-point
value by a floating-point divisor of zero. */
message << "exception: float divide by zero";
break;
case EXCEPTION_FLT_INEXACT_RESULT:
/* The result of a floating-point operation cannot be
represented exactly as a decimal fraction. */
message << "exception: float inexact";
break;
case EXCEPTION_FLT_INVALID_OPERATION:
/* This exception represents any floating-point
exception not included in this list. */
message << "exception: float invalid operation";
break;
case EXCEPTION_FLT_OVERFLOW:
/* The exponent of a floating-point operation is
greater than the magnitude allowed by the
corresponding type. */
message << "exception: float overflow";
break;
case EXCEPTION_FLT_STACK_CHECK:
/* The stack overflowed or underflowed as the result
of a floating-point operation. */
message << "exception: stack over/underflow";
break;
case EXCEPTION_STACK_OVERFLOW:
/* The stack overflowed or underflowed as the result
of a floating-point operation. */
message << "exception: stack overflow";
break;
case EXCEPTION_FLT_UNDERFLOW:
/* The exponent of a floating-point operation is less
than the magnitude allowed by the corresponding
type. */
message << "exception: float underflow";
break;
case EXCEPTION_INT_DIVIDE_BY_ZERO:
/* The thread attempted to divide an integer value by
an integer divisor of zero. */
message << "exception: integer divide by zero";
break;
case EXCEPTION_INT_OVERFLOW:
/* The result of an integer operation caused a carry
out of the most significant bit of the result. */
message << "exception: integer overflow";
break;
case EXCEPTION_PRIV_INSTRUCTION:
/* The thread attempted to execute an instruction
whose operation is not allowed in the current
machine mode. */
message << "exception: privileged instruction";
break;
case EXCEPTION_NONCONTINUABLE_EXCEPTION:
/* The thread attempted to continue execution after a
noncontinuable exception occurred. */
message << "exception: nocontinuable";
break;
case 0xE06D7363:
/* magic number(0xE06D7363) of c++ exception:
https://devblogs.microsoft.com/oldnewthing/20100730-00/?p=13273
*/
message << "Error c++ exception";
break;
default:
raise_win_error(code);
break;
}
// std::cout << message.str() << std::endl;
throw std::runtime_error(message.str());
}
DWORD HandleException(EXCEPTION_POINTERS *ptrs,
DWORD *pdw, EXCEPTION_RECORD *record)
{
*pdw = ptrs->ExceptionRecord->ExceptionCode;
*record = *ptrs->ExceptionRecord;
/* We don't want to catch breakpoint exceptions, they are used to attach
* a debugger to the process.
*/
if (*pdw == EXCEPTION_BREAKPOINT)
return EXCEPTION_CONTINUE_SEARCH;
return EXCEPTION_EXECUTE_HANDLER;
}
#endif
SEH_HOOK;
void init_subprocess() {
#ifdef __linux__
@ -393,7 +196,7 @@ static void pyjt_def_core(PyObject* m) {
{ R""(cache_compile)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -470,7 +273,7 @@ static void pyjt_def_core(PyObject* m) {
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}
@ -487,7 +290,7 @@ bool cache_compile(const string& cmd, const string& cache_path="", const string&
{ R""(log)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -557,7 +360,7 @@ bool cache_compile(const string& cmd, const string& cache_path="", const string&
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}
@ -574,7 +377,7 @@ void log(const std::string& fileline, const char* level, int verbose, const std:
{ R""(init_subprocess)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -586,7 +389,7 @@ void log(const std::string& fileline, const char* level, int verbose, const std:
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}
@ -603,7 +406,7 @@ void init_subprocess()
{ R""(log_capture_start)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -615,7 +418,7 @@ void init_subprocess()
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}
@ -632,7 +435,7 @@ void log_capture_start()
{ R""(log_capture_stop)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -644,7 +447,7 @@ void log_capture_start()
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}
@ -661,7 +464,7 @@ void log_capture_stop()
{ R""(log_capture_read)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -675,7 +478,7 @@ void log_capture_stop()
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}
@ -692,7 +495,7 @@ void log_capture_read()
{ R""(ostream_redirect)"",
(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* {
try {_JT_SEH_START3;
try {
;
uint64 arg_filled=0;
(void)arg_filled;
@ -740,7 +543,7 @@ void log_capture_read()
}
LOGf << "Not a valid call.";
_JT_SEH_END3; } catch (const std::exception& e) {
} catch (const std::exception& e) {
if (!PyErr_Occurred()) {
PyErr_Format(PyExc_RuntimeError, e.what());
}

View File

@ -2,76 +2,193 @@
#pragma once
#ifdef _WIN32
#include <windows.h>
#include <exception>
#include <eh.h>
#include <sstream>
#include "common.h"
namespace jittor {
EXTERN_LIB void raise_win_error(int ierr);
EXTERN_LIB void raise_cxx_exception(DWORD code, const EXCEPTION_RECORD* pr);
EXTERN_LIB DWORD HandleException(EXCEPTION_POINTERS *ptrs,
DWORD *pdw, EXCEPTION_RECORD *record);
#define _JT_SEH_TRY \
DWORD dwExceptionCode = 0; \
EXCEPTION_RECORD record; \
__try {
#define _JT_SEH_CATCH \
} \
__except (HandleException(GetExceptionInformation(), \
&dwExceptionCode, &record)) { \
raise_cxx_exception(dwExceptionCode, &record); \
}
#define _JT_SEH_START \
return [&]() { \
_JT_SEH_TRY; \
return [&]() {
#define _JT_SEH_END \
}(); \
_JT_SEH_CATCH; \
}(); \
using std::stringstream;
inline void raise_win_error(int ierr) {
DWORD err = (DWORD)ierr;
WCHAR *s_buf = NULL; /* Free via LocalFree */
stringstream message;
#define _JT_SEH_START2 \
[&]() { \
_JT_SEH_TRY;
if (err==0) {
err = GetLastError();
}
auto len = FormatMessageW(
/* Error API error */
FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, /* no message source */
err,
MAKELANGID(LANG_NEUTRAL,
SUBLANG_DEFAULT), /* Default language */
(LPWSTR) &s_buf,
0, /* size not used */
NULL); /* no args */
if (len==0) {
/* Only seen this in out of mem situations */
message << "Windows Error " << err;
s_buf = NULL;
} else {
/* remove trailing cr/lf and dots */
while (len > 0 && (s_buf[len-1] <= L' ' || s_buf[len-1] == L'.'))
s_buf[--len] = L'\0';
message << s_buf;
}
if (s_buf)
LocalFree(s_buf);
throw std::runtime_error(message.str());
}
#define _JT_SEH_END2 \
_JT_SEH_CATCH; \
}();
inline void raise_cxx_exception(unsigned int code, _EXCEPTION_POINTERS* pExp) {
std::cerr << "raise_cxx_exception " << code << std::endl;
EXCEPTION_RECORD* pr = pExp->ExceptionRecord;
#ifdef JT_SEH_FULL
/* The 'code' is a normal win32 error code so it could be handled by
raise_win_error(). However, for some errors, we have additional
information not included in the error code. We handle those here and
delegate all others to the generic function. */
stringstream message;
switch (code) {
case EXCEPTION_ACCESS_VIOLATION:
/* The thread attempted to read from or write
to a virtual address for which it does not
have the appropriate access. */
if (pr->ExceptionInformation[0] == 0)
message << "exception: access violation reading " << (void*)pr->ExceptionInformation[1];
else
message << "exception: access violation writing " << (void*)pr->ExceptionInformation[1];
break;
case EXCEPTION_BREAKPOINT:
/* A breakpoint was encountered. */
message << "exception: breakpoint encountered";
break;
#define _JT_SEH_START3 \
return [&]() { \
_JT_SEH_TRY; \
return [&]() {
case EXCEPTION_DATATYPE_MISALIGNMENT:
/* The thread attempted to read or write data that is
misaligned on hardware that does not provide
alignment. For example, 16-bit values must be
aligned on 2-byte boundaries, 32-bit values on
4-byte boundaries, and so on. */
message << "exception: datatype misalignment";
break;
#define _JT_SEH_END3 \
}(); \
_JT_SEH_CATCH; \
}(); \
case EXCEPTION_SINGLE_STEP:
/* A trace trap or other single-instruction mechanism
signaled that one instruction has been executed. */
message << "exception: single step";
break;
#else
case EXCEPTION_ARRAY_BOUNDS_EXCEEDED:
/* The thread attempted to access an array element
that is out of bounds, and the underlying hardware
supports bounds checking. */
message << "exception: array bounds exceeded";
break;
#define _JT_SEH_START3
#define _JT_SEH_END3
case EXCEPTION_FLT_DENORMAL_OPERAND:
/* One of the operands in a floating-point operation
is denormal. A denormal value is one that is too
small to represent as a standard floating-point
value. */
message << "exception: floating-point operand denormal";
break;
#endif
case EXCEPTION_FLT_DIVIDE_BY_ZERO:
/* The thread attempted to divide a floating-point
value by a floating-point divisor of zero. */
message << "exception: float divide by zero";
break;
case EXCEPTION_FLT_INEXACT_RESULT:
/* The result of a floating-point operation cannot be
represented exactly as a decimal fraction. */
message << "exception: float inexact";
break;
case EXCEPTION_FLT_INVALID_OPERATION:
/* This exception represents any floating-point
exception not included in this list. */
message << "exception: float invalid operation";
break;
case EXCEPTION_FLT_OVERFLOW:
/* The exponent of a floating-point operation is
greater than the magnitude allowed by the
corresponding type. */
message << "exception: float overflow";
break;
case EXCEPTION_FLT_STACK_CHECK:
/* The stack overflowed or underflowed as the result
of a floating-point operation. */
message << "exception: stack over/underflow";
break;
case EXCEPTION_STACK_OVERFLOW:
/* The stack overflowed or underflowed as the result
of a floating-point operation. */
message << "exception: stack overflow";
break;
case EXCEPTION_FLT_UNDERFLOW:
/* The exponent of a floating-point operation is less
than the magnitude allowed by the corresponding
type. */
message << "exception: float underflow";
break;
case EXCEPTION_INT_DIVIDE_BY_ZERO:
/* The thread attempted to divide an integer value by
an integer divisor of zero. */
message << "exception: integer divide by zero";
break;
case EXCEPTION_INT_OVERFLOW:
/* The result of an integer operation caused a carry
out of the most significant bit of the result. */
message << "exception: integer overflow";
break;
case EXCEPTION_PRIV_INSTRUCTION:
/* The thread attempted to execute an instruction
whose operation is not allowed in the current
machine mode. */
message << "exception: privileged instruction";
break;
case EXCEPTION_NONCONTINUABLE_EXCEPTION:
/* The thread attempted to continue execution after a
noncontinuable exception occurred. */
message << "exception: nocontinuable";
break;
case 0xE06D7363:
/* magic number(0xE06D7363) of c++ exception:
https://devblogs.microsoft.com/oldnewthing/20100730-00/?p=13273
*/
message << "Error c++ exception";
break;
default:
raise_win_error(code);
break;
}
// std::cout << message.str() << std::endl;
throw std::runtime_error(message.str());
}
}
#define SEH_HOOK int _seh_hook = (_set_se_translator(raise_cxx_exception), 0)
#else
#define _JT_SEH_TRY
#define _JT_SEH_CATCH
#define _JT_SEH_START
#define _JT_SEH_END
#define _JT_SEH_START2
#define _JT_SEH_END2
#define _JT_SEH_START3
#define _JT_SEH_END3
#define SEH_HOOK
#endif

View File

@ -119,7 +119,7 @@ class TestPad(unittest.TestCase):
def test_save_image(self):
arr = jt.array(np.random.randn(16,3,10,10))
jt.save_image(arr, "/tmp/a.jpg")
jt.save_image(arr, jt.flags.cache_path+"/tmp/a.jpg")
def test_unbind(self):
arr = np.random.randn(2,3,4)
@ -242,8 +242,9 @@ class TestPad(unittest.TestCase):
class TestOther(unittest.TestCase):
def test_save(self):
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
jt.save(pp, "/tmp/xx.pkl")
x = jt.load("/tmp/xx.pkl")
name = jt.flags.cache_path+"/xx.pkl"
jt.save(pp, name)
x = jt.load(name)
assert x[:2] == [1,2]
assert (x[2] == np.array([1,2,3])).all()
assert x[3]['a'] == [1,2,3]

View File

@ -23,6 +23,8 @@ import jittor.transform as trans
import time
skip_this_test = False
if os.name == 'nt':
skip_this_test = True
class MnistNet(Module):
def __init__(self):

View File

@ -441,10 +441,14 @@ _py3_include_path = None
_py3_extension_suffix = None
if os.name == 'nt':
from pathlib import Path
if check_msvc_install:
if not os.path.isfile(cc_path):
from jittor_utils import install_msvc
install_msvc.install(msvc_path)
mpath = os.path.join(str(Path.home()), ".cache", "jittor", "msvc")
if cc_path.startswith(mpath):
msvc_path = mpath
os.RTLD_NOW = os.RTLD_GLOBAL = os.RTLD_DEEPBIND = 0
path = os.path.dirname(cc_path).replace('/', '\\')
if path: