This commit is contained in:
Gword 2020-12-28 16:33:17 +08:00
parent a1c4785a6b
commit 64ddd81f63
3 changed files with 8 additions and 2 deletions

View File

@ -11,6 +11,12 @@ import jittor as jt
import numpy as np
import math
def eye(shape, dtype):
return jt.array(np.identity(shape[0])).unary(dtype)
def eye_(var):
var.assign(eye(var.shape, var.dtype))
def constant(shape, dtype, value=0.0):
return jt.array(value).unary(dtype).broadcast(shape)

View File

@ -211,7 +211,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
if (id>=0)
grad = move(grads[id]);
if (!grad) {
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
// LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
grad = make_number(0.f, var);
assign_attrs(grad.ptr, var);
}

View File

@ -201,7 +201,7 @@ void GetitemOp::graph_optimize() {
// setitem_grad_opt(this);
(void)setitem_grad_opt;
// (void)getitem_inplace;
getitem_inplace(this);
// getitem_inplace(this);
(void)getitem_inplace;
}