add cuda archs

This commit is contained in:
Dun Liang 2020-04-01 16:53:40 +08:00
parent 2153b65856
commit a112d0fb86
5 changed files with 45 additions and 2 deletions

View File

@ -914,7 +914,11 @@ compile_extern()
with jit_utils.import_scope(import_flags):
import jittor_core as core
flags = core.flags()
if has_cuda:
nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} "
flags.cc_path = cc_path
flags.cc_type = cc_type
flags.cc_flags = cc_flags + link_flags + kernel_opt_flags

View File

@ -9,6 +9,7 @@
# ***************************************************************
import sys
import os
os.environ["log_silent"] = "1"
import re
import jittor_utils as jit_utils
from jittor_utils import LOG

View File

@ -17,7 +17,8 @@ from ctypes import cdll
class LogWarper:
def __init__(self):
pass
self.log_silent = int(os.environ.get("log_silent", "0"))
self.log_v = int(os.environ.get("log_v", "0"))
def log_capture_start(self):
cc.log_capture_start()
@ -39,6 +40,8 @@ class LogWarper:
if cc and hasattr(cc, "log"):
cc.log(fileline, level, verbose, msg)
else:
if self.log_silent or verbose > self.log_v:
return
time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f")
tid = threading.get_ident()%100
v = f" v{verbose}" if verbose else ""
@ -150,7 +153,7 @@ def find_cache_path():
cache_name = os.environ["cache_name"]
else:
# try to get branch name from git
r = sp.run("git branch", cwd=os.path.dirname(__file__), stdout=sp.PIPE,
r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE,
stderr=sp.PIPE)
assert r.returncode == 0
bs = r.stdout.decode()

View File

@ -3,6 +3,10 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include <helper_cuda.h>
#endif
#include <random>
#include "init.h"
@ -10,16 +14,39 @@
namespace jittor {
DEFINE_FLAG(vector<int>, cuda_archs, {}, "Cuda arch");
unique_ptr<std::default_random_engine> eng;
vector<set_seed_callback> callbacks;
int current_seed;
static void init_cuda_devices() {
#ifdef HAS_CUDA
int count;
cudaGetDeviceCount(&count);
for (int i=0; i<count; i++) {
cudaDeviceProp devProp;
cudaGetDeviceProperties(&devProp, i);
int number = devProp.major * 10 + devProp.minor;
int found = 0;
for (auto v : cuda_archs)
if (v==number) {
found = 1;
break;
}
if (!found) cuda_archs.push_back(number);
}
LOGi << "Found cuda archs:" << cuda_archs;
#endif
}
void init() {
// init default_random_engine
set_seed(time(0));
// init fused op
op_registe({"fused","",""});
init_cuda_devices();
}
void set_seed(int seed) {

View File

@ -162,6 +162,14 @@ std::ostream& operator<<(std::ostream& os, const set<T>& input) {
return os << ']';
}
template <class T>
std::istream& operator>>(std::istream& is, vector<T>& out) {
T value;
while (is >> value)
out.push_back(value);
return is;
}
template <class Ta, class Tb>
std::ostream& operator<<(std::ostream& os, const map<Ta, Tb>& input) {
os << '{';