mirror of https://github.com/Jittor/Jittor
99 lines
2.8 KiB
C++
99 lines
2.8 KiB
C++
// ***************************************************************
|
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
// Maintainers:
|
|
// Dun Liang <randonlang@gmail.com>.
|
|
//
|
|
// This file is subject to the terms and conditions defined in
|
|
// file 'LICENSE.txt', which is part of this source code package.
|
|
// ***************************************************************
|
|
#include "var.h"
|
|
#include "ops/array_op.h"
|
|
#include "ops/op_register.h"
|
|
#include "ops/tape_op.h"
|
|
|
|
namespace jittor {
|
|
|
|
TapeOp::TapeOp(Var* x) {
|
|
flags.set(NodeFlags::_cpu);
|
|
flags.set(NodeFlags::_cuda);
|
|
flags.set(NodeFlags::_manual_set_vnbb);
|
|
create_output(nullptr, x->dtype());
|
|
}
|
|
|
|
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|
return dout;
|
|
}
|
|
|
|
void TapeOp::infer_shape() {
|
|
auto x = inputs().front();
|
|
auto y = outputs().front();
|
|
y->set_shape(x->shape);
|
|
y->share_with(x);
|
|
}
|
|
|
|
void Tapes::grads(Var** douts, VarPtr* dins) {
|
|
CHECK(callback.deleter);
|
|
try {
|
|
callback.func(_outputs.size(), douts, _inputs.size(), dins);
|
|
} catch (...) {
|
|
// if error occur in callback, we need to
|
|
// free it to prevent memory leak, but this is still
|
|
// not enough, error may occur outside. please
|
|
// find a better solution
|
|
callback.deleter();
|
|
callback.deleter = nullptr;
|
|
throw;
|
|
}
|
|
}
|
|
|
|
Tapes::Tapes(
|
|
const vector<VarHolder*>& taped_inputs,
|
|
const vector<VarHolder*>& taped_outputs,
|
|
GradCallback&& grad_callback
|
|
) {
|
|
flags.set(NodeFlags::_cpu);
|
|
flags.set(NodeFlags::_cuda);
|
|
flags.set(NodeFlags::_grads);
|
|
flags.set(NodeFlags::_manual_set_vnbb);
|
|
callback = move(grad_callback);
|
|
|
|
|
|
/*
|
|
stop grad stop grad
|
|
i --> tape --> t_i ---> .... ---> o --> tape --> t_o
|
|
| ^
|
|
+---> tapes ------------------------------+
|
|
*/
|
|
// set tape output
|
|
for (int i=0; i<taped_outputs.size(); i++) {
|
|
VarPtr out(0, ns_float32);
|
|
out->add_inputs({this});
|
|
auto v = taped_outputs[i]->var;
|
|
auto op = v->input();
|
|
op->add_inputs(vector<Node*>{out.ptr});
|
|
}
|
|
// set tapes input
|
|
vector<Var*> tin(taped_inputs.size());
|
|
for (int i=0; i<taped_inputs.size(); i++) {
|
|
tin[i] = taped_inputs[i]->var->input()->inputs().front();
|
|
}
|
|
add_inputs(tin);
|
|
// stop grad for input and output
|
|
for (int i=0; i<taped_inputs.size(); i++) {
|
|
taped_inputs[i]->var->set_stop_grad();
|
|
}
|
|
for (int i=0; i<taped_outputs.size(); i++) {
|
|
taped_outputs[i]->var->input()->inputs().front()->set_stop_grad();
|
|
}
|
|
}
|
|
|
|
void tape_together(
|
|
const vector<VarHolder*>& taped_inputs,
|
|
const vector<VarHolder*>& taped_outputs,
|
|
GradCallback&& grad_callback
|
|
) {
|
|
new Tapes(taped_inputs, taped_outputs, move(grad_callback));
|
|
}
|
|
|
|
} // jittor
|