mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor into mpi_init
This commit is contained in:
commit
989cbba133
|
@ -0,0 +1,57 @@
|
|||
# This is a basic workflow to help you get started with Actions
|
||||
|
||||
name: CI
|
||||
|
||||
# Controls when the action will run. Triggers the workflow on push or pull request
|
||||
# events but only for the master branch
|
||||
on: [ push ]
|
||||
# push:
|
||||
# branches: [ master ]
|
||||
# pull_request:
|
||||
# branches: [ master ]
|
||||
|
||||
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
|
||||
jobs:
|
||||
test_clang_8_cuda_10:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: self-hosted
|
||||
|
||||
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||
steps:
|
||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: test
|
||||
run: |
|
||||
export cache_name=github_${GITHUB_REF##*/}
|
||||
export cc_path="clang++-8"
|
||||
export cc_flags=" -g "
|
||||
export log_sync=0
|
||||
export log_v=0
|
||||
export PYTHONIOENCODING=utf8
|
||||
export PYTHONPATH=`pwd`/python
|
||||
export nvcc_path=/usr/local/cuda/bin/nvcc
|
||||
python3.7 -c "import jittor"
|
||||
python3.7 -m jittor.test -v
|
||||
|
||||
test_gcc:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: self-hosted
|
||||
|
||||
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||
steps:
|
||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: test
|
||||
run: |
|
||||
export cache_name=github_${GITHUB_REF##*/}
|
||||
export cc_path="g++"
|
||||
export cc_flags=" -g "
|
||||
export log_sync=0
|
||||
export log_v=0
|
||||
export PYTHONIOENCODING=utf8
|
||||
export PYTHONPATH=`pwd`/python
|
||||
export nvcc_path=
|
||||
python3.7 -c "import jittor"
|
||||
python3.7 -m jittor.test -v
|
|
@ -79,6 +79,11 @@ const char *_cudaGetErrorEnum(cusolverStatus_t error);
|
|||
const char *_cudaGetErrorEnum(curandStatus_t error);
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_H_
|
||||
// cuRAND API errors
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error);
|
||||
#endif
|
||||
|
||||
#ifdef NV_NPPIDEFS_H
|
||||
// NPP API errors
|
||||
const char *_cudaGetErrorEnum(NppStatus error);
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2019 Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
#ifndef JIT
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
output = create_output(1, ns_float32);
|
||||
}
|
||||
|
||||
void NcclTestOp::jit_prepare() {
|
||||
add_jit_define("T", ns_float32);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
void NcclTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
output->ptr<T>()[0] = 123;
|
||||
|
||||
|
||||
|
||||
//managing 4 devices
|
||||
int nDev;
|
||||
checkCudaErrors(cudaGetDeviceCount(&nDev));
|
||||
nDev = std::min(nDev, 2);
|
||||
|
||||
ncclComm_t comms[nDev];
|
||||
int size = 32*1024*1024;
|
||||
int devs[4] = { 0, 1, 2, 3 };
|
||||
|
||||
|
||||
//allocating and initializing device buffers
|
||||
float** sendbuff = (float**)malloc(nDev * sizeof(float*));
|
||||
float** recvbuff = (float**)malloc(nDev * sizeof(float*));
|
||||
cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);
|
||||
|
||||
|
||||
for (int i = 0; i < nDev; ++i) {
|
||||
checkCudaErrors(cudaSetDevice(i));
|
||||
checkCudaErrors(cudaMalloc(sendbuff + i, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMalloc(recvbuff + i, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMemset(sendbuff[i], 1, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMemset(recvbuff[i], 0, size * sizeof(float)));
|
||||
checkCudaErrors(cudaStreamCreate(s+i));
|
||||
}
|
||||
|
||||
|
||||
//initializing NCCL
|
||||
checkCudaErrors(ncclCommInitAll(comms, nDev, devs));
|
||||
|
||||
|
||||
//calling NCCL communication API. Group API is required when using
|
||||
//multiple devices per thread
|
||||
checkCudaErrors(ncclGroupStart());
|
||||
for (int i = 0; i < nDev; ++i)
|
||||
checkCudaErrors(ncclAllReduce((const void*)sendbuff[i], (void*)recvbuff[i], size, ncclFloat, ncclSum,
|
||||
comms[i], s[i]));
|
||||
checkCudaErrors(ncclGroupEnd());
|
||||
|
||||
|
||||
//synchronizing on CUDA streams to wait for completion of NCCL operation
|
||||
for (int i = 0; i < nDev; ++i) {
|
||||
checkCudaErrors(cudaSetDevice(i));
|
||||
checkCudaErrors(cudaStreamSynchronize(s[i]));
|
||||
}
|
||||
|
||||
|
||||
//free device buffers
|
||||
for (int i = 0; i < nDev; ++i) {
|
||||
checkCudaErrors(cudaSetDevice(i));
|
||||
checkCudaErrors(cudaFree(sendbuff[i]));
|
||||
checkCudaErrors(cudaFree(recvbuff[i]));
|
||||
}
|
||||
|
||||
|
||||
//finalizing NCCL
|
||||
for(int i = 0; i < nDev; ++i)
|
||||
ncclCommDestroy(comms[i]);
|
||||
checkCudaErrors(cudaSetDevice(0));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,23 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct NcclTestOp : Op {
|
||||
Var* output;
|
||||
string cmd;
|
||||
|
||||
NcclTestOp(string cmd);
|
||||
|
||||
const char* name() const override { return "nccl_test"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -3,8 +3,9 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import os, sys
|
||||
import os, sys, shutil
|
||||
from .compiler import *
|
||||
from jittor_utils import run_cmd
|
||||
from jittor.dataset.utils import download_url_to_local
|
||||
|
||||
def search_file(dirs, name):
|
||||
|
@ -171,10 +172,10 @@ def install_cutt(root_folder):
|
|||
true_md5 = "a6f4f7f75310a69b131e21f1ebec768a"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = os.popen('md5sum ' + fullname).read().split()[0]
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
if md5 != true_md5:
|
||||
os.system('rm ' + fullname)
|
||||
os.system('rm -rf ' + dirname)
|
||||
os.remove(fullname)
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
|
||||
LOG.i("Downloading cutt...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
@ -186,11 +187,11 @@ def install_cutt(root_folder):
|
|||
zf.extractall(path=root_folder)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
raise
|
||||
zf.close()
|
||||
|
||||
from jittor_utils import run_cmd
|
||||
LOG.i("installing cutt...")
|
||||
run_cmd(f"cd {dirname} && make")
|
||||
run_cmd(f"make", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
def setup_cutt():
|
||||
|
@ -233,7 +234,73 @@ def setup_cutt():
|
|||
LOG.vv("Get cutt_ops: "+str(dir(cutt_ops)))
|
||||
|
||||
|
||||
def install_nccl(root_folder):
|
||||
url = "https://github.com/NVIDIA/nccl/archive/v2.6.4-1.tar.gz"
|
||||
|
||||
filename = "nccl.tgz"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, "nccl-2.6.4-1")
|
||||
true_md5 = "38d7a9e98d95a99df0a4f1ad6fb50fa7"
|
||||
|
||||
if os.path.exists(fullname):
|
||||
md5 = run_cmd('md5sum '+fullname).split()[0]
|
||||
if md5 != true_md5:
|
||||
os.remove(fullname)
|
||||
if os.path.isdir(dirname):
|
||||
shutil.rmtree(dirname)
|
||||
if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")):
|
||||
LOG.i("Downloading nccl...")
|
||||
download_url_to_local(url, filename, root_folder, true_md5)
|
||||
|
||||
import tarfile
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
|
||||
LOG.i("installing nccl...")
|
||||
arch_flag = f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||
run_cmd(f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname)
|
||||
return dirname
|
||||
|
||||
def setup_nccl():
|
||||
global nccl_ops, use_nccl
|
||||
use_nccl = os.environ.get("use_nccl", "1")=="1"
|
||||
nccl_ops = None
|
||||
if not has_cuda:
|
||||
use_nccl = False
|
||||
return
|
||||
if not use_nccl: return
|
||||
nccl_include_path = os.environ.get("nccl_include_path")
|
||||
nccl_lib_path = os.environ.get("nccl_lib_path")
|
||||
|
||||
if nccl_lib_path is None or nccl_include_path is None:
|
||||
LOG.v("setup nccl...")
|
||||
# nccl_path decouple with cc_path
|
||||
from pathlib import Path
|
||||
nccl_path = os.path.join(str(Path.home()), ".cache", "jittor", "nccl")
|
||||
|
||||
make_cache_dir(nccl_path)
|
||||
nccl_home = install_nccl(nccl_path)
|
||||
nccl_include_path = os.path.join(nccl_home, "build", "include")
|
||||
nccl_lib_path = os.path.join(nccl_home, "build", "lib")
|
||||
|
||||
nccl_lib_name = os.path.join(nccl_lib_path, "libnccl.so")
|
||||
assert os.path.isdir(nccl_include_path)
|
||||
assert os.path.isdir(nccl_lib_path)
|
||||
assert os.path.isfile(nccl_lib_name), nccl_lib_name
|
||||
LOG.v(f"nccl_include_path: {nccl_include_path}")
|
||||
LOG.v(f"nccl_lib_path: {nccl_lib_path}")
|
||||
LOG.v(f"nccl_lib_name: {nccl_lib_name}")
|
||||
# We do not link manualy, link in custom ops
|
||||
ctypes.CDLL(nccl_lib_name, dlopen_flags)
|
||||
|
||||
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
||||
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
||||
nccl_ops = compile_custom_ops(nccl_op_files,
|
||||
extra_flags=f" -I'{nccl_include_path}'")
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
setup_cutt()
|
||||
setup_mkl()
|
||||
setup_nccl()
|
||||
|
||||
setup_cuda_extern()
|
||||
|
|
|
@ -689,8 +689,9 @@ def compile_extern():
|
|||
def check_cuda():
|
||||
if nvcc_path == "":
|
||||
return
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include
|
||||
global cc_flags, has_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home
|
||||
cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path))
|
||||
cuda_home = os.path.abspath(os.path.join(cuda_dir, ".."))
|
||||
assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}"
|
||||
cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include"))
|
||||
cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64"))
|
||||
|
@ -920,7 +921,11 @@ compile_extern()
|
|||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
import jittor_core as core
|
||||
|
||||
flags = core.flags()
|
||||
if has_cuda:
|
||||
nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||
|
||||
flags.cc_path = cc_path
|
||||
flags.cc_type = cc_type
|
||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||
|
|
|
@ -45,6 +45,7 @@ def check(jt_model, torch_model, shape, near_data):
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestArgPoolOp(unittest.TestCase):
|
||||
@unittest.skipIf(jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda(self):
|
||||
jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1))
|
||||
|
|
|
@ -133,8 +133,8 @@ def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc):
|
|||
jt.sync([cy, closs, cdx, cdw])
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)")
|
||||
assert len(logs)==3 and "oihw" in logs[0][0], (logs)
|
||||
assert np.allclose(y.data, cy.data)
|
||||
assert np.allclose(dw.data, cdw.data), (dw.data, cdw.data)
|
||||
assert np.allclose(y.data, cy.data, 1e-3)
|
||||
assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data)
|
||||
assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(), np.abs(dx.data - cdx.data).max())
|
||||
|
||||
class TestConvTuner(unittest.TestCase):
|
||||
|
|
|
@ -123,7 +123,7 @@ class TestCudnnConvOp(unittest.TestCase):
|
|||
assert len(logs)==3 and "oihw" in logs[0][0], logs
|
||||
assert np.allclose(y.data, cy.data)
|
||||
assert np.allclose(dx.data, cdx.data, 1e-2)
|
||||
assert np.allclose(dw.data, cdw.data)
|
||||
assert np.allclose(dw.data, cdw.data, 1e-2)
|
||||
check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1)
|
||||
check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1)
|
||||
check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import unittest
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl found")
|
||||
class TestNccl(unittest.TestCase):
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nccl(self):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("").data == 123
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -116,7 +116,7 @@ class TestResnet(unittest.TestCase):
|
|||
assert jt.core.number_of_lived_vars() < 3500
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.2
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -31,7 +31,7 @@ class MnistNet(Module):
|
|||
x = self.layer(x)
|
||||
return x
|
||||
|
||||
@unittest.skipIf(skip_model_test, "skip_this_test")
|
||||
@unittest.skipIf(skip_model_test, "skip_this_test, model_test != 1")
|
||||
class TestVGGClass(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
# ***************************************************************
|
||||
import sys
|
||||
import os
|
||||
os.environ["log_silent"] = "1"
|
||||
import re
|
||||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG
|
||||
|
|
|
@ -17,7 +17,8 @@ from ctypes import cdll
|
|||
|
||||
class LogWarper:
|
||||
def __init__(self):
|
||||
pass
|
||||
self.log_silent = int(os.environ.get("log_silent", "0"))
|
||||
self.log_v = int(os.environ.get("log_v", "0"))
|
||||
|
||||
def log_capture_start(self):
|
||||
cc.log_capture_start()
|
||||
|
@ -39,6 +40,8 @@ class LogWarper:
|
|||
if cc and hasattr(cc, "log"):
|
||||
cc.log(fileline, level, verbose, msg)
|
||||
else:
|
||||
if self.log_silent or verbose > self.log_v:
|
||||
return
|
||||
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
|
||||
tid = threading.get_ident()%100
|
||||
v = f" v{verbose}" if verbose else ""
|
||||
|
@ -100,8 +103,10 @@ def try_import_jit_utils_core(silent=None):
|
|||
|
||||
def run_cmd(cmd, cwd=None, err_msg=None, print_error=True):
|
||||
LOG.v(f"Run cmd: {cmd}")
|
||||
if cwd: cmd = f"cd {cwd} && {cmd}"
|
||||
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
|
||||
if cwd:
|
||||
r = sp.run(cmd, cwd=cwd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
|
||||
else:
|
||||
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT)
|
||||
s = r.stdout.decode('utf8')
|
||||
if r.returncode != 0:
|
||||
if print_error:
|
||||
|
@ -150,7 +155,7 @@ def find_cache_path():
|
|||
cache_name = os.environ["cache_name"]
|
||||
else:
|
||||
# try to get branch name from git
|
||||
r = sp.run("git branch", cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
stderr=sp.PIPE)
|
||||
assert r.returncode == 0
|
||||
bs = r.stdout.decode()
|
||||
|
|
27
src/init.cc
27
src/init.cc
|
@ -3,6 +3,10 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#endif
|
||||
#include <random>
|
||||
|
||||
#include "init.h"
|
||||
|
@ -10,16 +14,39 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(vector<int>, cuda_archs, {}, "Cuda arch");
|
||||
|
||||
unique_ptr<std::default_random_engine> eng;
|
||||
|
||||
vector<set_seed_callback> callbacks;
|
||||
int current_seed;
|
||||
|
||||
static void init_cuda_devices() {
|
||||
#ifdef HAS_CUDA
|
||||
int count;
|
||||
cudaGetDeviceCount(&count);
|
||||
for (int i=0; i<count; i++) {
|
||||
cudaDeviceProp devProp;
|
||||
cudaGetDeviceProperties(&devProp, i);
|
||||
int number = devProp.major * 10 + devProp.minor;
|
||||
int found = 0;
|
||||
for (auto v : cuda_archs)
|
||||
if (v==number) {
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
if (!found) cuda_archs.push_back(number);
|
||||
}
|
||||
LOGi << "Found cuda archs:" << cuda_archs;
|
||||
#endif
|
||||
}
|
||||
|
||||
void init() {
|
||||
// init default_random_engine
|
||||
set_seed(time(0));
|
||||
// init fused op
|
||||
op_registe({"fused","",""});
|
||||
init_cuda_devices();
|
||||
}
|
||||
|
||||
void set_seed(int seed) {
|
||||
|
|
|
@ -45,7 +45,7 @@ jit_op_entry_t load_jit_lib(string name, string symbol_name="jit_entry") {
|
|||
}
|
||||
|
||||
void run_cmd(string cmd, string cwd="") {
|
||||
if (cwd.size()) cmd = "cd "+cwd + " && " + cmd;
|
||||
if (cwd.size()) cmd = "cd '"+cwd + "' && " + cmd;
|
||||
LOGvvv << "Run cmd:" << cmd;
|
||||
system_with_check(cmd.c_str());
|
||||
}
|
||||
|
|
|
@ -162,6 +162,14 @@ std::ostream& operator<<(std::ostream& os, const set<T>& input) {
|
|||
return os << ']';
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::istream& operator>>(std::istream& is, vector<T>& out) {
|
||||
T value;
|
||||
while (is >> value)
|
||||
out.push_back(value);
|
||||
return is;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
std::ostream& operator<<(std::ostream& os, const map<Ta, Tb>& input) {
|
||||
os << '{';
|
||||
|
|
Loading…
Reference in New Issue