mirror of https://github.com/Jittor/Jittor
polish torch issue
This commit is contained in:
parent
2901e578dc
commit
2721e9fb55
|
@ -563,6 +563,8 @@ def setup_mpi():
|
||||||
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
from jittor_utils import dirty_fix_pytorch_runtime_error
|
||||||
|
dirty_fix_pytorch_runtime_error()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -535,6 +535,23 @@ def get_total_mem():
|
||||||
else:
|
else:
|
||||||
return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
||||||
|
|
||||||
|
def dirty_fix_pytorch_runtime_error():
|
||||||
|
''' This funtion should be called before pytorch.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
import jittor as jt
|
||||||
|
jt.dirty_fix_pytorch_runtime_error()
|
||||||
|
import torch
|
||||||
|
'''
|
||||||
|
import os, platform
|
||||||
|
|
||||||
|
if platform.system() == 'Linux':
|
||||||
|
os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND
|
||||||
|
import jittor_utils
|
||||||
|
with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
|
||||||
|
import torch
|
||||||
|
|
||||||
is_in_ipynb = in_ipynb()
|
is_in_ipynb = in_ipynb()
|
||||||
cc = None
|
cc = None
|
||||||
LOG = Logwrapper()
|
LOG = Logwrapper()
|
||||||
|
|
Loading…
Reference in New Issue