From a112d0fb8695cefb834a01a484332ff9c9ec69e7 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Wed, 1 Apr 2020 16:53:40 +0800 Subject: [PATCH] add cuda archs --- python/jittor/compiler.py | 4 ++++ python/jittor/utils/asm_tuner.py | 1 + python/jittor_utils/__init__.py | 7 +++++-- src/init.cc | 27 +++++++++++++++++++++++++++ src/types.h | 8 ++++++++ 5 files changed, 45 insertions(+), 2 deletions(-) diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index a8ecb600..8930390e 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -914,7 +914,11 @@ compile_extern() with jit_utils.import_scope(import_flags): import jittor_core as core + flags = core.flags() +if has_cuda: + nvcc_flags += f" -arch={','.join(map(lambda x:'sm_'+str(x),flags.cuda_archs))} " + flags.cc_path = cc_path flags.cc_type = cc_type flags.cc_flags = cc_flags + link_flags + kernel_opt_flags diff --git a/python/jittor/utils/asm_tuner.py b/python/jittor/utils/asm_tuner.py index b88d016f..468fea5c 100755 --- a/python/jittor/utils/asm_tuner.py +++ b/python/jittor/utils/asm_tuner.py @@ -9,6 +9,7 @@ # *************************************************************** import sys import os +os.environ["log_silent"] = "1" import re import jittor_utils as jit_utils from jittor_utils import LOG diff --git a/python/jittor_utils/__init__.py b/python/jittor_utils/__init__.py index 72203fc2..b9fcff1b 100644 --- a/python/jittor_utils/__init__.py +++ b/python/jittor_utils/__init__.py @@ -17,7 +17,8 @@ from ctypes import cdll class LogWarper: def __init__(self): - pass + self.log_silent = int(os.environ.get("log_silent", "0")) + self.log_v = int(os.environ.get("log_v", "0")) def log_capture_start(self): cc.log_capture_start() @@ -39,6 +40,8 @@ class LogWarper: if cc and hasattr(cc, "log"): cc.log(fileline, level, verbose, msg) else: + if self.log_silent or verbose > self.log_v: + return time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f") tid = threading.get_ident()%100 v = f" v{verbose}" if verbose else "" @@ -150,7 +153,7 @@ def find_cache_path(): cache_name = os.environ["cache_name"] else: # try to get branch name from git - r = sp.run("git branch", cwd=os.path.dirname(__file__), stdout=sp.PIPE, + r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE, stderr=sp.PIPE) assert r.returncode == 0 bs = r.stdout.decode() diff --git a/src/init.cc b/src/init.cc index 745ad436..468b5370 100644 --- a/src/init.cc +++ b/src/init.cc @@ -3,6 +3,10 @@ // This file is subject to the terms and conditions defined in // file 'LICENSE.txt', which is part of this source code package. // *************************************************************** +#ifdef HAS_CUDA +#include +#include +#endif #include #include "init.h" @@ -10,16 +14,39 @@ namespace jittor { +DEFINE_FLAG(vector, cuda_archs, {}, "Cuda arch"); + unique_ptr eng; vector callbacks; int current_seed; +static void init_cuda_devices() { +#ifdef HAS_CUDA + int count; + cudaGetDeviceCount(&count); + for (int i=0; i& input) { return os << ']'; } +template +std::istream& operator>>(std::istream& is, vector& out) { + T value; + while (is >> value) + out.push_back(value); + return is; +} + template std::ostream& operator<<(std::ostream& os, const map& input) { os << '{';