FIX_TORCH_ERROR by default

This commit is contained in:
Dun Liang 2022-04-26 14:27:05 +08:00
parent a1322782ae
commit 0280f141e2
5 changed files with 14 additions and 9 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.3.11'
__version__ = '1.3.3.12'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -560,7 +560,13 @@ def setup_mpi():
if k == "mpi_test": continue
setattr(core.Var, k, wrapper(mpi_ops.__dict__[k]))
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
in_mpi = inside_mpi()
FIX_TORCH_ERROR = 0
if os.name != 'nt' and not in_mpi:
FIX_TORCH_ERROR = 1
if "FIX_TORCH_ERROR" in os.environ:
FIX_TORCH_ERROR = os.environ["FIX_TORCH_ERROR"] != "0"
if FIX_TORCH_ERROR:
try:
import torch
from jittor_utils import dirty_fix_pytorch_runtime_error
@ -570,7 +576,6 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
cudnn = cublas = curand = cufft = None
setup_mpi()
in_mpi = inside_mpi()
rank = mpi.world_rank() if in_mpi else 0
world_size = mpi.world_size() if in_mpi else 1
setup_nccl()

View File

@ -101,7 +101,7 @@ std::ostream& operator<<(std::ostream& os, const Var& var) {
<< ":s" << var.is_finished()
<< ":n" << var.flags.get(NodeFlags::_needed_by_backward)
<< ','
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr << std::dec
<< ')' << var.shape;
#ifdef NODE_MEMCHECK
os << '<' << var.__id() << '>';

View File

@ -202,7 +202,7 @@ class TestArray(unittest.TestCase):
jt.sync([a, b])
assert a.data == 1
assert b.data == -1
assert len(rep) == 2
assert len(rep) == 3
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
def test_scalar_fuse_unary_cuda(self):

View File

@ -85,14 +85,14 @@ class TestFusedOp(unittest.TestCase):
(hv, lv, lo))
for i in range(8):
check(0,0,0)
a = jt.array(1.0).name('a').stop_fuse()
b = (a+jt.array(1.0).name('t1').stop_fuse()).name('b')
c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c')
a = jt.array([1.0,1.0]).name('a').stop_fuse()
b = (a+jt.array([1.0,1.0]).name('t1').stop_fuse()).name('b')
c = (b+jt.array([1.0,1.0]).name('t2').stop_fuse()).name('c')
check(3,5,5)
graph = jt.dump_all_graphs()
# for n in graph.nodes_info:
# print(n)
self.assertEqual(c.data, 3)
np.testing.assert_allclose(c.data, [3,3])
graph2 = jt.dump_all_graphs()
print("check", i)
for n in graph2.nodes_info: