JittorMirror/python/jittor/vcompiler/vcompiler.cc

1049 lines
35 KiB
C++

// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
// Guoye Yang <498731903@qq.com>
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <algorithm>
#include <functional>
#include <queue>
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include "mem/allocator/cuda_dual_allocator.h"
#include "event_queue.h"
#endif
#include "misc/cuda_flags.h"
#include "executor.h"
#include "var.h"
#include "op.h"
#include "mem/allocator.h"
#include "graph.h"
#include "fused_op.h"
#include "fuser.h"
#include "profiler/profiler_guard.h"
#include "parallel_compiler.h"
#include "memory_profiler.h"
#include "misc/nan_checker.h"
#include "memory_profiler.h"
#include "utils/seh.h"
#include "utils/cache_compile.h"
#include "var_holder.h"
#include "mem/swap.h"
#include "mem/mem_info.h"
#include <cuda_fp16.h>
#include "var_holder.h"
#include "vcompiler.h"
namespace jittor {
EXTERN_LIB MemoryProfiler memory_profiler;
DECLARE_FLAG(int, profile_memory_enable);
DECLARE_FLAG(int, gopt_disable);
DECLARE_FLAG(int, use_threading);
// from cuda_managed_allocator
#ifdef HAS_CUDA
DECLARE_FLAG(int, use_cuda_managed_allocator);
#endif
void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int ll, int rr, int64 tt) {
fused_op.ops.clear();
fused_op.edges.clear();
auto ntt = ++tflag_count;
for (int i=ll; i<rr; i++) {
int opid = fuse_ops[i];
Op* op = ops[opid];
uint64_t fid1 = fused_op.ops.size();
op->custom_data = fid1;
op->tflag = ntt;
fused_op.ops.push_back(op);
}
LOGvvv << "Prepare fused_op" << fused_op.ops;
fused_op.update_ops();
for (Op* op : fused_op.ops) {
uint fid1 = op->custom_data;
int iid = 0;
for (auto ve : op->_inputs) {
// this is a control dependency edge, dont used
if (ve.back->index<0) continue;
auto v = ve.node->var();
iid++;
int iop_id;
int iv_id;
if (v->_inputs.size() && v->input()->tflag == ntt) {
auto e = v->_inputs.front();
iop_id = e.node->custom_data;
iv_id = e.back->index;
} else {
iv_id = v->custom_data >> 2;
// add iv_id, prevent iv_id jit key overflow
iop_id = fused_op.ops.size() + iv_id;
}
fused_op.edges.emplace_back(iop_id, iv_id, fid1, iid-1);
}
// TODO: can we remove this?
// uint oid = 0;
// for (Var* v : op->outputs()) {
// oid++;
// if (v->tflag != tt) {
// // this var node not belong to current execution
// // this will happend in multiple outputs fuseable op
// // v->custom_data = 0 represents this var cannot be fused
// v->custom_data = 0;
// continue;
// }
// // for (auto o : v->outputs_with_index()) {
// // Op* op2 = o.op;
// // uint iid = o.index;
// // if (op2->tflag != ntt) continue;
// // uint fid2 = op2->custom_data;
// // fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
// // }
// }
}
}
static inline void propergate_needed_flags(FusedOp& fused_op) {
auto& ops = fused_op.ops;
for (int i=ops.size()-1; i>=0; i--) {
bool has_need = 0;
auto op = ops[i];
for (auto o : op->outputs())
if (o->flags.get(NodeFlags::_needed_by_backward) &&
!(o->custom_data&1)) {
has_need = 1;
}
if (has_need)
for (auto i : op->inputs()) {
i->flags.set(NodeFlags::_needed_by_backward);
}
}
}
void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) {
vector<Stack> stack;
if (is_fused_op) {
FusedOp& fused_op = *((FusedOp*)op);
logf >> "[OP TYPE]:" << "fused_op:(";
for (auto& op : fused_op.ops)
logf << op->name_ex() >> ",";
logf >> ")\n";
logf >> "[Input]:";
for (auto& vi : fused_op.vars)
if (vi.type == 0) logf << vi.var->dtype() >> vi.var->shape >> vi.var->name >> ",";
logf << "\n[Output]:";
Var* ov = nullptr;
for (auto& vi : fused_op.vars)
if (vi.type == 2) {
logf << vi.var->dtype() >> vi.var->shape >> vi.var->name >> ",";
ov = vi.var;
}
if (ov)
stack = get_node_trace(ov);
} else {
logf >> "[OP TYPE]:" << op->name_ex();
logf << "\n[Input]:";
for (auto v : op->inputs())
logf << v->dtype() >> v->shape >> v->name >> ",";
logf << "\n[Output]:";
Var* ov = nullptr;
for (auto v : op->outputs()) {
logf << v->dtype() >> v->shape >> v->name >> ",";
ov = v;
}
if (ov)
stack = get_node_trace(ov);
}
logf << "\n[Async Backtrace]:";
if (stack.size()) {
logf << "---";
for (auto& s : stack) {
logf << "\n " << s.file_path >> ":" >> s.lineno;
if (s.module_type.size()) logf << '<' >> s.module_type >> '>';
if (s.module_name.size() && s.module_name.find(":") == string::npos)
logf << '[' >> s.module_name >> ']';
}
} else
logf << "not found, please set env JT_SYNC=1, trace_py_var=3";
logf << "\n[Reason]:" << e.what();
jittor::LogFatalVoidify() && logf;
}
static void top_weak_sync(vector<Var*>& vars) {
auto t = ++tflag_count;
int64 max_id=0;
for (auto v : vars) {
if (v->is_finished()) continue;
max_id = std::max(v->id, max_id);
v->tflag = t;
}
while (true) {
if (sync_ptr == hold_vars.begin())
break;
auto next_ptr = std::prev(sync_ptr);
auto v = (*next_ptr)->var;
if (v->id > max_id) break;
sync_ptr = next_ptr;
if (v->tflag == t) continue;
if (v->_outputs.size()) continue;
if (v->is_finished()) continue;
vars.push_back(v);
}
}
extern void free_var_mem(Var* v);
VarHolder* get_output(Var* x) {
ASSERT(x->mem_ptr) << x;
VarPtr vp(x->shape, x->dtype());
vp->mem_ptr = x->mem_ptr;
vp->allocation = x->allocation;
vp->allocator = x->allocator;
vp->finish_pending_liveness();
x->mem_ptr = nullptr;
x->allocator = nullptr;
x->allocation = 0;
return new VarHolder(std::move(vp));
}
} // jittor
#include <cuda_runtime.h>
#include "common.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "ops/getitem_op.h"
namespace jittor {
inline static bool fast_strcmp(const char* a, const char* b) {
return ((const uint32*)a)[0] == ((const uint32*)b)[0];
}
inline static void get_shape_value(vector<Node*>& nodes, ShapeValue& k) {
auto add_shape = [&](NanoVector shape) {
k.values.push_back(shape.data);
k.values.push_back(shape.offset);
};
for (auto* node : nodes) {
if (node->is_var()) {
Var* v = (Var*)node;
add_shape(v->shape);
k.values.push_back(v->num);
k.values.push_back(v->size);
continue;
}
auto* op = node->op();
auto* name = op->name();
if (fast_strcmp(name, "array")) {
auto* op_ = (ArrayOp*)op;
if (op_->output->flags.get(NodeFlags::_force_fuse))
k.values.push_back(op_->ptr<uint64>()[0]);
} else
if (fast_strcmp(name, "code")) {
auto* op_ = (CodeOp*)op;
for (auto& kv : op_->data) {
double v = kv.second;
// bitwise copy
k.values.push_back(*(uint64*)&v);
}
} else
if (fast_strcmp(name, "getitem") ||
fast_strcmp(name, "setitem")) {
auto* op_ = (GetitemOp*)op;
for (int i=0; i<op_->vs.n; i++) {
auto& vs = op_->vs.slices[i];
if (vs.is_int() || vs.is_slice()) {
k.values.push_back(vs.slice.start);
k.values.push_back(vs.slice.stop);
k.values.push_back(vs.slice.step);
k.values.push_back(vs.slice.mask);
}
}
add_shape(op_->o_shape);
}
}
}
inline static void restore_shape_value(vector<Node*>& nodes, ShapeValue& k) {
int iter = 0;
auto pop_number = [&]() {
ASSERT(iter < k.values.size());
return k.values[iter++];
};
auto pop_shape = [&]() {
ASSERT(iter < k.values.size());
NanoVector nv;
nv.data = k.values[iter++];
nv.offset = k.values[iter++];
return nv;
};
for (auto* node : nodes) {
if (node->is_var()) {
Var* v = (Var*)node;
v->shape = pop_shape();
v->num = pop_number();
v->size = pop_number();
continue;
}
auto* op = node->op();
auto* name = op->name();
if (fast_strcmp(name, "array")) {
auto* op_ = (ArrayOp*)op;
if (op_->output->flags.get(NodeFlags::_force_fuse))
op_->ptr<uint64>()[0] = pop_number();
} else
if (fast_strcmp(name, "code")) {
auto* op_ = (CodeOp*)op;
for (auto& kv : op_->data) {
double& v = kv.second;
// bitwise copy
*(uint64*)&v = pop_number();
}
} else
if (fast_strcmp(name, "getitem") ||
fast_strcmp(name, "setitem")) {
auto* op_ = (GetitemOp*)op;
for (int i=0; i<op_->vs.n; i++) {
auto& vs = op_->vs.slices[i];
if (vs.is_int() || vs.is_slice()) {
vs.slice.start = pop_number();
vs.slice.stop = pop_number();
vs.slice.step = pop_number();
vs.slice.mask = pop_number();
}
}
op_->o_shape = pop_shape();
op->graph_optimize();
}
}
}
SGraphPtr build_sgraph(const vector<VarHolder*>& outputs, const vector<VarHolder*>& inputs) {
vector<Var*> vars;
vars.reserve(outputs.size());
for (auto* vh : outputs)
vars.push_back(vh->var);
bool weak_sync = false;
if (weak_sync && !use_threading)
top_weak_sync(vars);
auto allocator = get_allocator();
auto temp_allocator = get_allocator(true);
exe.allocator = allocator;
exe.temp_allocator = temp_allocator;
auto& last_is_cuda = exe.last_is_cuda;
// bfs find all ops need to run
int op_num = 0;
vector<Node*> bfs_q;
bfs_q.reserve(vars.size());
int start_var_num = 0;
while (1) {
op_num = 0;
start_var_num = 0;
bfs_q.clear();
// get all nodes need to be executed
int need_opt = 0;
auto t = ++tflag_count;
int64 max_id = 0;
for (Var* v : vars)
if (!v->is_finished() && v->tflag != t) {
v->tflag = t;
start_var_num++;
bfs_q.push_back(v);
max_id = std::max(max_id, v->id);
}
for (int i=0; i<bfs_q.size(); i++) {
auto node = bfs_q[i];
op_num += !node->is_var();
for (auto i : node->_inputs)
if (i.node->tflag != t && !i.node->is_finished()) {
i.node->tflag = t;
need_opt += i.node->flags.get(NodeFlags::_has_gopt);
bfs_q.push_back(i.node);
}
// this var has been fetched
if (weak_sync || node->flags.get(NodeFlags::_fetch)) {
for (auto& n : node->_outputs) {
// if not in queue and is fetch op
if (n.node->tflag != t &&
n.node->pending_liveness &&
!n.node->is_finished() &&
(n.node->id <= max_id ||
n.node->flags.get(NodeFlags::_fetch))) {
n.node->tflag = t;
need_opt += n.node->flags.get(NodeFlags::_has_gopt);
bfs_q.push_back(n.node);
}
}
}
}
if (!need_opt || gopt_disable) break;
for (Node* n : bfs_q) {
if (n->flags.get(NodeFlags::_has_gopt)) {
n->op()->graph_optimize();
n->flags.set(NodeFlags::_has_gopt, 0);
}
}
}
auto tt = tflag_count;
vector<Op*> ops;
vector<Var*> all_vars;
ops.reserve(op_num);
all_vars.reserve(bfs_q.size() - op_num);
for (Node* node : bfs_q)
if (!node->is_var()) {
node->custom_data = ops.size();
ops.push_back(node->op());
} else {
// set can't fuse flag to false
node->custom_data = all_vars.size();
all_vars.push_back(node->var());
}
int var_num = all_vars.size();
// father: father of union-find set
vector<int> father(op_num);
for (int i=0; i<op_num; i++) {
father[i] = i;
}
// union-find algorithm
auto find_fa = [&](int i) -> int {
int j=i;
while (father[j] != j) j = father[j];
while (i != j) {
int tmp = father[i];
father[i] = j;
i = tmp;
}
return j;
};
vector<int> var_fused(var_num);
if (V_ON(100)) {
for (uint i=0; i<ops.size(); i++) {
Op* op = ops[i];
string st="others";
if (op->type()==OpType::reduce) st="reduce";
if (op->type()==OpType::broadcast) st="broadcast";
if (op->type()==OpType::element) st="element";
LOGvvv << "id:" << ops[i]->custom_data << " type:" <<
st << " addr:" << op;
for (Var* v : op->inputs()) {
Op* next_op = v->input();
// continue if is boundary
if (!next_op || next_op->tflag != tt) {
LOGvvv << "input:" << v;
continue;
}
LOGvvv << "input:" << next_op->custom_data << " addr:" << next_op;
}
LOGvvv << "";
}
}
count_fuse(tt, start_var_num, ops, all_vars, father, var_fused);
// var_fused represents:
// 0: can fused
// 1: cannot fused
// 2: weak shared(may turn into 1 or 3 by shared operator cutting)
// 3: strong shared(force shared)
vector<int> roots, next(op_num, -1);
vector<int> deps(op_num, 0);
roots.reserve(op_num);
for (int i=0; i<op_num; i++) {
int fa = find_fa(i);
if (fa == i)
roots.push_back(i);
else {
next[i] = next[fa];
next[fa] = i;
}
}
vector<int> queue;
queue.reserve(roots.size());
// ** toplogical_sort external **
// output:
// queue: toplogical order of fused op
{
// queue.clear();
#ifndef JT_bfs_executor
std::priority_queue<pair<int64,int64>> p_queue;
#endif
for (int root : roots) {
for (int i=root; i>=0; i=next[i]) {
Op* op = ops[i];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
Op* opi = v->input();
// if those two ops are not fused
if (father[opi->custom_data] != root) {
deps[root]++;
}
}
}
#ifdef JT_bfs_executor
if (deps[root] == 0)
queue.push_back(root);
#else
if (deps[root] == 0)
p_queue.emplace(-ops[root]->order(), root);
#endif
}
#ifdef JT_bfs_executor
for (uint s=0; s<queue.size(); s++)
#else
while (p_queue.size())
#endif
{
#ifdef JT_bfs_executor
int op_id = queue[s];
#else
int op_id = p_queue.top().second;
p_queue.pop();
queue.push_back(op_id);
#endif
for (int i=op_id; i>=0; i=next[i]) {
Op* op = ops[i];
for (Var* v : op->outputs())
{
if (v->tflag == tt)
for (Op* op2 : v->outputs())
{
if (op2->tflag != tt) continue;
int op2_id = father[op2->custom_data];
// continue if those two ops are fused
if (op2_id == op_id) continue;
deps[op2_id]--;
#ifdef JT_bfs_executor
if (deps[op2_id] == 0)
queue.push_back(op2_id);
#else
if (deps[op2_id] == 0)
p_queue.emplace(-op2->order(), op2_id);
#endif
}
}
}
}
ASSERTop(queue.size(),==,roots.size());
}
// ** toplogical_sort internal **
// output:
// fuse_ops: fused op id [000|1111|22|3333]
// range: split index ^ ^ ^ ^ ^
vector<int> fuse_ops;
fuse_ops.reserve(op_num*2);
vector<int> range(queue.size());
{
vector<int> subgraph;
subgraph.reserve(16);
vector<int> sharegraph;
sharegraph.reserve(16);
vector<int> sharegraph_q;
sharegraph_q.reserve(16);
vector<int> shared_id(op_num, -1);
// for fused op in reversed order
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[queue.size()-rid-1];
auto& queue = subgraph;
queue.clear();
sharegraph.clear();
int total=0;
for (int i=root; i>=0; i=next[i], total++) {
Op* op = ops[i];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
Op* opi = v->input();
// if those two ops are fused
int opid = opi->custom_data;
auto fopid = father[opid];
if (fopid == root)
deps[i]++;
else if (shared_id[opid] != root) {
auto& vf = var_fused[v->custom_data];
// var_fused = 1 cannot share input op
// TODO: check this input op's output var all can be shared
if (vf == 1)
continue;
// if weak share, turn into strong share
if (vf == 2) vf = 3;
// new shared op
deps[opid] = 0;
shared_id[opid] = root;
sharegraph.push_back(opid);
}
}
if (deps[i] == 0)
queue.push_back(i);
}
// find all share graph
uint sn = sharegraph.size();
for (uint i=0; i<sharegraph.size(); i++) {
int id = sharegraph[i];
Op* op = ops[id];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
int vi = v->custom_data;
if (var_fused[vi] == 1)
continue;
// if weak share, cut off
if (var_fused[vi] == 2) {
if (sharegraph.size() - sn < 32)
var_fused[vi] = 3;
else {
var_fused[vi] = 1;
continue;
}
}
Op* opi = v->input();
int opid = opi->custom_data;
int& dep = deps[opid];
if (shared_id[opid] != root) {
shared_id[opid] = root;
dep = 1;
sharegraph.push_back(opid);
} else
dep ++;
}
}
sharegraph_q.clear();
for (uint i=0; i<sn; i++)
if (deps[sharegraph[i]]==0)
sharegraph_q.push_back(sharegraph[i]);
// topsort in sharegraph_q
for (uint i=0; i<sharegraph_q.size(); i++) {
int id = sharegraph_q[i];
Op* op = ops[id];
for (Var* v : op->inputs()) {
if (v->tflag != tt) continue;
int vi = v->custom_data;
if (var_fused[vi] == 1)
continue;
Op* opi = v->input();
int opid = opi->custom_data;
int& dep = deps[opid];
dep --;
if (dep == 0)
sharegraph_q.push_back(opid);
}
}
LOGvvvv << "sharegraph_q" << sharegraph_q;
ASSERTop(sharegraph.size(),==,sharegraph_q.size());
// topsort fused op internal
for (uint s=0; s<queue.size(); s++) {
int i = queue[s];
Op* op = ops[i];
for (Var* v : op->outputs())
if (v->tflag == tt)
for (Op* op2 : v->outputs()) {
if (op2->tflag != tt) continue;
int op2_id = op2->custom_data;
// continue if those two ops are not fused
if (father[op2_id] != root) continue;
deps[op2_id]--;
if (deps[op2_id] == 0)
queue.push_back(op2_id);
}
}
ASSERTop(queue.size(),==,(uint)total);
LOGvvvv << "topsort internal" << queue;
for (int i=(int)sharegraph_q.size()-1; i>=0; i--)
fuse_ops.push_back(sharegraph_q[i]);
for (uint i=0; i<queue.size(); i++)
fuse_ops.push_back(queue[i]);
range[rid] = fuse_ops.size();
}
}
for (int i=0; i<var_num; i++) {
all_vars[i]->custom_data = var_fused[i]==1;
}
FusedOp fused_op;
// compile all ops, prevent compiling during running
parallel_compile_all_ops(queue, range, fused_op, fuse_ops, ops, tt, true);
// flags
std::sort(bfs_q.begin(), bfs_q.end(), [&](Node* x, Node* y) { return x->id<y->id; });
unordered_map<Var*,pair<Var*,uint64>> share_map;
auto min_id = bfs_q.front()->id;
auto max_id = bfs_q.back()->id;
vector<char> flags(max_id-min_id+1);
constexpr int is_output = 0;
constexpr int is_new_var = 1;
constexpr int is_share = 2;
auto lived = [&](Node* n) { return n->id>=min_id && n->id<=max_id; };
auto get_flags = [&](Node* n, int f) -> int {
if (!lived(n)) return 0;
return (flags[n->id-min_id]>>f)&1;
};
auto set_flags = [&](Node* n, int f) {
if (!lived(n)) return;
flags[n->id-min_id] |= (1<<f);
};
for (auto v : vars) {
set_flags(v, is_output);
}
for (auto v : all_vars) {
set_flags(v, is_new_var);
if (v->allocator) {
share_map[v] = std::make_pair((Var*)v->allocator, v->allocation);
set_flags(v, is_share);
}
}
// build fused ops
vector<FusedOp> fused_ops(queue.size());
vector<Op*> rid_ops(queue.size());
vector<int> v_last_rid(max_id-min_id+1, -1);
vector<jit_op_entry_t> jit_entries(queue.size());
auto& jkl = get_jk();
for (uint rid=0; rid<queue.size(); rid++) {
int root = queue[rid];
Op* op = ops[root];
bool is_fused_op = false;
if (op->type() != OpType::other) {
auto& fused_op = fused_ops[rid];
op = &fused_op;
is_fused_op = true;
int ll = (rid<queue.size()-1)?range[queue.size()-rid-2]:0, rr = range[queue.size()-rid-1];
root = fuse_ops[rr-1];
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
op->do_prepare(jkl);
jit_entries[rid] = (jit_op_entry_t)&FusedOp::do_run;
} else {
op->do_prepare(jkl);
if (!jkl.empty()) {
const char* jit_key = jkl.to_cstring();
auto iter = jit_ops.find(jit_key);
ASSERT(iter != jit_ops.end()) << jit_key << op << rid;
jit_entries[rid] = iter->second;
} else {
jit_entries[rid] = (jit_op_entry_t)&Op::run;
}
}
rid_ops[rid] = op;
for (auto v : op->inputs())
if (get_flags(v, is_new_var))
v_last_rid[v->id-min_id] = rid;
}
SGraphPtr sgraph_ptr;
sgraph_ptr.ptr = std::make_unique<SGraph>();
auto& g = *sgraph_ptr.ptr;
g.outputs.reserve(outputs.size());
for (auto v : outputs) {
g.outputs.push_back(v->var);
}
g.inputs.reserve(inputs.size());
for (auto v : inputs) {
g.inputs.push_back(v->var);
}
g.bfs_q = std::move(bfs_q);
g.share_map = std::move(share_map);
g.flags = std::move(flags);
g.fused_ops = std::move(fused_ops);
g.rid_ops = std::move(rid_ops);
g.v_last_rid = std::move(v_last_rid);
ShapeKey key;
key.shapes.reserve(inputs.size());
for (auto v : inputs) {
key.shapes.push_back(v->var->shape);
}
ShapeValue& value = g.shape_values[key];
get_shape_value(g.bfs_q, value);
auto prev_size = value.values.size();
value.values.resize(value.values.size() + jit_entries.size());
memcpy(&value.values[prev_size], &jit_entries[0], jit_entries.size()*sizeof(jit_op_entry_t));
g.shape_value_len = value.values.size();
return sgraph_ptr;
}
bool prob_sgraph(SGraphPtr* sgraph, const vector<VarHolder*>& inputs) {
// return true;
ShapeKey key;
key.shapes.reserve(inputs.size());
for (auto v : inputs) {
key.shapes.push_back(v->var->shape);
}
auto& g = *sgraph->ptr;
auto it = g.shape_values.find(key);
if (it == g.shape_values.end()) return false;
return true;
}
void merge_sgraph(SGraphPtr* sgraph, SGraphPtr* sgraph2) {
auto& g1 = *sgraph->ptr;
auto& g2 = *sgraph2->ptr;
ASSERT(g1.outputs.size() == g2.outputs.size());
ASSERT(g1.inputs.size() == g2.inputs.size());
ASSERTop(g1.bfs_q.size(),==,g2.bfs_q.size());
ASSERT(g1.share_map.size() == g2.share_map.size());
ASSERT(g1.flags.size() == g2.flags.size());
ASSERT(g1.fused_ops.size() == g2.fused_ops.size());
ASSERT(g1.rid_ops.size() == g2.rid_ops.size());
ASSERT(g1.v_last_rid.size() == g2.v_last_rid.size());
ASSERT(g1.shape_value_len == g2.shape_value_len);
for (int i=0; i<g1.bfs_q.size(); i++) {
auto n1 = g1.bfs_q[i];
auto n2 = g2.bfs_q[i];
ASSERT(n1->is_var() == n2->is_var());
if (n1->is_var()) {
ASSERT(n1->var()->shape.size() == n2->var()->shape.size());
ASSERT(n1->var()->dtype() == n2->var()->dtype());
} else {
ASSERT(fast_strcmp(n1->op()->name(), n2->op()->name()) == 1);
}
}
for (auto& kv : g2.shape_values) {
g1.shape_values[kv.first] = kv.second;
}
}
vector<VarHolder*> exec_sgraph(SGraphPtr* sgraph, const vector<VarHolder*>& inputs) {
ShapeKey key;
key.shapes.reserve(inputs.size());
for (auto v : inputs) {
key.shapes.push_back(v->var->shape);
}
auto& g = *sgraph->ptr;
auto it = g.shape_values.find(key);
ASSERT(it != g.shape_values.end());
auto& value = it->second;
restore_shape_value(g.bfs_q, value);
vector<jit_op_entry_t> jit_entries(g.rid_ops.size());
memcpy(&jit_entries[0], &value.values[value.values.size() - jit_entries.size()], jit_entries.size()*sizeof(jit_op_entry_t));
ASSERT(inputs.size() == g.inputs.size());
for (int i=0; i<inputs.size(); i++) {
auto* v2 = inputs[i]->var;
auto* v = g.inputs[i];
if (v != v2) {
if (v->mem_ptr) {
free_var_mem(v);
}
ASSERT(v2->mem_ptr);
v->mem_ptr = v2->mem_ptr;
v->allocator = v2->allocator;
v->allocation = v2->allocation;
v->shape = v2->shape;
v->num = v2->num;
v->size = v2->size;
v->allocator->share_with(v->size, v->allocation);
}
}
auto allocator = get_allocator();
auto temp_allocator = get_allocator(true);
exe.allocator = allocator;
exe.temp_allocator = temp_allocator;
auto& last_is_cuda = exe.last_is_cuda;
vector<Var*>& vars = g.outputs;
vector<Node*>& bfs_q = g.bfs_q;
unordered_map<Var*,pair<Var*,uint64>>& share_map = g.share_map;
vector<char>& flags = g.flags;
vector<FusedOp>& fused_ops = g.fused_ops;
vector<Op*>& rid_ops = g.rid_ops;
vector<int>& v_last_rid = g.v_last_rid;
constexpr int is_output = 0;
constexpr int is_new_var = 1;
constexpr int is_share = 2;
auto min_id = bfs_q.front()->id;
auto max_id = bfs_q.back()->id;
auto lived = [&](Node* n) { return n->id>=min_id && n->id<=max_id; };
auto get_flags = [&](Node* n, int f) -> int {
if (!lived(n)) return 0;
return (flags[n->id-min_id]>>f)&1;
};
auto set_flags = [&](Node* n, int f) {
if (!lived(n)) return;
flags[n->id-min_id] |= (1<<f);
};
// running
SetupFreeBuffer setup_free_buffer;
#ifdef HAS_CUDA
int sync_times = 0;
#endif
auto& jkl = get_jk();
for (uint rid=0; rid<rid_ops.size(); rid++) {
Op* op = rid_ops[rid];
bool is_fused_op = op->type() != OpType::other;
try {
for (auto* var : op->outputs())
var->alloc(allocator);
if (PREDICT_BRANCH_NOT_TAKEN(profile_memory_enable))
memory_profiler.check();
LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs();
// op->do_prepare(jkl);
bool is_cuda = op->flags.get(NodeFlags::_cuda);
#ifdef HAS_CUDA
if (!is_cuda) {
if (last_is_cuda) {
// if prev op in gpu and this op in cpu
// cuda sync
checkCudaErrors(cudaDeviceSynchronize());
sync_times++;
}
for (Var* v : op->inputs()) {
if (v->allocator->is_cuda())
migrate_to_cpu(v, allocator);
}
if (!use_cuda_managed_allocator) {
for (auto* var : op->outputs())
if (var->allocator->is_cuda())
migrate_to_cpu(var, allocator);
}
} else {
for (Var* v : op->inputs()) {
if (!v->allocator->is_cuda())
migrate_to_gpu(v, allocator);
}
for (Var* v : op->outputs()) {
if (!v->allocator->is_cuda())
migrate_to_gpu(v, allocator);
}
}
#endif
last_is_cuda = is_cuda;
// _JT_SEH_START2;
if (profiler_enable)
op->do_run();
else {
jit_op_entry_t& jit_entry = jit_entries[rid];
jit_entry(op);
}
// _JT_SEH_END2;
#ifdef HAS_CUDA
// migrate to gpu
if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) {
for (Var* v : op->outputs()) {
migrate_to_gpu(v, allocator);
}
}
#endif
#ifdef JT_CHECK_NAN
for (Var* var : op->outputs())
check_nan(var, op);
#endif
#ifdef JT_SYNC
#ifdef HAS_CUDA
checkCudaErrors(cudaGetLastError());
checkCudaErrors(cudaDeviceSynchronize());
#endif
#endif
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> rid_ops.size() >> ") output:" << op->outputs();
for (Var* v : op->inputs())
if (get_flags(v, is_new_var) && !get_flags(v, is_output) && v_last_rid[v->id-min_id] == rid) {
if (v->mem_ptr)
free_var_mem(v);
if (get_flags(v, is_share)) {
// recover share var
auto kv = share_map.find(v)->second;
v->allocator = (Allocator*)kv.first;
v->allocation = kv.second;
}
}
for (Var* v : op->outputs()) {
if (!get_flags(v, is_new_var) && !get_flags(v, is_output) && v->mem_ptr) {
// this output is not used in this graph, so we free it directly
free_var_mem(v);
}
}
} catch (const std::exception& e) {
// log memory info
display_memory_info(__FILELINE__, false, true);
// log jit_key and file location
op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
jittor::Log logf(__FILELINE__, 'f', 0);
logf << "\nExecute fused operator(" >> rid >> '/' >> rid_ops.size() >> ")"
<< "failed.";
if (jit_compiler::file_exist(jit_src_path))
logf << "\n[JIT Source]:" << jit_src_path << "\n";
check_op_async_error(op, is_fused_op, e, logf);
}
}
for (Var* v : vars) ASSERT(v->mem_ptr || v->flags.get(NodeFlags::_is_swapped) || !v->backward_liveness) << v;
// clean fetcher free buffer
// fetcher_to_free.clear();
#ifdef HAS_CUDA
event_queue.flush();
#endif
vector<VarHolder*> ret;
ret.reserve(vars.size());
for (Var* v : vars) {
ASSERT(get_flags(v, is_new_var));
ret.push_back(get_output(v));
if (get_flags(v, is_share)) {
// recover share var
auto kv = share_map.find(v)->second;
v->allocator = (Allocator*)kv.first;
v->allocation = kv.second;
}
}
return ret;
}
vector<VarHolder*> delay_fetch(const vector<VarHolder*>& inputs) {
static vector<VarPtr> prev_vars;
static cudaEvent_t event;
static bool init = false;
if (!init) {
init = true;
checkCudaErrors(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
}
sync(inputs);
vector<VarHolder*> ret;
ret.reserve(prev_vars.size());
for (auto& v : prev_vars) {
ret.push_back(new VarHolder(move(v)));
}
prev_vars.clear();
prev_vars.reserve(inputs.size());
for (auto& v : inputs) {
VarPtr vp(v->var->shape, v->var->dtype());
vp->alloc(cpu_allocator);
vp->finish_pending_liveness();
cudaMemcpyAsync(vp->mem_ptr, v->var->mem_ptr, v->var->size, cudaMemcpyDeviceToHost, 0);
prev_vars.emplace_back(move(vp));
}
cudaEventSynchronize(event);
cudaEventRecord(event, 0);
return ret;
}
}