This commit is contained in:
Gword 2020-07-14 14:51:35 +08:00
commit ef918ccbf7
5 changed files with 7 additions and 3 deletions

View File

@ -273,6 +273,7 @@ Var.start_grad = Var.detach_inplace = detach_inplace
def unsqueeze(x, dim):
shape = list(x.shape)
if dim < 0: dim += len(shape) + 1
assert dim <= len(shape)
return x.reshape(shape[:dim] + [1] + shape[dim:])
Var.unsqueeze = unsqueeze

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh:
setuptools.setup(
name='jittor',
version='1.1.5.2',
version='1.1.5.4',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",

View File

@ -92,7 +92,10 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
Op* op = it.op;
auto index = it.index;
if (op->tflag != nt) continue;
<<<<<<< HEAD
// TODO: support two outputs backprop.
=======
>>>>>>> b27082f9444a4e627f7dfc574d0114302ba27b5e
for (Var* out : op->outputs()) {
if (out->tflag != nt) continue;
Var* dout = grads[out->custom_data];

View File

@ -303,10 +303,10 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
auto arr = (PyArray_Proxy*)holder.obj;
int64 size = PyArray_Size(arr);
T args;
args.ptr = arr->data;
args.shape = vector<int64>(arr->dimensions, arr->dimensions+arr->nd);
args.dtype = get_type_str(arr);
args.buffer.reset(new char[size]);
args.ptr = (void*)args.buffer.get();
memcpy((void*)args.buffer.get(), (void*)arr->data, size);
return args;
}

View File

@ -106,7 +106,7 @@ struct VarHolder {
/* detach the grad */
// @pyjt(detach)
inline VarHolder* detach() {
return new VarHolder(move(jittor::detach(var)));
return new VarHolder(jittor::detach(var));
}