mirror of https://github.com/Jittor/Jittor
add cuda archs
This commit is contained in:
parent
2153b65856
commit
a112d0fb86
|
@ -914,7 +914,11 @@ compile_extern()
|
|||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
import jittor_core as core
|
||||
|
||||
flags = core.flags()
|
||||
if has_cuda:
|
||||
nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
|
||||
|
||||
flags.cc_path = cc_path
|
||||
flags.cc_type = cc_type
|
||||
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
# ***************************************************************
|
||||
import sys
|
||||
import os
|
||||
os.environ["log_silent"] = "1"
|
||||
import re
|
||||
import jittor_utils as jit_utils
|
||||
from jittor_utils import LOG
|
||||
|
|
|
@ -17,7 +17,8 @@ from ctypes import cdll
|
|||
|
||||
class LogWarper:
|
||||
def __init__(self):
|
||||
pass
|
||||
self.log_silent = int(os.environ.get("log_silent", "0"))
|
||||
self.log_v = int(os.environ.get("log_v", "0"))
|
||||
|
||||
def log_capture_start(self):
|
||||
cc.log_capture_start()
|
||||
|
@ -39,6 +40,8 @@ class LogWarper:
|
|||
if cc and hasattr(cc, "log"):
|
||||
cc.log(fileline, level, verbose, msg)
|
||||
else:
|
||||
if self.log_silent or verbose > self.log_v:
|
||||
return
|
||||
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
|
||||
tid = threading.get_ident()%100
|
||||
v = f" v{verbose}" if verbose else ""
|
||||
|
@ -150,7 +153,7 @@ def find_cache_path():
|
|||
cache_name = os.environ["cache_name"]
|
||||
else:
|
||||
# try to get branch name from git
|
||||
r = sp.run("git branch", cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
|
||||
stderr=sp.PIPE)
|
||||
assert r.returncode == 0
|
||||
bs = r.stdout.decode()
|
||||
|
|
27
src/init.cc
27
src/init.cc
|
@ -3,6 +3,10 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#endif
|
||||
#include <random>
|
||||
|
||||
#include "init.h"
|
||||
|
@ -10,16 +14,39 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(vector<int>, cuda_archs, {}, "Cuda arch");
|
||||
|
||||
unique_ptr<std::default_random_engine> eng;
|
||||
|
||||
vector<set_seed_callback> callbacks;
|
||||
int current_seed;
|
||||
|
||||
static void init_cuda_devices() {
|
||||
#ifdef HAS_CUDA
|
||||
int count;
|
||||
cudaGetDeviceCount(&count);
|
||||
for (int i=0; i<count; i++) {
|
||||
cudaDeviceProp devProp;
|
||||
cudaGetDeviceProperties(&devProp, i);
|
||||
int number = devProp.major * 10 + devProp.minor;
|
||||
int found = 0;
|
||||
for (auto v : cuda_archs)
|
||||
if (v==number) {
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
if (!found) cuda_archs.push_back(number);
|
||||
}
|
||||
LOGi << "Found cuda archs:" << cuda_archs;
|
||||
#endif
|
||||
}
|
||||
|
||||
void init() {
|
||||
// init default_random_engine
|
||||
set_seed(time(0));
|
||||
// init fused op
|
||||
op_registe({"fused","",""});
|
||||
init_cuda_devices();
|
||||
}
|
||||
|
||||
void set_seed(int seed) {
|
||||
|
|
|
@ -162,6 +162,14 @@ std::ostream& operator<<(std::ostream& os, const set<T>& input) {
|
|||
return os << ']';
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::istream& operator>>(std::istream& is, vector<T>& out) {
|
||||
T value;
|
||||
while (is >> value)
|
||||
out.push_back(value);
|
||||
return is;
|
||||
}
|
||||
|
||||
template <class Ta, class Tb>
|
||||
std::ostream& operator<<(std::ostream& os, const map<Ta, Tb>& input) {
|
||||
os << '{';
|
||||
|
|
Loading…
Reference in New Issue