fix fused extras

This commit is contained in:
Dun Liang 2020-05-09 13:12:40 +08:00
parent 503da59a3b
commit 6d62d5da22
8 changed files with 89 additions and 7 deletions

View File

@ -105,7 +105,7 @@ template <typename T>
void check(T result, char const *const func, const char *const file,
int const line) {
if (result) {
DEVICE_RESET
// DEVICE_RESET
LOGf << "CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
<< func;
@ -125,7 +125,7 @@ inline void __getLastCudaError(const char *errorMessage, const char *file,
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
DEVICE_RESET
// DEVICE_RESET
LOGf << "CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(err) >> "(" << _cudaGetErrorEnum(err) << ")"
<< errorMessage;
@ -141,7 +141,7 @@ inline void __printLastCudaError(const char *errorMessage, const char *file,
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
DEVICE_RESET
// DEVICE_RESET
LOGf << "CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(err) >> "(" << _cudaGetErrorEnum(err) << ")"
<< errorMessage;

View File

@ -449,6 +449,13 @@ def fetch_var(var, func, *args, **kw):
Var.fetch = fetch_var
del fetch_var
def display_memory_info():
import inspect, os
f = inspect.currentframe()
fileline = inspect.getframeinfo(f.f_back)
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
core.display_memory_info(fileline)
def import_vars(data):
''' Load variables into current scopes
example:

View File

@ -380,6 +380,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
// var->finish_pending_liveness();
var->finish_pending_liveness();
} catch (const std::exception& e) {
// log memory info
display_memory_info(__FILELINE__);
// log jit_key and file location
op->do_prepare();
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;
if (is_fused_op) {
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"
<< "failed:" << fused_op.ops << "\n\nReason: " >> e.what();

View File

@ -4,7 +4,13 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <typeinfo>
#include <iomanip>
#include "var.h"
#include "op.h"
#include "var_holder.h"
#include "misc/cuda_flags.h"
#include "mem/allocator/aligned_allocator.h"
#ifdef HAS_CUDA
#include "mem/allocator/cuda_managed_allocator.h"
@ -84,5 +90,55 @@ Allocator* get_allocator() {
void gc_all() {
for (auto& kv : allocators) kv.second->gc();
}
struct FloatOutput {
double value;
string scale;
int base;
string suffix;
int p=4;
};
std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
int w = 8;
os << std::setw(w-2-o.suffix.size());
os << std::setprecision(o.p);
uint i=0;
double k = o.value;
for (; i+1<o.scale.size(); i++) {
if (k<o.base) break;
k /= o.base;
}
os << k << o.scale[i];
return os << o.suffix;
}
void display_memory_info(const char* fileline) {
int p = 2;
Log log(fileline, 'i', 0);
log << "\n=== display_memory_info ===\n";
log << "hold_vars:" << VarHolder::hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
if (use_stat_allocator) {
log << "stat:" << use_stat_allocator;
log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte
- stat_allocator_total_free_byte), " KMG", 1024, "B"};
log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call
- stat_allocator_total_free_call), " KMG", 1000, ""} >> '\n';
}
for (auto& a : SFRLAllocator::sfrl_allocators) {
auto total = a->used_memory + a->unused_memory;
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
}
log >> "===========================\n";
log.end();
}
} // jittor

View File

@ -51,4 +51,7 @@ Allocator* get_allocator();
// @pyjt(gc)
void gc_all();
// @pyjt(display_memory_info)
void display_memory_info(const char* fileline="");
} // jittor

View File

@ -19,5 +19,9 @@ struct StatAllocator : Allocator {
};
DECLARE_FLAG(int, use_stat_allocator);
DECLARE_FLAG(size_t, stat_allocator_total_alloc_call);
DECLARE_FLAG(size_t, stat_allocator_total_alloc_byte);
DECLARE_FLAG(size_t, stat_allocator_total_free_call);
DECLARE_FLAG(size_t, stat_allocator_total_free_byte);
} // jittor

View File

@ -758,7 +758,7 @@ string OpCompiler::__get_fused_src(
"for", "const", "auto", "get_random_engine",
"int", "float", "bool", "CHECK", "STRINGIZE",
"void", "__restrict__", "if", "true", "false",
"Op", "Var", "Node", "itof"
"Op", "Var", "Node", "itof", "assert", "ASSERT"
};
auto not_change = [&](const string& s) -> bool {
if (unchanged.count(s)) return true;
@ -914,7 +914,9 @@ string OpCompiler::__get_fused_src(
fused_kernel = fused_kernel_args + "\n" + fused_kernel;
LOGvvvv << "Fused kernel:\n" >> fused_kernel;
auto fused_src = fused_begin + fused_includes + "\n#include \"fused_op.h\"\n" +
auto fused_src = fused_begin + fused_includes +
"\n#include <assert.h>\n" +
"\n#include \"fused_op.h\"\n" +
fused_defines + '\n' +
"void jittor::FusedOp::jit_run() {\n" + fused_kernel + "\n}\n";

View File

@ -170,6 +170,10 @@ void LoopVarAnalyzePass::run() {
for (uint j=0; j<ndim; j++)
if (!(mask>>j&1) && j<loop_var_names.size()) {
for (auto& vname : vnames) {
// cannot replace extras shape
// TODO: optimize it
if (vname.find("extras") != string::npos)
continue;
// replace op{i}_{vname}shape{j} -> {loop_var_names[j]}
std::stringstream name1;
name1 << vname<<"shape"<<j;
@ -193,7 +197,7 @@ void LoopVarAnalyzePass::run() {
replace_vars.emplace_back(name1, name2);
}
LOGvvvv << "replace_vars" << replace_vars;
LOGvvv << "replace_vars" << replace_vars;
ir->replace(replace_vars);
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
// move define