mirror of https://github.com/Jittor/Jittor
317 lines
8.8 KiB
C++
317 lines
8.8 KiB
C++
// ***************************************************************
|
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
// This file is subject to the terms and conditions defined in
|
|
// file 'LICENSE.txt', which is part of this source code package.
|
|
// ***************************************************************
|
|
#include <limits>
|
|
|
|
#include "node.h"
|
|
#include "op.h"
|
|
#include "var.h"
|
|
#include "op_compiler.h"
|
|
#include "profiler/profiler.h"
|
|
#include "mem/allocator.h"
|
|
#include "misc/cuda_flags.h"
|
|
#include "pybind/py_var_tracer.h"
|
|
|
|
namespace jittor {
|
|
|
|
DECLARE_FLAG(string, cache_path);
|
|
|
|
DEFINE_FLAG(int, try_use_32bit_index, 0,
|
|
"If not overflow, try to use 32 bit type as index type.");
|
|
|
|
string_view_map<jit_op_entry_t> jit_ops;
|
|
string_view_map<string> jit_key_mapper;
|
|
|
|
int64 Op::number_of_lived_ops = 0;
|
|
|
|
Op::Op() {
|
|
flags.set(NodeFlags::_var, 0);
|
|
flags.set(NodeFlags::_cpu, 1);
|
|
flags.flags |= ((amp_reg & 7) << NodeFlags::_prefer_32);
|
|
number_of_lived_ops++;
|
|
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
|
|
}
|
|
|
|
Op::~Op() {
|
|
number_of_lived_ops--;
|
|
}
|
|
|
|
void Op::forward(Var* input) {
|
|
flags.set(NodeFlags::_forwarded);
|
|
outputs_holder.emplace_back(input);
|
|
}
|
|
|
|
VarPtr Op::duplicate() {
|
|
return nullptr;
|
|
}
|
|
|
|
VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|
LOGw << "Grad of" << name() << "return zeros";
|
|
return nullptr;
|
|
}
|
|
|
|
void Op::grads(Var** douts, VarPtr* dins) {
|
|
LOGw << "Grads of" << name() << "return zeros";
|
|
}
|
|
|
|
Var* Op::create_output(NanoVector shape, NanoString dtype) {
|
|
VarPtr vp(shape, dtype);
|
|
Var* output = vp.ptr;
|
|
outputs_holder.emplace_back(move(vp));
|
|
return output;
|
|
}
|
|
|
|
void Op::init() {
|
|
bool has_vary_input = 0;
|
|
for (Var* v : inputs())
|
|
if (v->num < 0) {
|
|
has_vary_input = 1;
|
|
break;
|
|
}
|
|
flags.set(NodeFlags::_has_vary_input, has_vary_input);
|
|
infer_shape();
|
|
}
|
|
|
|
void Op::compile_optimize(string& src) {}
|
|
|
|
void Op::infer_shape() {}
|
|
void Op::run() {}
|
|
void Op::jit_prepare(JK& jk) {}
|
|
void Op::graph_optimize() {}
|
|
|
|
string Op::name_ex() const {
|
|
string a=name();
|
|
if (ns!=ns_void) {
|
|
a += '.';
|
|
a += ns.to_cstring();
|
|
}
|
|
return a;
|
|
}
|
|
|
|
string Op::get_jit_key(JK& jk) {
|
|
jk.clear();
|
|
do_jit_prepare(jk);
|
|
return jk.to_string();
|
|
}
|
|
|
|
vector<pair<string,string>> Op::get_jit_define() {
|
|
return parse_jit_keys(get_jit_key(get_jk()));
|
|
}
|
|
|
|
string Op::get_hash_name() {
|
|
string hash_name;
|
|
std::stringstream ss;
|
|
JK& jk = get_jk();
|
|
do_prepare(jk);
|
|
ss << std::hex << std::hash<string>()(jk.to_string());
|
|
hash_name = ss.str();
|
|
return hash_name;
|
|
}
|
|
|
|
void Op::do_jit_prepare(JK& jk) {
|
|
memcheck_all_exist();
|
|
jk << name();
|
|
jit_prepare(jk);
|
|
if (jk.empty()) {
|
|
// not a jit op
|
|
bool has_cuda = flags.get(NodeFlags::_cuda);
|
|
bool has_cpu = flags.get(NodeFlags::_cpu);
|
|
CHECK(has_cuda || has_cpu);
|
|
if (has_cuda && has_cpu && !use_cuda)
|
|
flags.set(NodeFlags::_cuda, 0);
|
|
} else {
|
|
// check use int64_t as index_t if array is too big
|
|
int in_id=0, out_id=0;
|
|
bool use_int64_t = false;
|
|
// TODO: fused op do not have inputs,
|
|
// check use_cuda_op from outputs may not be enough
|
|
bool use_cuda_op = use_cuda;
|
|
for (Var* var : inputs()) {
|
|
if (var->mem_ptr) {
|
|
/* jit key don't include here, because
|
|
parallel compiler don't known
|
|
jk << JK::key << "alloc_i" << JK::hex1(in_id)
|
|
<< JK::hex1(var->allocator->flags()) << JK::end;
|
|
*/
|
|
use_cuda_op &= var->allocator->is_cuda();
|
|
}
|
|
if (var->num >= std::numeric_limits<int32_t>::max())
|
|
use_int64_t = true;
|
|
in_id ++;
|
|
}
|
|
for (Var* var : outputs()) {
|
|
if (var->mem_ptr) {
|
|
/*
|
|
jk << JK::key << "alloc_o" << JK::hex1(in_id)
|
|
<< JK::hex1(var->allocator->flags()) << JK::end;
|
|
*/
|
|
use_cuda_op &= var->allocator->is_cuda();
|
|
}
|
|
if (var->num >= std::numeric_limits<int32_t>::max())
|
|
use_int64_t = true;
|
|
out_id ++;
|
|
}
|
|
jk << _CS("[JIT:1]");
|
|
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
|
jk << _CS("[JIT_cuda:1]");
|
|
flags.set(NodeFlags::_cpu, 0);
|
|
// TODO: 64bit index in CUDA
|
|
use_int64_t = false;
|
|
} else {
|
|
if (use_cuda==2) {
|
|
if (flags.get(NodeFlags::_cuda))
|
|
LOGf << "Op" << name() >> "'s vars are not allocated in cuda";
|
|
else
|
|
LOGf << "Op" << name() << "doesn't have cuda version";
|
|
}
|
|
ASSERT(flags.get(NodeFlags::_cpu))
|
|
<< "Op" << name() << "doesn't have cpu version";
|
|
jk << _CS("[JIT_cpu:1]");
|
|
flags.set(NodeFlags::_cuda, 0);
|
|
}
|
|
if (try_use_32bit_index) use_int64_t = false;
|
|
if (use_int64_t)
|
|
jk << _CS("[index_t:int64]");
|
|
else
|
|
jk << _CS("[index_t:int32]");
|
|
}
|
|
jk.finilize();
|
|
}
|
|
|
|
void Op::do_prepare(JK& jk){
|
|
jk.clear();
|
|
do_jit_prepare(jk);
|
|
}
|
|
|
|
void Op::do_run_after_prepare(JK& jk) {
|
|
if (!jk.empty())
|
|
jit_run(jk);
|
|
else
|
|
run();
|
|
}
|
|
|
|
void Op::do_run() {
|
|
JK& jk = get_jk();
|
|
do_prepare(jk);
|
|
do_run_after_prepare(jk);
|
|
}
|
|
|
|
string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix) {
|
|
auto iter = jit_key_mapper.find(jit_key);
|
|
string s = iter==jit_key_mapper.end() ? jit_key : iter->second;
|
|
std::stringstream ss;
|
|
if (s.size() > 100) {
|
|
ss << s.substr(0, 90) << "...hash_"
|
|
<< std::hex << std::hash<string>()(s);
|
|
} else {
|
|
ss << s << "_hash_" <<
|
|
std::hex << std::hash<string>()(s);
|
|
}
|
|
s = ss.str();
|
|
for (char& c : s) {
|
|
if (!((c>='a' && c<='z') || (c>='A' && c<='Z') || (c>='0' && c<='9')))
|
|
c = '_';
|
|
}
|
|
#ifndef _WIN32
|
|
string filename = cache_path + "/jit/";
|
|
#else
|
|
string filename = cache_path + "\\jit\\";
|
|
#endif
|
|
filename += s;
|
|
filename += "_op";
|
|
filename += suffix;
|
|
return filename;
|
|
}
|
|
|
|
// convert xxx.yyy -> xxx
|
|
string Op::op_name_to_file_name(const string& s) {
|
|
auto pos = s.find('.');
|
|
return pos == string::npos ? s : s.substr(0, pos);
|
|
}
|
|
// convert xxx_xxx -> XxxXxx
|
|
string Op::file_name_to_class_name(const string& s) {
|
|
char prev = '_';
|
|
string res;
|
|
res.reserve(s.size());
|
|
for (char c : s) {
|
|
if (c != '_') {
|
|
if (prev == '_')
|
|
res += c-'a'+'A';
|
|
else
|
|
res += c;
|
|
}
|
|
prev = c;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void Op::jit_run(JK& jk) {
|
|
const char* jit_key = jk.to_cstring();
|
|
auto iter = jit_ops.find(jit_key);
|
|
if (iter != jit_ops.end()) {
|
|
LOGvvv << "Jit op key found:" << jit_key << "jit op entry:" << (void*)iter->second;
|
|
Profiler::record_and_run(iter->second, this, jit_key);
|
|
return;
|
|
}
|
|
LOGvv << "Jit op key not found:" << jit_key;
|
|
// compile JIT op
|
|
string prev_jit_key = jit_key;
|
|
auto op_entry = OpCompiler::do_compile(this);
|
|
string new_jit_key = get_jit_key(jk);
|
|
jit_ops[new_jit_key] = jit_ops[prev_jit_key] = op_entry;
|
|
jit_key_mapper[prev_jit_key] = new_jit_key;
|
|
LOGvv << "Get jit op entry:" << (void*)op_entry;
|
|
Profiler::record_and_run(op_entry, this, new_jit_key.c_str());
|
|
}
|
|
|
|
void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
|
in = out = compute = 0;
|
|
for (auto& e : _inputs) {
|
|
auto var = e.node->var();
|
|
if (e.back->index<0) continue;
|
|
in += var->size;
|
|
compute = std::max(compute, (uint64_t)var->num);
|
|
}
|
|
for (auto& e : _outputs) {
|
|
auto var = e.node->var();
|
|
if (e.index<0) continue;
|
|
out += var->size;
|
|
compute = std::max(compute, (uint64_t)var->num);
|
|
}
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& os, const Op* op) {
|
|
if (!op) return os << "Op(0)";
|
|
os << "Op(" << (void*)op
|
|
<< ':' << op->forward_liveness
|
|
<< ':' << op->backward_liveness
|
|
<< ':' << op->pending_liveness
|
|
<< ":i" << op->_inputs.size()
|
|
<< ":o" << op->_outputs.size()
|
|
<< ":s" << op->is_finished()
|
|
<< "," << op->name_ex();
|
|
if (op->_outputs.size()>1)
|
|
os << "->...";
|
|
else if (op->_outputs.size() == 1) {
|
|
auto v = (Var*)op->_outputs.front().node;
|
|
if (v->name.size())
|
|
os << "->" << v->name;
|
|
else
|
|
os << "->" << (void*)v;
|
|
}
|
|
os << ')';
|
|
#ifdef NODE_MEMCHECK
|
|
os << '<' << op->__id() << '>';
|
|
#endif
|
|
if (trace_py_var) {
|
|
os << '{';
|
|
print_node_trace(op, os);
|
|
os << '}';
|
|
}
|
|
return os;
|
|
}
|
|
|
|
} // jittor
|