mirror of https://github.com/Jittor/Jittor
add cuda archs
This commit is contained in:
parent
2153b65856
commit
a112d0fb86
|
@ -914,7 +914,11 @@ compile_extern()
|
||||||
|
|
||||||
with jit_utils.import_scope(import_flags):
|
with jit_utils.import_scope(import_flags):
|
||||||
import jittor_core as core
|
import jittor_core as core
|
||||||
|
|
||||||
flags = core.flags()
|
flags = core.flags()
|
||||||
|
if has_cuda:
|
||||||
|
nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||||
|
|
||||||
flags.cc_path = cc_path
|
flags.cc_path = cc_path
|
||||||
flags.cc_type = cc_type
|
flags.cc_type = cc_type
|
||||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
os.environ["log_silent"] = "1"
|
||||||
import re
|
import re
|
||||||
import jittor_utils as jit_utils
|
import jittor_utils as jit_utils
|
||||||
from jittor_utils import LOG
|
from jittor_utils import LOG
|
||||||
|
|
|
@ -17,7 +17,8 @@ from ctypes import cdll
|
||||||
|
|
||||||
class LogWarper:
|
class LogWarper:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
self.log_silent = int(os.environ.get("log_silent", "0"))
|
||||||
|
self.log_v = int(os.environ.get("log_v", "0"))
|
||||||
|
|
||||||
def log_capture_start(self):
|
def log_capture_start(self):
|
||||||
cc.log_capture_start()
|
cc.log_capture_start()
|
||||||
|
@ -39,6 +40,8 @@ class LogWarper:
|
||||||
if cc and hasattr(cc, "log"):
|
if cc and hasattr(cc, "log"):
|
||||||
cc.log(fileline, level, verbose, msg)
|
cc.log(fileline, level, verbose, msg)
|
||||||
else:
|
else:
|
||||||
|
if self.log_silent or verbose > self.log_v:
|
||||||
|
return
|
||||||
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
|
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
|
||||||
tid = threading.get_ident()%100
|
tid = threading.get_ident()%100
|
||||||
v = f" v{verbose}" if verbose else ""
|
v = f" v{verbose}" if verbose else ""
|
||||||
|
@ -150,7 +153,7 @@ def find_cache_path():
|
||||||
cache_name = os.environ["cache_name"]
|
cache_name = os.environ["cache_name"]
|
||||||
else:
|
else:
|
||||||
# try to get branch name from git
|
# try to get branch name from git
|
||||||
r = sp.run("git branch", cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||||
stderr=sp.PIPE)
|
stderr=sp.PIPE)
|
||||||
assert r.returncode == 0
|
assert r.returncode == 0
|
||||||
bs = r.stdout.decode()
|
bs = r.stdout.decode()
|
||||||
|
|
27
src/init.cc
27
src/init.cc
|
@ -3,6 +3,10 @@
|
||||||
// This file is subject to the terms and conditions defined in
|
// This file is subject to the terms and conditions defined in
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
#ifdef HAS_CUDA
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <helper_cuda.h>
|
||||||
|
#endif
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "init.h"
|
#include "init.h"
|
||||||
|
@ -10,16 +14,39 @@
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
DEFINE_FLAG(vector<int>, cuda_archs, {}, "Cuda arch");
|
||||||
|
|
||||||
unique_ptr<std::default_random_engine> eng;
|
unique_ptr<std::default_random_engine> eng;
|
||||||
|
|
||||||
vector<set_seed_callback> callbacks;
|
vector<set_seed_callback> callbacks;
|
||||||
int current_seed;
|
int current_seed;
|
||||||
|
|
||||||
|
static void init_cuda_devices() {
|
||||||
|
#ifdef HAS_CUDA
|
||||||
|
int count;
|
||||||
|
cudaGetDeviceCount(&count);
|
||||||
|
for (int i=0; i<count; i++) {
|
||||||
|
cudaDeviceProp devProp;
|
||||||
|
cudaGetDeviceProperties(&devProp, i);
|
||||||
|
int number = devProp.major * 10 + devProp.minor;
|
||||||
|
int found = 0;
|
||||||
|
for (auto v : cuda_archs)
|
||||||
|
if (v==number) {
|
||||||
|
found = 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!found) cuda_archs.push_back(number);
|
||||||
|
}
|
||||||
|
LOGi << "Found cuda archs:" << cuda_archs;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
// init default_random_engine
|
// init default_random_engine
|
||||||
set_seed(time(0));
|
set_seed(time(0));
|
||||||
// init fused op
|
// init fused op
|
||||||
op_registe({"fused","",""});
|
op_registe({"fused","",""});
|
||||||
|
init_cuda_devices();
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_seed(int seed) {
|
void set_seed(int seed) {
|
||||||
|
|
|
@ -162,6 +162,14 @@ std::ostream& operator<<(std::ostream& os, const set<T>& input) {
|
||||||
return os << ']';
|
return os << ']';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
std::istream& operator>>(std::istream& is, vector<T>& out) {
|
||||||
|
T value;
|
||||||
|
while (is >> value)
|
||||||
|
out.push_back(value);
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
|
||||||
template <class Ta, class Tb>
|
template <class Ta, class Tb>
|
||||||
std::ostream& operator<<(std::ostream& os, const map<Ta, Tb>& input) {
|
std::ostream& operator<<(std::ostream& os, const map<Ta, Tb>& input) {
|
||||||
os << '{';
|
os << '{';
|
||||||
|
|
Loading…
Reference in New Issue