This commit is contained in:
Dun Liang 2021-11-01 11:42:38 +08:00
parent 47c2f65749
commit 08cc8a0451
32 changed files with 34 additions and 34 deletions

View File

@ -531,7 +531,7 @@ def setup_mpi():
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
def warper(func):
def wrapper(func):
def inner(self, *args, **kw):
return func(self, *args, **kw)
inner.__doc__ = func.__doc__
@ -539,7 +539,7 @@ def setup_mpi():
for k in mpi_ops.__dict__:
if not k.startswith("mpi_"): continue
if k == "mpi_test": continue
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
setattr(core.Var, k, wrapper(mpi_ops.__dict__[k]))
if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
try:

View File

@ -13,7 +13,7 @@
#include "var.h"
#include "cublas_batched_matmul_op.h"
#include "cublas_warper.h"
#include "cublas_wrapper.h"
using namespace std;

View File

@ -10,7 +10,7 @@
#include "var.h"
#include "cublas_matmul_op.h"
#include "cublas_warper.h"
#include "cublas_wrapper.h"
using namespace std;

View File

@ -7,7 +7,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cublas_warper.h"
#include "cublas_wrapper.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -7,7 +7,7 @@
// ***************************************************************
#pragma once
#include "op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "init.h"

View File

@ -10,7 +10,7 @@
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv3d_backward_w_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"

View File

@ -10,7 +10,7 @@
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv3d_backward_x_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"

View File

@ -7,7 +7,7 @@
// ***************************************************************
#include "var.h"
#include "cudnn_conv3d_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"

View File

@ -10,7 +10,7 @@
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv_backward_w_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
using namespace std;

View File

@ -10,7 +10,7 @@
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv_backward_x_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
using namespace std;

View File

@ -7,7 +7,7 @@
// ***************************************************************
#include "var.h"
#include "cudnn_conv_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
using namespace std;

View File

@ -8,7 +8,7 @@
#include "var.h"
#include "cudnn_rnn_descriptor.h"
#include "cudnn_rnn_backward_x_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"

View File

@ -8,7 +8,7 @@
#include "var.h"
#include "cudnn_rnn_descriptor.h"
#include "cudnn_rnn_op.h"
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"

View File

@ -4,7 +4,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cudnn_warper.h"
#include "cudnn_wrapper.h"
#include "misc/cuda_flags.h"
namespace jittor {

View File

@ -12,7 +12,7 @@
#include <curand.h>
#include "helper_cuda.h"
#include "curand_random_op.h"
#include "curand_warper.h"
#include "curand_wrapper.h"
namespace jittor {

View File

@ -7,7 +7,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "curand_warper.h"
#include "curand_wrapper.h"
#include "init.h"
#include "misc/cuda_flags.h"

View File

@ -7,7 +7,7 @@
#include "cutt_transpose_op.h"
#include "ops/op_register.h"
#include "cutt.h"
#include "cutt_warper.h"
#include "cutt_wrapper.h"
#include "misc/stack_vector.h"
#include "helper_cuda.h"

View File

@ -6,7 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cutt_warper.h"
#include "cutt_wrapper.h"
namespace jittor {

View File

@ -8,7 +8,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "mpi_warper.h"
#include "mpi_wrapper.h"
#include <cuda_runtime.h>
#include <nccl.h>

View File

@ -13,7 +13,7 @@
#include <nccl.h>
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include "nccl_warper.h"
#include "nccl_wrapper.h"
#include "ops/op_register.h"
namespace jittor {

View File

@ -13,7 +13,7 @@
#include <nccl.h>
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include "nccl_warper.h"
#include "nccl_wrapper.h"
#include "ops/op_register.h"
namespace jittor {

View File

@ -13,7 +13,7 @@
#include <nccl.h>
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include "nccl_warper.h"
#include "nccl_wrapper.h"
#include "ops/op_register.h"
namespace jittor {

View File

@ -7,7 +7,7 @@
#include "nccl_test_op.h"
#include "utils/str_utils.h"
#include "nccl_warper.h"
#include "nccl_wrapper.h"
namespace jittor {

View File

@ -8,7 +8,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "misc/cuda_flags.h"
#include "nccl_warper.h"
#include "nccl_wrapper.h"
#include "event_queue.h"
const char *_cudaGetErrorEnum(ncclResult_t error) {

View File

@ -6,7 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "mpi_wrapper.h"
#include "var.h"
#include "mpi_all_reduce_op.h"
#include "ops/op_register.h"

View File

@ -6,7 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "mpi_wrapper.h"
#include "var.h"
#include "mpi_broadcast_op.h"
#include "ops/op_register.h"

View File

@ -6,7 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "mpi_wrapper.h"
#include "var.h"
#include "mpi_reduce_op.h"
#include "ops/op_register.h"

View File

@ -3,7 +3,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mpi_warper.h"
#include "mpi_wrapper.h"
#include "var.h"
#include "mpi_test_op.h"

View File

@ -11,7 +11,7 @@
#include <stdint.h>
#include <stdio.h>
#include "mpi_warper.h"
#include "mpi_wrapper.h"
#include "common.h"
#include "ops/array_op.h"

View File

@ -937,7 +937,7 @@ Output::
print(out)
return tree, out
def python_pass_warper(mod_func, args, kw):
def python_pass_wrapper(mod_func, args, kw):
import importlib
mod, func = mod_func.rsplit(".", 1)
mod = importlib.import_module(mod)

View File

@ -14,7 +14,7 @@ namespace jittor {
string py_caller(const string& mod_func, const vector<string>& args, const map<string,string>& kw) {
PyObjHolder mod(PyImport_ImportModule("jittor"));
PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_warper"));
PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_wrapper"));
PyObjHolder py_name(to_py_object<string>(mod_func));
PyObjHolder py_args(to_py_tuple(args));
PyObjHolder py_kw(to_py_object(kw));

View File

@ -23,7 +23,7 @@ import urllib.request
if platform.system() == 'Darwin':
mp.set_start_method('fork')
class LogWarper:
class Logwrapper:
def __init__(self):
self.log_silent = int(os.environ.get("log_silent", "0"))
self.log_v = int(os.environ.get("log_v", "0"))
@ -482,7 +482,7 @@ def get_total_mem():
is_in_ipynb = in_ipynb()
cc = None
LOG = LogWarper()
LOG = Logwrapper()
check_msvc_install = False
msvc_path = ""