polish torch issue

This commit is contained in:
Dun Liang 2022-04-05 16:55:01 +08:00
parent 2901e578dc
commit 2721e9fb55
2 changed files with 19 additions and 0 deletions

View File

@ -563,6 +563,8 @@ def setup_mpi():
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
try:
import torch
from jittor_utils import dirty_fix_pytorch_runtime_error
dirty_fix_pytorch_runtime_error()
except:
pass

View File

@ -535,6 +535,23 @@ def get_total_mem():
else:
return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
def dirty_fix_pytorch_runtime_error():
''' This funtion should be called before pytorch.
Example::
import jittor as jt
jt.dirty_fix_pytorch_runtime_error()
import torch
'''
import os, platform
if platform.system() == 'Linux':
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
import jittor_utils
with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
import torch
is_in_ipynb = in_ipynb()
cc = None
LOG = Logwrapper()