mirror of https://github.com/Jittor/Jittor
polish torch issue
This commit is contained in:
parent
2901e578dc
commit
2721e9fb55
|
@ -563,6 +563,8 @@ def setup_mpi():
|
|||
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
||||
try:
|
||||
import torch
|
||||
from jittor_utils import dirty_fix_pytorch_runtime_error
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -535,6 +535,23 @@ def get_total_mem():
|
|||
else:
|
||||
return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
||||
|
||||
def dirty_fix_pytorch_runtime_error():
|
||||
''' This funtion should be called before pytorch.
|
||||
|
||||
Example::
|
||||
|
||||
import jittor as jt
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
'''
|
||||
import os, platform
|
||||
|
||||
if platform.system() == 'Linux':
|
||||
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||
import jittor_utils
|
||||
with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
|
||||
import torch
|
||||
|
||||
is_in_ipynb = in_ipynb()
|
||||
cc = None
|
||||
LOG = Logwrapper()
|
||||
|
|
Loading…
Reference in New Issue