mirror of https://github.com/Jittor/Jittor
FIX_TORCH_ERROR by default
This commit is contained in:
parent
a1322782ae
commit
0280f141e2
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.3.11'
|
||||
__version__ = '1.3.3.12'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -560,7 +560,13 @@ def setup_mpi():
|
|||
if k == "mpi_test": continue
|
||||
setattr(core.Var, k, wrapper(mpi_ops.__dict__[k]))
|
||||
|
||||
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
||||
in_mpi = inside_mpi()
|
||||
FIX_TORCH_ERROR = 0
|
||||
if os.name != 'nt' and not in_mpi:
|
||||
FIX_TORCH_ERROR = 1
|
||||
if "FIX_TORCH_ERROR" in os.environ:
|
||||
FIX_TORCH_ERROR = os.environ["FIX_TORCH_ERROR"] != "0"
|
||||
if FIX_TORCH_ERROR:
|
||||
try:
|
||||
import torch
|
||||
from jittor_utils import dirty_fix_pytorch_runtime_error
|
||||
|
@ -570,7 +576,6 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
|
|||
|
||||
cudnn = cublas = curand = cufft = None
|
||||
setup_mpi()
|
||||
in_mpi = inside_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
world_size = mpi.world_size() if in_mpi else 1
|
||||
setup_nccl()
|
||||
|
|
|
@ -101,7 +101,7 @@ std::ostream& operator<<(std::ostream& os, const Var& var) {
|
|||
<< ":s" << var.is_finished()
|
||||
<< ":n" << var.flags.get(NodeFlags::_needed_by_backward)
|
||||
<< ','
|
||||
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr
|
||||
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr << std::dec
|
||||
<< ')' << var.shape;
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << var.__id() << '>';
|
||||
|
|
|
@ -202,7 +202,7 @@ class TestArray(unittest.TestCase):
|
|||
jt.sync([a, b])
|
||||
assert a.data == 1
|
||||
assert b.data == -1
|
||||
assert len(rep) == 2
|
||||
assert len(rep) == 3
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
def test_scalar_fuse_unary_cuda(self):
|
||||
|
|
|
@ -85,14 +85,14 @@ class TestFusedOp(unittest.TestCase):
|
|||
(hv, lv, lo))
|
||||
for i in range(8):
|
||||
check(0,0,0)
|
||||
a = jt.array(1.0).name('a').stop_fuse()
|
||||
b = (a+jt.array(1.0).name('t1').stop_fuse()).name('b')
|
||||
c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c')
|
||||
a = jt.array([1.0,1.0]).name('a').stop_fuse()
|
||||
b = (a+jt.array([1.0,1.0]).name('t1').stop_fuse()).name('b')
|
||||
c = (b+jt.array([1.0,1.0]).name('t2').stop_fuse()).name('c')
|
||||
check(3,5,5)
|
||||
graph = jt.dump_all_graphs()
|
||||
# for n in graph.nodes_info:
|
||||
# print(n)
|
||||
self.assertEqual(c.data, 3)
|
||||
np.testing.assert_allclose(c.data, [3,3])
|
||||
graph2 = jt.dump_all_graphs()
|
||||
print("check", i)
|
||||
for n in graph2.nodes_info:
|
||||
|
|
Loading…
Reference in New Issue