mirror of https://github.com/Jittor/Jittor
fix bug
This commit is contained in:
parent
a1c4785a6b
commit
64ddd81f63
|
@ -11,6 +11,12 @@ import jittor as jt
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
def eye(shape, dtype):
|
||||
return jt.array(np.identity(shape[0])).unary(dtype)
|
||||
|
||||
def eye_(var):
|
||||
var.assign(eye(var.shape, var.dtype))
|
||||
|
||||
def constant(shape, dtype, value=0.0):
|
||||
return jt.array(value).unary(dtype).broadcast(shape)
|
||||
|
||||
|
|
|
@ -211,7 +211,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
if (id>=0)
|
||||
grad = move(grads[id]);
|
||||
if (!grad) {
|
||||
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
|
||||
// LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
|
||||
grad = make_number(0.f, var);
|
||||
assign_attrs(grad.ptr, var);
|
||||
}
|
||||
|
|
|
@ -201,7 +201,7 @@ void GetitemOp::graph_optimize() {
|
|||
// setitem_grad_opt(this);
|
||||
(void)setitem_grad_opt;
|
||||
// (void)getitem_inplace;
|
||||
getitem_inplace(this);
|
||||
// getitem_inplace(this);
|
||||
(void)getitem_inplace;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue