diff --git a/.dockerignore b/.dockerignore index 19393373..26a123bd 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,28 @@ Dockerfile -**/publish.py \ No newline at end of file +**/publish.py +my +.git +.refresh +__pycache__ +.ipynb_checkpoints/ +.vscode/ +__res/ +perf.data +perf.data.old +*.swp +*.ipynb +*.pdf +*.zip +*.tgz +test.py +extern/mkl/mkldnn_lnx*/* +data/ +build/ +venv/ +*.md +!*.src.md +!README.md +!README.cn.md +python/jittor.egg-info +dist/ +!doc/source/* diff --git a/Dockerfile b/Dockerfile index d0b5cc0b..34a30b83 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,12 +38,14 @@ RUN pip3 install matplotlib RUN apt install openmpi-bin openmpi-common libopenmpi-dev -y +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example + +RUN pip3 uninstall jittor -y + COPY . . RUN pip3 install . --timeout 100 RUN python3.7 -m jittor.test.test_example -RUN rm -rf ~/.cache/jittor/default - CMD python3.7 -m jittor.notebook --allow-root --ip=0.0.0.0 \ No newline at end of file diff --git a/extern/cuda/cub/ops/cub_arg_reduce_op.cc b/extern/cuda/cub/ops/cub_arg_reduce_op.cc index adb6e342..8bbce945 100644 --- a/extern/cuda/cub/ops/cub_arg_reduce_op.cc +++ b/extern/cuda/cub/ops/cub_arg_reduce_op.cc @@ -24,7 +24,7 @@ CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdim : x(x), offsets(offsets), op(op), keepdims(keepdims) { flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cuda, 1); - ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32); + ASSERT(offsets->dtype()==ns_int32); y = create_output(nullptr, ns_int32); y_key = create_output(nullptr, x->dtype()); } diff --git a/extern/cuda/cub/ops/cub_argsort_op.cc b/extern/cuda/cub/ops/cub_argsort_op.cc index b80ce99d..ec410df0 100644 --- a/extern/cuda/cub/ops/cub_argsort_op.cc +++ b/extern/cuda/cub/ops/cub_argsort_op.cc @@ -23,7 +23,7 @@ CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending, : x(x), indexes(indexes), offsets(offsets), descending(descending) { flags.set(NodeFlags::_cpu, 0); flags.set(NodeFlags::_cuda, 1); - ASSERT(offsets->dtype()==ns_int || offsets->dtype()==ns_int32); + ASSERT(offsets->dtype()==ns_int32); y = create_output(nullptr, dtype); y_key = create_output(nullptr, x->dtype()); } diff --git a/extern/cuda/curand/ops/curand_random_op.h b/extern/cuda/curand/ops/curand_random_op.h index 3e899ecf..9c9fd4cf 100644 --- a/extern/cuda/curand/ops/curand_random_op.h +++ b/extern/cuda/curand/ops/curand_random_op.h @@ -13,7 +13,7 @@ namespace jittor { struct CurandRandomOp : Op { Var* output; - CurandRandomOp(NanoVector shape, NanoString dtype=ns_float); + CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32); const char* name() const override { return "curand_random"; } DECLARE_jit_run; diff --git a/extern/cuda/inc/helper_cuda.h b/extern/cuda/inc/helper_cuda.h index 2a9c079a..a8bdacb3 100644 --- a/extern/cuda/inc/helper_cuda.h +++ b/extern/cuda/inc/helper_cuda.h @@ -101,6 +101,17 @@ const char *_cudaGetErrorEnum(NppStatus error); #endif #endif +template +void peek(T result, char const *const func, const char *const file, + int const line) { + if (result) { + // DEVICE_RESET + LOGe << "Peek CUDA error at" << file >> ":" >> line << " code=" + >> static_cast(result) >> "(" << _cudaGetErrorEnum(result) << ")" + << func; + } +} + template void check(T result, char const *const func, const char *const file, int const line) { @@ -116,6 +127,7 @@ void check(T result, char const *const func, const char *const file, // This will output the proper CUDA error strings in the event // that a CUDA host call returns an error #define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) +#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__) // This will output the proper error string when calling cudaGetLastError #define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__) diff --git a/extern/mpi/src/mpi_warper.cc b/extern/mpi/src/mpi_warper.cc index fc3d8c73..c06d8a81 100644 --- a/extern/mpi/src/mpi_warper.cc +++ b/extern/mpi/src/mpi_warper.cc @@ -83,7 +83,7 @@ mpi_initer() { MPI_CHECK(MPI_Init(NULL, NULL)); MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size)); MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank)); - + //calculating localRank based on hostname which is used in selecting a GPU uint64_t hostHashs[mpi_world_rank]; char hostname[1024]; diff --git a/notebook/custom_op.src.md b/notebook/custom_op.src.md index 4a9d7125..8a3158f4 100644 --- a/notebook/custom_op.src.md +++ b/notebook/custom_op.src.md @@ -33,7 +33,7 @@ namespace jittor { struct CustomOp : Op { Var* output; - CustomOp(NanoVector shape, NanoString dtype=ns_float); + CustomOp(NanoVector shape, NanoString dtype=ns_float32); const char* name() const override { return "custom"; } DECLARE_jit_run; diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index c153969a..04c5db10 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace def unsqueeze(x, dim): shape = list(x.shape) + if dim < 0: dim += len(shape) + 1 assert dim <= len(shape) return x.reshape(shape[:dim] + [1] + shape[dim:]) Var.unsqueeze = unsqueeze @@ -304,11 +305,11 @@ Var.masked_fill = masked_fill def sqr(x): return x*x Var.sqr = sqr -def argmax(x, dim:int, keepdims:bool=False): +def argmax(x, dim, keepdims:bool=False): return x.arg_reduce("max", dim, keepdims) Var.argmax = argmax -def argmin(x, dim:int, keepdims:bool=False): +def argmin(x, dim, keepdims:bool=False): return x.arg_reduce("min", dim, keepdims) Var.argmin = argmin @@ -321,13 +322,54 @@ def attrs(var): } Var.attrs = attrs -def fetch(vars, func, *args, **kw): - core.fetch(vars, lambda *results: func(*results, *args, **kw)) +def fetch(*args): + ''' Async fetch vars with function closure. + +Example 1:: -def fetch_var(var, func, *args, **kw): - core.fetch([var], lambda a: func(a, *args, **kw)) -Var.fetch = fetch_var -del fetch_var + for img,label in enumerate(your_dataset): + pred = your_model(img) + loss = critic(pred, label) + acc = accuracy(pred, label) + jt.fetch(acc, loss, + lambda acc, loss: + print(f"loss:{loss} acc:{acc}" + ) + +Example 2:: + + for i,(img,label) in enumerate(your_dataset): + pred = your_model(img) + loss = critic(pred, label) + acc = accuracy(pred, label) + # variable i will be bind into function closure + jt.fetch(i, acc, loss, + lambda i, acc, loss: + print(f"#{i}, loss:{loss} acc:{acc}" + ) + ''' + assert len(args)>=1 + func = args[-1] + assert callable(func) + args = list(args[:-1]) + if len(args)>0 and isinstance(args[0], Sequence) \ + and len(args[0])>=1 and isinstance(args[0][0], Var): + raise TypeError("jt.Var should not inside a list or tuple.") + + var_map = [] + variables = [] + for i, v in enumerate(args): + if isinstance(v, Var): + variables.append(v) + var_map.append(i) + args[i] = None + def callback(*results): + for i,v in enumerate(results): + args[var_map[i]] = v + func(*args) + core.ops.fetch(variables, callback) + +Var.fetch = fetch def display_memory_info(): import inspect, os @@ -439,11 +481,11 @@ class Module: end = 0 for k in key_: if isinstance(v, nn.Sequential): - if np.int(k) >= len(v.layers): + if ori_int(k) >= len(v.layers): end = 1 break else: - v = v[np.int(k)] + v = v[ori_int(k)] else: if hasattr(v, k): v = getattr(v, k) @@ -574,12 +616,23 @@ def jittor_exit(): pass else: core.sync_all(True) + core.cleanup() atexit.register(jittor_exit) Var.__str__ = lambda x: str(x.data) Var.__repr__ = lambda x: str(x.data) Var.peek = lambda x: f"{x.dtype}{x.shape}" + +ori_int = int + +int = int32 +Var.int = Var.int32 +float = float32 +Var.float = Var.float32 +double = float64 +Var.double = Var.float64 + from . import nn from .nn import matmul from . import contrib diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 987f4b2e..a4e9cc49 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -345,7 +345,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""): with open(os.path.join(jittor_path, header), encoding='utf8') as f: src = f.read() # XxxXxxOp(args) - res = re.findall(pybind_attrs_reg + '('+name2+"\\([^\\n]*\\))", src, re.S) + res = re.findall(pybind_attrs_reg + '[^~]('+name2+"\\([^\\n]*\\))", src, re.S) assert len(res) >= 1, "Wrong op args in " + header # registe op cc_name = os.path.join(jittor_path, header[:-2] + ".cc") @@ -908,14 +908,14 @@ with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f: f.write(jit_src) cc_flags += f' -I{cache_path} ' # gen pyjt -pyjt_compiler.compile(cache_path, jittor_path) +pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path) # initialize order: # 1. registers # 2. generate source # 3. op_utils # 4. other -files2 = run_cmd(f'find "{os.path.join(cache_path, "gen")}" | grep "cc$"').splitlines() +files2 = pyjt_gen_src files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines() at_beginning = [ "src/ops/op_utils.cc", diff --git a/python/jittor/pyjt_compiler.py b/python/jittor/pyjt_compiler.py index b22b4c1c..85762b96 100644 --- a/python/jittor/pyjt_compiler.py +++ b/python/jittor/pyjt_compiler.py @@ -258,7 +258,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs LOG.vvv("gen err from func_head", func_head) args = func_head[1:].split(")")[0].split(",") - error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\"" + error_code = f" << \"Wrong inputs arguments, Please refer to examples({help_cmd}).\"" error_code += r' << "\n\nTypes of your inputs are:\n"' for arg in args: arg = arg.strip() @@ -849,6 +849,7 @@ def compile(cache_path, jittor_path): headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \ [ os.path.join(cache_path, h) for h in headers2 ] basenames = [] + pyjt_names = [] for h in headers: with open(h, 'r') as f: src = f.read() @@ -866,6 +867,7 @@ def compile(cache_path, jittor_path): if not check: continue basenames.append(basename) + pyjt_names.append(fname) code = f""" #include "pyjt/numpy.h" @@ -888,3 +890,5 @@ def compile(cache_path, jittor_path): LOG.vvvv(code) with open(fname, "w") as f: f.write(code) + pyjt_names.append(fname) + return pyjt_names diff --git a/python/jittor/test/test_array.py b/python/jittor/test/test_array.py index f0a451dc..dc859607 100644 --- a/python/jittor/test/test_array.py +++ b/python/jittor/test/test_array.py @@ -60,6 +60,7 @@ class TestArray(unittest.TestCase): for i in range(3): x = jt.array(im) b = net(x) + b.fetch(lambda b: None) b.sync() jt.sync(device_sync=True) @@ -70,6 +71,7 @@ class TestArray(unittest.TestCase): x = jt.array(im) b = net(x) b.fetch(lambda b: results.append(b)) + b.sync() # del c jt.sync(device_sync=True) t2 = time.time() - time_start @@ -111,6 +113,12 @@ class TestArray(unittest.TestCase): """) assert (b.data==[2,8,18]).all() + def test_not_c_style(self): + a = np.array([1,2,3]) + b = a[::-1] + x = jt.array(b) + x = x + b + assert (x.data == [6,4,2]).all() diff --git a/python/jittor/test/test_core.py b/python/jittor/test/test_core.py index fe31536a..867b9f6e 100644 --- a/python/jittor/test/test_core.py +++ b/python/jittor/test/test_core.py @@ -16,7 +16,7 @@ def expect_error(func): class TestCore(unittest.TestCase): def test_number_of_hold_vars(self): - assert jt.random([1,2,3]).peek() == "float[1,2,3,]" + assert jt.random([1,2,3]).peek() == "float32[1,2,3,]" assert jt.core.number_of_hold_vars() == 0 x = jt.random([1,2,3]) assert jt.core.number_of_hold_vars() == 1 diff --git a/python/jittor/test/test_custom_op.py b/python/jittor/test/test_custom_op.py index 9b3482c2..8ae033b3 100644 --- a/python/jittor/test/test_custom_op.py +++ b/python/jittor/test/test_custom_op.py @@ -16,7 +16,7 @@ namespace jittor { struct CustomOp : Op { Var* output; - CustomOp(NanoVector shape, NanoString dtype=ns_float); + CustomOp(NanoVector shape, NanoString dtype=ns_float32); const char* name() const override { return "custom"; } DECLARE_jit_run; @@ -75,7 +75,7 @@ class TestCustomOp(unittest.TestCase): my_op = jt.compile_custom_op(""" struct MyOp : Op { Var* output; - MyOp(NanoVector shape, NanoString dtype=ns_float); + MyOp(NanoVector shape, NanoString dtype=ns_float32); const char* name() const override { return "my"; } DECLARE_jit_run; diff --git a/python/jittor/test/test_fetcher.py b/python/jittor/test/test_fetcher.py index 35d7bd8e..cc8c4c85 100644 --- a/python/jittor/test/test_fetcher.py +++ b/python/jittor/test/test_fetcher.py @@ -13,7 +13,10 @@ class TestFetcher(unittest.TestCase): a = jt.array([1,2,3]) a = a*2 v = [] - jt.fetch([a], lambda a: v.append(a)) + jt.fetch(a, lambda a: v.append(a)) + jt.fetch(1, 2, 3, a, + lambda x, y, z, a: self.assertTrue(x==1 and y==2 and z==3 and isinstance(a, np.ndarray)) + ) jt.sync_all(True) assert len(v)==1 and (v[0]==[2,4,6]).all() diff --git a/python/jittor/test/test_nano_string.py b/python/jittor/test/test_nano_string.py index 81eb1a22..1f26e0a1 100644 --- a/python/jittor/test/test_nano_string.py +++ b/python/jittor/test/test_nano_string.py @@ -10,7 +10,7 @@ from .test_core import expect_error import os mid = 0 -if os.uname()[1] == "jittor-ce": +if "jittor" in os.uname()[1]: mid = 1 class TestNanoString(unittest.TestCase): @@ -27,7 +27,8 @@ class TestNanoString(unittest.TestCase): assert t < [1.5e-7, 1.7e-7][mid], t assert (jt.hash("asdasd") == 4152566416) - assert str(jt.NanoString("float"))=="float" + assert str(jt.NanoString("float"))=="float32" + assert jt.NanoString("float")=="float32" # pybind11: 7 # Tuple call: 1.3 # fast call (with or with not): 0.9 @@ -38,14 +39,14 @@ class TestNanoString(unittest.TestCase): def test_type(self): import numpy as np - assert str(jt.NanoString(float)) == "float" - assert str(jt.NanoString(np.float)) == "float" + assert str(jt.NanoString(float)) == "float32" + assert str(jt.NanoString(np.float)) == "float32" assert str(jt.NanoString(np.float32)) == "float32" assert str(jt.NanoString(np.float64)) == "float64" assert str(jt.NanoString(np.int8)) == "int8" assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64" - assert str(jt.NanoString(jt.float)) == "float" + assert str(jt.NanoString(jt.float)) == "float32" assert str(jt.NanoString(jt.float32)) == "float32" assert str(jt.NanoString(jt.float64)) == "float64" assert str(jt.NanoString(jt.int8)) == "int8" diff --git a/python/jittor/test/test_resize_and_crop.py b/python/jittor/test/test_resize_and_crop.py index b6e2c3c6..1299ba13 100644 --- a/python/jittor/test/test_resize_and_crop.py +++ b/python/jittor/test/test_resize_and_crop.py @@ -99,6 +99,7 @@ class TestResizeAndCrop(unittest.TestCase): test_case(20, [1024, 1024], [1.2, 1.8][mid]) test_case(20, [1024, 666], [0.8,1.0][mid]) + @unittest.skipIf(torch is None, "no torch found") def test_resize(self): import torch.nn.functional as F x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32") @@ -108,11 +109,13 @@ class TestResizeAndCrop(unittest.TestCase): jnn.Resize((r_size, r_size), 'bilinear', align_corners), lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners)) + @unittest.skipIf(torch is None, "no torch found") def test_upsample(self): arr = np.random.randn(2,3,224,224) check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2)) check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2)) + @unittest.skipIf(torch is None, "no torch found") def test_pixelshuffle(self): arr = np.random.randn(2,4,224,224) check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2)) diff --git a/python/jittor/test/test_resnet.py b/python/jittor/test/test_resnet.py index c67e5a51..bbf7b873 100644 --- a/python/jittor/test/test_resnet.py +++ b/python/jittor/test/test_resnet.py @@ -64,16 +64,16 @@ class TestResnet(unittest.TestCase): SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) for batch_idx, (data, target) in enumerate(self.train_loader): - output = mnist_net(data) - loss = nn.cross_entropy_loss(output, target) # train step with jt.log_capture_scope( log_silent=1, log_v=1, log_vprefix="op.cc=100,exe=10", ) as logs: + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) SGD.step(loss) - def callback(loss, output, target, batch_idx): + def callback(batch_idx, loss, output, target): # print train info global prev pred = np.argmax(output, axis=1) @@ -83,13 +83,13 @@ class TestResnet(unittest.TestCase): print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' .format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev)) # prev = time.time() - jt.fetch([loss, output, target], callback, batch_idx) - + jt.fetch(batch_idx, loss, output, target, callback) + log_conv = find_log_with_re(logs, "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") log_matmul = find_log_with_re(logs, "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") - if batch_idx: + if batch_idx > 2: assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul)) mem_used = jt.flags.stat_allocator_total_alloc_byte \ @@ -114,15 +114,13 @@ class TestResnet(unittest.TestCase): # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 - # print(jt.core.number_of_lived_vars(), mem_used) - jt.display_memory_info() - # if jt.in_mpi: - # assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars() - # else: - # assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars() + if jt.in_mpi: + assert jt.core.number_of_lived_vars() < 7500, jt.core.number_of_lived_vars() + else: + assert jt.core.number_of_lived_vars() < 6500, jt.core.number_of_lived_vars() jt.sync_all(True) - assert np.mean(loss_list[-50:])<0.3 + assert np.mean(loss_list[-50:])<0.5 assert np.mean(acc_list[-50:])>0.8 if __name__ == "__main__": diff --git a/python/jittor/test/test_vgg.py b/python/jittor/test/test_vgg.py index f0c04bc5..d119e1f1 100644 --- a/python/jittor/test/test_vgg.py +++ b/python/jittor/test/test_vgg.py @@ -77,7 +77,7 @@ class TestVGGClass(unittest.TestCase): acc_list.append(acc) print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}' .format(0, batch_idx, 100,1. * batch_idx, loss[0], acc)) - jt.fetch([loss, output, target], callback, batch_idx) + jt.fetch(batch_idx, loss, output, target, callback) log_conv = find_log_with_re(logs, "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") diff --git a/python/jittor/utils/publish.py b/python/jittor/utils/publish.py index ad956677..10a90e5e 100644 --- a/python/jittor/utils/publish.py +++ b/python/jittor/utils/publish.py @@ -25,8 +25,8 @@ def docker_task(name, build_cmd): run_cmd(build_cmd) run_cmd(f"sudo docker push {name}") bname = os.path.basename(name) - run_cmd(f"docker save {name}:latest -o /tmp/{bname}.tgz && chmod 666 /tmp/{bname}.tgz") - upload_file(f" /tmp/{bname}.tgz") + run_cmd(f"sudo docker save {name}:latest -o /tmp/{bname}.tgz && sudo chmod 666 /tmp/{bname}.tgz") + upload_file(f"/tmp/{bname}.tgz") docker_task( "jittor/jittor", diff --git a/script/tmpi b/script/tmpi new file mode 100755 index 00000000..f660cb6a --- /dev/null +++ b/script/tmpi @@ -0,0 +1,117 @@ +#!/bin/bash + +# Copyright 2013 Benedikt Morbach +# Distributed under the terms of the GNU General Public License v2 + +# runs multiple MPI processes as a grid in a new tmux window and multiplexes keyboard input to all of them + +additional_vars=( LD_LIBRARY_PATH LD_PRELOAD ) +export "${additional_vars[@]}" + +usage() { + echo 'tmpi: Run multiple MPI processes as a grid in a new tmux window and multiplex keyboard input to all of them.' + echo '' + echo 'Usage:' + echo ' tmpi [number] [command]' + echo '' + echo 'You need to pass at least two arguments.' + echo 'The first argument is the number of processes to use, every argument after that is the commandline to run.' + echo 'If you call this script from outside tmux and your command contains important whitespace then you need to appy two levels of quoting to preserve it.' + echo '' + echo 'LD_LIBRARY_PATH and LD_PRELOAD are passed through, so you can run it like this:' + echo 'LD_LIBRARY_PATH="${PWD}/.libs:${LD_LIBRARY_PATH}" tmpi 16 gdb -q bin/.libs/example' + echo '' + echo 'The new window is set to remain on exit and has to be closed manually. ("C-b + k" by default)' +} + +check_tools() { + tools=( tmux mpirun ) + + for tool in "${tools[@]}"; do + if ! which ${tool}; then + echo "You need to install ${tool} to run this script." + fi + done +} + +if [[ ${#} -lt 2 ]]; then + usage + + exit 1 +fi + +if [[ -z ${TMUX} ]]; then + # it seems we aren't in a tmux session. + # start a new one so that our window doesn't end up in some other session and we have to search it. + # actually start a new server with '-L' to ensure that our environment carries over. + socket=$(mktemp --dry-run tmpi.XXXX) + exec tmux -L ${socket} new-session "${0} ${*}" +fi + +if [[ ${1} == runmpi ]] ; then + # we are being started as one of many processes by mpirun. + shift + + # start the processes in the order of their rank. + # this avoids races, as we have to push the variables in tmux' environment. + # it has the nice side-effect that the panes are also ordered by rank. + while [[ $(cat /tmp/tmpi.lock) -ne ${OMPI_COMM_WORLD_RANK} ]] ; do + sleep 0.02 + done + + # get all the variables that mpirun starts us with so that we can pass them through. + mpi_vars=( $( env | grep -e MPI -e OPAL -e PMIX -e PYTHON -e debug | cut -d '=' -f1 ) ) + mpi_vars+=( "${additional_vars[@]}" ) + + # add the variables to tmux' session environment. + # we can't just export them because the process will be started as a child of tmux, not us. + for var in "${mpi_vars[@]}"; do + tmux set-environment -t ${session} "${var}" "${!var}" + done + + x=( $(tmux split-window -P -F '#{pane_pid} #{pane_id}' -t ${window} "${*}") ) + pid=${x[0]} + pane=${x[1]} + + for var in "${mpi_vars[@]}"; do + tmux set-environment -t ${session} -u "${var}" + done + + # kill the dummy pane that opened the new window + [[ ${OMPI_COMM_WORLD_RANK} -eq 0 ]] && tmux kill-pane -t ${dummy} &> /dev/null + + # set the window to tiled mode. + # have to do this after every new pane is spawned because otherwise the splits get + # smaller and smaller until tmux refuses to open new panes, despite plenty of space being left. + tmux select-layout -t ${pane} tiled &> /dev/null + + # let the next process start + echo $((${OMPI_COMM_WORLD_RANK}+1)) > /tmp/tmpi.lock + + # don't exit here as mpirun needs to be kept alive and it would also exit. + while [[ -d /proc/${pid} ]]; do + sleep 1 + done +else + # we are the parent and set everything up before we start ourselves a bunch of times via mpirun. + processes=${1} + self=${0} + shift + + # create an empty new dummy window which we sill later split up for the mpi processes. + x=( $(tmux new-window ${session} -P -F '#{pane_id} #{window_id} #{session_id}') ) + export dummy=${x[0]} + export window=${x[1]} + export session=${x[2]} + + # syncronize input to all panes. + tmux set-window-option -t ${window} synchronize-panes on &> /dev/null + tmux set-window-option -t ${window} remain-on-exit on &> /dev/null + + # always start with rank 0. + echo 0 > /tmp/tmpi.lock + + # re-execute ourself to spawn of the processes. + echo mpirun -np ${processes} ${self} runmpi "${@}" + mpirun -np ${processes} ${self} runmpi "${@}" +fi diff --git a/setup.py b/setup.py index 3787619d..f4f97762 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh: setuptools.setup( name='jittor', - version='1.1.4.9', + version='1.1.5.4', # scripts=[], author="Jittor Group", author_email="ran.donglang@gmail.com", diff --git a/src/executor.cc b/src/executor.cc index 685d1018..6ee4fda3 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -9,7 +9,6 @@ #include #include #include "mem/allocator/cuda_dual_allocator.h" -#include "fetcher.h" #include "event_queue.h" #endif #include "misc/cuda_flags.h" @@ -26,6 +25,9 @@ namespace jittor { Executor exe; +// from fetch_op.cc +extern list fetcher_to_free; + void Executor::run_sync(vector vars, bool device_sync) { auto allocator = get_allocator(); this->allocator = allocator; @@ -33,22 +35,43 @@ void Executor::run_sync(vector vars, bool device_sync) { int op_num = 0; vector bfs_q; bfs_q.reserve(vars.size()); - auto nodes = (vector*)&vars; int start_var_num = 0; - for (Var* v : vars) - if (!v->is_finished()) - start_var_num++; - bfs_backward(*nodes, bfs_q, [&](Node *node) -> bool { - node->custom_data = 0; - if (node->is_finished()) - return false; - op_num += !node->is_var(); - return true; - }); + { + // get all nodes need to be executed + auto t = ++Node::tflag_count; + for (Var* v : vars) + if (!v->is_finished() && v->tflag != t) { + v->tflag = t; + start_var_num++; + bfs_q.push_back(v); + } + for (int i=0; iis_var(); + for (auto i : node->_inputs) + if (i.node->tflag != t && !i.node->is_finished()) { + i.node->tflag = t; + bfs_q.push_back(i.node); + } + // this var has been fetched + if (node->flags.get(NodeFlags::_fetch)) { + for (auto& n : node->_outputs) { + // if not in queue and is fetch op + if (n.node->tflag != t && + !n.node->is_finished() && + n.node->flags.get(NodeFlags::_fetch)) { + n.node->tflag = t; + bfs_q.push_back(n.node); + } + } + } + } + } auto tt = Node::tflag_count; vector ops; vector all_vars; ops.reserve(op_num); + all_vars.reserve(bfs_q.size() - op_num); for (Node* node : bfs_q) if (!node->is_var()) { node->custom_data = ops.size(); @@ -391,7 +414,6 @@ void Executor::run_sync(vector vars, bool device_sync) { outputs_bk.push_back(var); op->finish_pending_liveness(); for (Var* var : outputs_bk) - // var->finish_pending_liveness(); var->finish_pending_liveness(); } catch (const std::exception& e) { // log memory info @@ -410,6 +432,8 @@ void Executor::run_sync(vector vars, bool device_sync) { } LOGvv << "All" << op_num << "ops finished, return vars:" << vars; for (Var* v : vars) ASSERT(v->mem_ptr); + // clean fetcher free buffer + fetcher_to_free.clear(); #ifdef HAS_CUDA if (device_sync && use_cuda) { last_is_cuda = false; diff --git a/src/grad.cc b/src/grad.cc index 93334811..78fe28ec 100644 --- a/src/grad.cc +++ b/src/grad.cc @@ -27,7 +27,7 @@ VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) { auto dx = op->grad(out, dout, x, x_index); if (x->loop_options) dx->loop_options = x->loop_options; - return move(dx); + return dx; } inline static void assign_attrs(Var* a, Var* b) { @@ -92,29 +92,30 @@ vector grad(Var* loss, vector targets) { Op* op = it.op; auto index = it.index; if (op->tflag != nt) continue; - // TODO: support two outputs backprop. - Var* out = op->outputs().back(); - Var* dout = grads[out->custom_data]; - VarPtr dvar = make_grad(op, out, dout, var, index); - registe_node_trace_grad(dvar.ptr, op, index); - if (dvar) - ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size()) - << "dvar" << dvar << "var" << var; - if (!grad) - grad = move(dvar); - else if (dvar) { - grad = make_binary(grad, dvar, ns_add); - #ifdef PREVENT_LARGE_FUSED_OP - gsum ++; - if (gsum>=PREVENT_LARGE_FUSED_OP) { - // TODO: this is a dirty fix for - // stopping fuse lots of op together, - // try to find a better solution - grad->flags.set(NodeFlags::_stop_fuse); + for (Var* out : op->outputs()) { + if (out->tflag != nt) continue; + Var* dout = grads[out->custom_data]; + VarPtr dvar = make_grad(op, out, dout, var, index); + registe_node_trace_grad(dvar.ptr, op, index); + if (dvar) + ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size()) + << "dvar" << dvar << "var" << var; + if (!grad) + grad = move(dvar); + else if (dvar) { + grad = make_binary(grad, dvar, ns_add); + #ifdef PREVENT_LARGE_FUSED_OP + gsum ++; + if (gsum>=PREVENT_LARGE_FUSED_OP) { + // TODO: this is a dirty fix for + // stopping fuse lots of op together, + // try to find a better solution + grad->flags.set(NodeFlags::_stop_fuse); + } + #endif + assign_attrs(grad.ptr, var); + registe_node_trace_grad(grad.ptr, var, index); } - #endif - assign_attrs(grad.ptr, var); - registe_node_trace_grad(grad.ptr, var, index); } } } diff --git a/src/init.cc b/src/init.cc index 56023e8d..9db52f14 100644 --- a/src/init.cc +++ b/src/init.cc @@ -11,6 +11,7 @@ #include "init.h" #include "ops/op_register.h" +#include "var.h" namespace jittor { @@ -21,6 +22,15 @@ unique_ptr eng; vector callbacks; int current_seed; +// fron fetch_op.cc +extern list fetcher; +extern list fetcher_to_free; + +void cleanup() { + fetcher_to_free.clear(); + fetcher.clear(); +} + static void init_cuda_devices() { #ifdef HAS_CUDA int count=0; diff --git a/src/init.h b/src/init.h index 6fb24c31..3d8a48dc 100644 --- a/src/init.h +++ b/src/init.h @@ -20,4 +20,8 @@ void add_set_seed_callback(set_seed_callback callback); extern "C" std::default_random_engine* get_random_engine(); +// things need to be clean before python exit +// @pyjt(cleanup) +void cleanup(); + } // jittor diff --git a/src/mem/allocator/cuda_dual_allocator.h b/src/mem/allocator/cuda_dual_allocator.h index e83e425f..c0461294 100644 --- a/src/mem/allocator/cuda_dual_allocator.h +++ b/src/mem/allocator/cuda_dual_allocator.h @@ -95,7 +95,7 @@ struct DelayFree final : Allocator { void free(void* mem_ptr, size_t size, const size_t& allocation) override { using namespace cuda_dual_local; allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator); - checkCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0)); + peekCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0)); } void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) { diff --git a/src/misc/nano_string.cc b/src/misc/nano_string.cc index 62cc9582..3e219014 100644 --- a/src/misc/nano_string.cc +++ b/src/misc/nano_string.cc @@ -9,9 +9,6 @@ namespace jittor { #define FOR_ALL_TYPES(m) \ - m(float) \ - m(double) \ - m(int) \ m(bool) \ m(int8) \ m(int16) \ @@ -151,6 +148,10 @@ static void init_ns() { NanoString::__string_to_ns["sum"] = ns_add; NanoString::__string_to_ns["min"] = ns_minimum; NanoString::__string_to_ns["max"] = ns_maximum; + NanoString::__string_to_ns["float"] = ns_float32; + NanoString::__string_to_ns["double"] = ns_float64; + NanoString::__string_to_ns["int"] = ns_int32; + NanoString::__string_to_ns["uint"] = ns_uint32; LOGvv << "init __string_to_ns" << NanoString::__string_to_ns; LOGvv << "init __ns_to_string" << NanoString::__ns_to_string; } diff --git a/src/misc/nano_string.h b/src/misc/nano_string.h index fd4d9af8..313f69f4 100644 --- a/src/misc/nano_string.h +++ b/src/misc/nano_string.h @@ -12,9 +12,6 @@ namespace jittor { #define FOR_ALL_NS(m) \ \ m(void) \ - m(float) \ - m(double) \ - m(int) \ m(bool) \ m(int8) \ m(int16) \ diff --git a/src/node.h b/src/node.h index f71df9be..d0acccf9 100644 --- a/src/node.h +++ b/src/node.h @@ -24,7 +24,9 @@ struct NodeFlags { _finished=1, // bit2: stop grad _stop_grad=2, - _n=3, + // bit3: is fetch + _fetch=3, + _n=4, // var related flags _force_fuse=_n+0, diff --git a/src/ops/array_op.cc b/src/ops/array_op.cc index 1bc2bc7f..0132f605 100644 --- a/src/ops/array_op.cc +++ b/src/ops/array_op.cc @@ -32,9 +32,9 @@ Init() { } ~Init() { if (!get_device_count()) return; - checkCudaErrors(cudaDeviceSynchronize()); - checkCudaErrors(cudaStreamDestroy(stream)); - checkCudaErrors(cudaEventDestroy(event)); + peekCudaErrors(cudaDeviceSynchronize()); + peekCudaErrors(cudaStreamDestroy(stream)); + peekCudaErrors(cudaEventDestroy(event)); } } init; diff --git a/src/ops/array_op.h b/src/ops/array_op.h index d03c3c1d..d660cd75 100644 --- a/src/ops/array_op.h +++ b/src/ops/array_op.h @@ -20,7 +20,7 @@ struct ArrayOp : Op { Var* output; Allocation allocation; // @pybind(None) - ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float); + ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float32); ArrayOp(ArrayArgs&& args); template diff --git a/src/fetcher.cc b/src/ops/fetch_op.cc similarity index 60% rename from src/fetcher.cc rename to src/ops/fetch_op.cc index e5ea9531..ebf0324b 100644 --- a/src/fetcher.cc +++ b/src/ops/fetch_op.cc @@ -1,5 +1,7 @@ // *************************************************************** -// Copyright (c) 2020 Jittor. Authors: Dun Liang . All Rights Reserved. +// Copyright (c) 2020 Jittor. +// Authors: Dun Liang . +// All Rights Reserved. // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** @@ -12,8 +14,9 @@ #include "mem/allocator/cuda_dual_allocator.h" #include "event_queue.h" #endif -#include "fetcher.h" +#include "ops/fetch_op.h" #include "mem/allocator.h" +#include "executor.h" namespace jittor { @@ -49,31 +52,68 @@ Init() { // do not call deleter on exit for (auto& f : fetch_tasks) f.func.deleter = nullptr; - checkCudaErrors(cudaDeviceSynchronize()); - checkCudaErrors(cudaStreamDestroy(stream)); - checkCudaErrors(cudaEventDestroy(event)); + peekCudaErrors(cudaDeviceSynchronize()); + peekCudaErrors(cudaStreamDestroy(stream)); + peekCudaErrors(cudaEventDestroy(event)); } -}; +} ; } using namespace fetcher_local; #endif -void fetch(const vector& vh, FetchFunc&& func) { +list fetcher; +// this list will be free at each execution +list fetcher_to_free; + +FetchOp::FetchOp(vector&& inputs, FetchFunc&& func) +: fetch_vars(inputs), func(move(func)) { #ifdef HAS_CUDA - static Init init; + // stream needs to be created after nccl plugin + static Init init_fetch; #endif - sync(vh); - vector allocations(vh.size()); - vector arrays(vh.size()); + VarPtr vp(0, ns_int32); + outputs_holder.emplace_back(vp); + fetcher.emplace_front(move(vp)); + fetcher_iter = fetcher.begin(); + bool all_finished = true; + for (auto v : fetch_vars) + if (!v->is_finished()) { + all_finished = false; + v->flags.set(NodeFlags::_stop_fuse); + v->flags.set(NodeFlags::_fetch); + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_fetch); + flags.set(NodeFlags::_stop_grad); + fetcher_iter->ptr->flags.set(NodeFlags::_fetch); + // fetcher_to_free.clear(); + if (all_finished) { + // if all finished, run immediately + run(); + } + // if too many fetchers are bufferd, force flush + while (fetcher.size() > 20) { + LOGvvvv << "too many fetchers(">>fetcher.size() >> + ") are bufferd, force flush"; + exe.run_sync({fetcher.back().ptr}, false); + } +} + +void FetchOp::run() { + vector allocations(fetch_vars.size()); + vector arrays(fetch_vars.size()); #ifdef HAS_CUDA bool has_cuda_memcpy = false; event_queue.flush(); #endif - for (int i=0; ivar; + LOGvvvv << "fetch" << fetch_vars.size() << "vars" << fetch_vars; + int i = 0; + for (auto v : fetch_vars) { auto& allocation = allocations[i]; + #ifdef HAS_CUDA if (v->allocator->is_cuda()) { checkCudaErrors(cudaEventRecord(event, 0)); @@ -98,6 +138,7 @@ void fetch(const vector& vh, FetchFunc&& func) { arrays[i].ptr = allocation.ptr; arrays[i].shape = v->shape; arrays[i].dtype = v->dtype(); + i++; } #ifdef HAS_CUDA if (has_cuda_memcpy) { @@ -109,6 +150,8 @@ void fetch(const vector& vh, FetchFunc&& func) { FetchResult fr{move(func), move(allocations), move(arrays)}; fr.call(); } + fetcher_to_free.emplace_front(move(*fetcher_iter)); + fetcher.erase(fetcher_iter); } } // jittor diff --git a/src/fetcher.h b/src/ops/fetch_op.h similarity index 80% rename from src/fetcher.h rename to src/ops/fetch_op.h index 661dddd0..bf4b47db 100644 --- a/src/fetcher.h +++ b/src/ops/fetch_op.h @@ -5,8 +5,9 @@ // *************************************************************** #pragma once #include -#include "common.h" -#include "var_holder.h" +#include "op.h" +#include "var.h" +#include "mem/allocator.h" #include "ops/array_op.h" namespace jittor { @@ -42,7 +43,15 @@ struct FetchResult { inline void call() { func.callback(this); } }; -// @pyjt(fetch) -void fetch(const vector& vh, FetchFunc&& func); +struct FetchOp final : Op { + vector fetch_vars; + FetchFunc func; + list::iterator fetcher_iter; -} // jittor + FetchOp(vector&& inputs, FetchFunc&& func); + + const char* name() const override { return "fetch"; } + void run() override; +}; + +} // jittor \ No newline at end of file diff --git a/src/ops/op_utils.cc b/src/ops/op_utils.cc index 4da665cb..7a18e48a 100644 --- a/src/ops/op_utils.cc +++ b/src/ops/op_utils.cc @@ -16,7 +16,7 @@ static auto make_broadcast_to = get_op_info("broadcast_to") .get_constructor(); VarPtr make_number(float number, Var* x) { - VarPtr nums = make_array(&number, 1, ns_float); + VarPtr nums = make_array(&number, 1, ns_float32); nums = make_broadcast_to(nums, x, {}); return make_unary(nums, x->dtype()); } diff --git a/src/ops/random_op.h b/src/ops/random_op.h index bc8ade3c..438d0cb6 100644 --- a/src/ops/random_op.h +++ b/src/ops/random_op.h @@ -10,7 +10,7 @@ namespace jittor { struct RandomOp : Op { Var* output; - RandomOp(NanoVector shape, NanoString dtype=ns_float); + RandomOp(NanoVector shape, NanoString dtype=ns_float32); const char* name() const override { return "random"; } DECLARE_jit_run; diff --git a/src/ops/unary_op.cc b/src/ops/unary_op.cc index cd9c0057..9a652cee 100644 --- a/src/ops/unary_op.cc +++ b/src/ops/unary_op.cc @@ -22,9 +22,6 @@ static auto make_number = get_op_info("number") .get_constructor(); static unordered_set unary_ops = { - "float", - "double", - "int", "bool", "int8", "int16", diff --git a/src/opt/tuner/conv_tuner.cc b/src/opt/tuner/conv_tuner.cc index 7999ae33..53452c44 100644 --- a/src/opt/tuner/conv_tuner.cc +++ b/src/opt/tuner/conv_tuner.cc @@ -229,7 +229,7 @@ void ConvTuner::forwardTune(FusedOp* fop) { if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return; // only support float32 currently - if (bop->z->dtype() != ns_float && bop->z->dtype() != ns_float32) + if (bop->z->dtype() != ns_float32) continue; Op* ops[3] = {op, bop->x->input(), bop->y->input()}; int ok = 0; diff --git a/src/pyjt/numpy.cc b/src/pyjt/numpy.cc index 39acf045..f236567e 100644 --- a/src/pyjt/numpy.cc +++ b/src/pyjt/numpy.cc @@ -15,6 +15,7 @@ PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp co PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *); unsigned int (*PyArray_GetNDArrayCFeatureVersion)(); int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj); +PyObject* (*PyArray_NewCopy)(PyObject *, int); tmp_data_t tmp_data; @@ -30,6 +31,7 @@ void numpy_init() { fill(PyArray_New, 93); fill(PyArray_GetNDArrayCFeatureVersion, 211); fill(PyArray_SetBaseObject, 282); + fill(PyArray_NewCopy, 85); ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7); } diff --git a/src/pyjt/numpy.h b/src/pyjt/numpy.h index 42e27776..73d1c14d 100644 --- a/src/pyjt/numpy.h +++ b/src/pyjt/numpy.h @@ -76,12 +76,12 @@ inline int get_typenum(NanoString ns) { if (ns == ns_uint8) return 2; if (ns == ns_int16) return 3; if (ns == ns_uint16) return 4; - if (ns == ns_int32 || ns == ns_int) return 5; + if (ns == ns_int32) return 5; if (ns == ns_uint32) return 6; if (ns == ns_int64) return 7; if (ns == ns_uint64) return 8; - if (ns == ns_float32 || ns == ns_float) return 11; - if (ns == ns_float64 || ns == ns_double) return 12; + if (ns == ns_float32) return 11; + if (ns == ns_float64) return 12; LOGf << ns; return -1; } @@ -97,6 +97,8 @@ extern PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_ extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *); extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)(); extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj); +extern PyObject* (*PyArray_NewCopy)(PyObject *, int); +#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0) #define NPY_ARRAY_ALIGNED 0x0100 #define NPY_ARRAY_WRITEABLE 0x0400 diff --git a/src/pyjt/py_converter.h b/src/pyjt/py_converter.h index f874cc59..cf8e2144 100644 --- a/src/pyjt/py_converter.h +++ b/src/pyjt/py_converter.h @@ -293,21 +293,23 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) { auto ptr = GET_RAW_PTR(VarHolder, obj); return move(fetch_sync({ptr}).at(0)); } - if (Py_TYPE(obj) != PyArray_Type) { - PyObjHolder holder(PyArray_FROM_O(obj)); + // PyArray_Type + auto arr = (PyArray_Proxy*)obj; + if (Py_TYPE(obj) != PyArray_Type || !is_c_style(arr)) { + PyObjHolder holder( + Py_TYPE(obj) != PyArray_Type ? + PyArray_FROM_O(obj) : + PyArray_Copy(obj)); auto arr = (PyArray_Proxy*)holder.obj; int64 size = PyArray_Size(arr); T args; - args.ptr = arr->data; args.shape = vector(arr->dimensions, arr->dimensions+arr->nd); args.dtype = get_type_str(arr); args.buffer.reset(new char[size]); + args.ptr = (void*)args.buffer.get(); memcpy((void*)args.buffer.get(), (void*)arr->data, size); return args; } - // PyArray_Type - auto arr = (PyArray_Proxy*)obj; - CHECK(is_c_style(arr)); T args; args.ptr = arr->data; if (arr->dimensions) diff --git a/src/var_holder.cc b/src/var_holder.cc index 3cb80a38..9ad35aba 100644 --- a/src/var_holder.cc +++ b/src/var_holder.cc @@ -97,6 +97,9 @@ ArrayArgs VarHolder::fetch_sync() { return {var->mem_ptr, var->shape, var->dtype()}; } +// from fetch_op.cc +extern list fetcher; + void sync_all(bool device_sync) { vector vars; vars.reserve(VarHolder::hold_vars.size()); @@ -104,6 +107,8 @@ void sync_all(bool device_sync) { if (!v->var->_outputs.size()) vars.push_back(v->var); } + for (auto& v :fetcher) + vars.push_back(v.ptr); graph_check(); exe.run_sync(vars, device_sync); //need sync at last graph_check(); diff --git a/src/var_holder.h b/src/var_holder.h index c103378c..854b2d0a 100644 --- a/src/var_holder.h +++ b/src/var_holder.h @@ -106,7 +106,7 @@ struct VarHolder { /* detach the grad */ // @pyjt(detach) inline VarHolder* detach() { - return new VarHolder(move(jittor::detach(var))); + return new VarHolder(jittor::detach(var)); }