polish setitem optimize

This commit is contained in:
Dun Liang 2021-07-20 21:14:55 +08:00
parent e1472a7a8f
commit 69979f71e4
10 changed files with 99 additions and 11 deletions

View File

@ -1,10 +1,10 @@
jittor.attention
jittor.loss3d
=====================
这里是Jittor的 3d 损失函数 模块的API文档您可以通过`from jittor import loss3d`来获取该模块。
```eval_rst
.. automodule:: jittor.loss3d
:members:
:members: chamfer_loss, ChamferLoss, earth_mover_distance, EarthMoverDistance
:undoc-members:
```

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.70'
__version__ = '1.2.3.71'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -1333,3 +1333,4 @@ from . import numpy2cupy
from .contrib import concat
from .misc import *
from . import sparse
from . import optim

View File

@ -37,6 +37,7 @@ class Optimizer(object):
# __zero_grad is a value for fast determ the grad is zero or not
# so we can omit 0+x
self.__zero_grad = True
self._grad_map = {}
def add_param_group(self, group):
self.param_groups.append(group)
@ -189,6 +190,35 @@ class Optimizer(object):
p.update(p - g * lr)
self.zero_grad()
def _build_grad_map(self):
_grad_map = {}
for pg in self.param_groups:
for p, g in zip(pg["params"], pg["grads"]):
_grad_map[id(p)] = g
self._grad_map = _grad_map
def find_grad(self, v:jt.Var) -> jt.Var:
if id(v) not in self._grad_map:
self._build_grad_map()
if id(v) not in self._grad_map:
raise RuntimeError("This variable is not managed by this optimizer")
return self._grad_map[id(v)]
def opt_grad(v:jt.Var, opt:Optimizer):
''' Get grad of certain variable in optimizer, Example::
model = Model()
optimizer = SGD(model.parameters())
...
optimizer.backward(loss)
for p in model.parameters():
grad = p.opt_grad(optimizer)
'''
return opt.find_grad(v)
jt.Var.opt_grad = opt_grad
class SGD(Optimizer):
""" SGD Optimizer.

View File

@ -667,7 +667,7 @@ def compile_src(src, h, basename):
arr_func_return.append(f"return ({func_call},0)")
func_return_failed = "return -1"
else:
assert "-> void" in func_head
assert "-> void" in func_head, func_head
arr_func_return.append(f"{func_call};{before_return}return")
func_return_failed = "return"
# generate error msg when not a valid call

View File

@ -47,6 +47,8 @@ static void setitem_inplace(SetitemOp* op) {
return;
}
auto output = op->outputs().front();
// return if output is all ready shared
if (output->allocator) return;
output->share_with(input);
auto data = op->input(1);
@ -78,13 +80,13 @@ static void setitem_inplace(SetitemOp* op) {
VarSlice s = vs.slices[i];
if (!(s.is_slice())) return;
Slice ss = s.slice;
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
if (!(ss.start == 0 && (ss.mask&2) && ss.step == 1))
return;
inplace_size *= in_shape[i];
}
VarSlice s = vs.slices[0];
if (s.is_var()) return;
if (s.is_var() || s.is_str()) return;
auto size = 0;
if (s.is_int())
@ -175,7 +177,10 @@ static void getitem_inplace(GetitemOp* op) {
auto in = op->inputs().front();
auto ou = op->outputs().front();
// return if out is all ready inplaced
if (ou->allocator)
return;
// return if input or output's shape is variable
if (in->num <= 0 || ou->num <= 0)
return;
@ -192,7 +197,7 @@ static void getitem_inplace(GetitemOp* op) {
}
VarSlice s = vs.slices[0];
if (s.is_var()) return;
if (s.is_var() || s.is_str()) return;
auto size = 0;
if (s.is_int())
@ -214,7 +219,7 @@ void SetitemOp::graph_optimize() {
void GetitemOp::graph_optimize() {
// This optimize is still WIP
// LOGir << "hello getitem graph_optimize";
setitem_grad_opt(this);
// setitem_grad_opt(this);
(void)setitem_grad_opt;
// (void)getitem_inplace;
getitem_inplace(this);

View File

@ -207,4 +207,23 @@ string VarHolder::debug_msg() {
return ss.str();
}
int VarHolder::grad() {
LOGf << R""(Jittor Var doesn't have this interface, please change
your code as below::
model = Model()
optimizer = SGD(model.parameters())
...
optimizer.backward(loss)
for p in model.parameters():
# prev code:
# grad = p.grad
# change to:
grad = p.opt_grad(optimizer)
)"";
return 0;
}
} // jittor

View File

@ -264,6 +264,23 @@ struct VarHolder {
*/
// @pyjt(debug_msg)
string debug_msg();
/* Jittor Var doesn't have this interface, please change your code as below::
model = Model()
optimizer = SGD(model.parameters())
...
optimizer.backward(loss)
for p in model.parameters():
# prev code:
# grad = p.grad
# change to:
grad = p.opt_grad(optimizer)
*/
// @pyjt(__get__grad)
int grad();
};
// @pyjt(sync)

View File

@ -41,6 +41,14 @@ class TestOptimizer(unittest.TestCase):
# print(s)
opt.load_state_dict(s)
def test_opt_grad(self):
a = jt.ones(2)
opt = jt.optim.SGD([a], 0.1)
opt.backward(a**2)
g = a.opt_grad(opt)
np.testing.assert_allclose(g.data, 2)
if __name__ == "__main__":

View File

@ -13,7 +13,7 @@ skip_this_test = False
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSetitem(unittest.TestCase):
def test_setitem(self):
def test_setitem_(self):
arr0 = jt.random((4,2,2))
data0 = jt.ones((2,2))
arr0[1] = data0
@ -32,7 +32,7 @@ class TestSetitem(unittest.TestCase):
arr1 = jt.random((4,2,2))
data1 = jt.zeros((2,2))
arr1[3,:,0:2] = data1
arr1[3,:,:] = data1
arr1.sync()
data1.data[0,0] = 1
assert arr1[3,0,0] == 1

View File

@ -956,6 +956,14 @@ class ToTensor:
return self.__class__.__name__ + '()'
class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.
Args:
pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
def __init__(self, mode=None):
self.mode = mode