mirror of https://github.com/Jittor/Jittor
polish setitem optimize
This commit is contained in:
parent
e1472a7a8f
commit
69979f71e4
|
@ -1,10 +1,10 @@
|
|||
jittor.attention
|
||||
jittor.loss3d
|
||||
=====================
|
||||
|
||||
这里是Jittor的 3d 损失函数 模块的API文档,您可以通过`from jittor import loss3d`来获取该模块。
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor.loss3d
|
||||
:members:
|
||||
:members: chamfer_loss, ChamferLoss, earth_mover_distance, EarthMoverDistance
|
||||
:undoc-members:
|
||||
```
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.70'
|
||||
__version__ = '1.2.3.71'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -1333,3 +1333,4 @@ from . import numpy2cupy
|
|||
from .contrib import concat
|
||||
from .misc import *
|
||||
from . import sparse
|
||||
from . import optim
|
||||
|
|
|
@ -37,6 +37,7 @@ class Optimizer(object):
|
|||
# __zero_grad is a value for fast determ the grad is zero or not
|
||||
# so we can omit 0+x
|
||||
self.__zero_grad = True
|
||||
self._grad_map = {}
|
||||
|
||||
def add_param_group(self, group):
|
||||
self.param_groups.append(group)
|
||||
|
@ -189,6 +190,35 @@ class Optimizer(object):
|
|||
p.update(p - g * lr)
|
||||
self.zero_grad()
|
||||
|
||||
def _build_grad_map(self):
|
||||
_grad_map = {}
|
||||
for pg in self.param_groups:
|
||||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
_grad_map[id(p)] = g
|
||||
self._grad_map = _grad_map
|
||||
|
||||
def find_grad(self, v:jt.Var) -> jt.Var:
|
||||
if id(v) not in self._grad_map:
|
||||
self._build_grad_map()
|
||||
if id(v) not in self._grad_map:
|
||||
raise RuntimeError("This variable is not managed by this optimizer")
|
||||
return self._grad_map[id(v)]
|
||||
|
||||
def opt_grad(v:jt.Var, opt:Optimizer):
|
||||
''' Get grad of certain variable in optimizer, Example::
|
||||
|
||||
|
||||
model = Model()
|
||||
optimizer = SGD(model.parameters())
|
||||
...
|
||||
optimizer.backward(loss)
|
||||
|
||||
for p in model.parameters():
|
||||
grad = p.opt_grad(optimizer)
|
||||
'''
|
||||
return opt.find_grad(v)
|
||||
|
||||
jt.Var.opt_grad = opt_grad
|
||||
|
||||
class SGD(Optimizer):
|
||||
""" SGD Optimizer.
|
||||
|
|
|
@ -667,7 +667,7 @@ def compile_src(src, h, basename):
|
|||
arr_func_return.append(f"return ({func_call},0)")
|
||||
func_return_failed = "return -1"
|
||||
else:
|
||||
assert "-> void" in func_head
|
||||
assert "-> void" in func_head, func_head
|
||||
arr_func_return.append(f"{func_call};{before_return}return")
|
||||
func_return_failed = "return"
|
||||
# generate error msg when not a valid call
|
||||
|
|
|
@ -47,6 +47,8 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
return;
|
||||
}
|
||||
auto output = op->outputs().front();
|
||||
// return if output is all ready shared
|
||||
if (output->allocator) return;
|
||||
output->share_with(input);
|
||||
|
||||
auto data = op->input(1);
|
||||
|
@ -78,13 +80,13 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
if (!(ss.start == 0 && (ss.mask&2) && ss.step == 1))
|
||||
return;
|
||||
inplace_size *= in_shape[i];
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
if (s.is_var()) return;
|
||||
if (s.is_var() || s.is_str()) return;
|
||||
|
||||
auto size = 0;
|
||||
if (s.is_int())
|
||||
|
@ -175,7 +177,10 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
|
||||
auto in = op->inputs().front();
|
||||
auto ou = op->outputs().front();
|
||||
|
||||
|
||||
// return if out is all ready inplaced
|
||||
if (ou->allocator)
|
||||
return;
|
||||
// return if input or output's shape is variable
|
||||
if (in->num <= 0 || ou->num <= 0)
|
||||
return;
|
||||
|
@ -192,7 +197,7 @@ static void getitem_inplace(GetitemOp* op) {
|
|||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
if (s.is_var()) return;
|
||||
if (s.is_var() || s.is_str()) return;
|
||||
|
||||
auto size = 0;
|
||||
if (s.is_int())
|
||||
|
@ -214,7 +219,7 @@ void SetitemOp::graph_optimize() {
|
|||
void GetitemOp::graph_optimize() {
|
||||
// This optimize is still WIP
|
||||
// LOGir << "hello getitem graph_optimize";
|
||||
setitem_grad_opt(this);
|
||||
// setitem_grad_opt(this);
|
||||
(void)setitem_grad_opt;
|
||||
// (void)getitem_inplace;
|
||||
getitem_inplace(this);
|
||||
|
|
|
@ -207,4 +207,23 @@ string VarHolder::debug_msg() {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
int VarHolder::grad() {
|
||||
LOGf << R""(Jittor Var doesn't have this interface, please change
|
||||
your code as below::
|
||||
|
||||
model = Model()
|
||||
optimizer = SGD(model.parameters())
|
||||
...
|
||||
optimizer.backward(loss)
|
||||
|
||||
for p in model.parameters():
|
||||
# prev code:
|
||||
# grad = p.grad
|
||||
|
||||
# change to:
|
||||
grad = p.opt_grad(optimizer)
|
||||
)"";
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -264,6 +264,23 @@ struct VarHolder {
|
|||
*/
|
||||
// @pyjt(debug_msg)
|
||||
string debug_msg();
|
||||
|
||||
/* Jittor Var doesn't have this interface, please change your code as below::
|
||||
|
||||
model = Model()
|
||||
optimizer = SGD(model.parameters())
|
||||
...
|
||||
optimizer.backward(loss)
|
||||
|
||||
for p in model.parameters():
|
||||
# prev code:
|
||||
# grad = p.grad
|
||||
|
||||
# change to:
|
||||
grad = p.opt_grad(optimizer)
|
||||
*/
|
||||
// @pyjt(__get__grad)
|
||||
int grad();
|
||||
};
|
||||
|
||||
// @pyjt(sync)
|
||||
|
|
|
@ -41,6 +41,14 @@ class TestOptimizer(unittest.TestCase):
|
|||
# print(s)
|
||||
opt.load_state_dict(s)
|
||||
|
||||
def test_opt_grad(self):
|
||||
a = jt.ones(2)
|
||||
opt = jt.optim.SGD([a], 0.1)
|
||||
opt.backward(a**2)
|
||||
g = a.opt_grad(opt)
|
||||
np.testing.assert_allclose(g.data, 2)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -13,7 +13,7 @@ skip_this_test = False
|
|||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSetitem(unittest.TestCase):
|
||||
def test_setitem(self):
|
||||
def test_setitem_(self):
|
||||
arr0 = jt.random((4,2,2))
|
||||
data0 = jt.ones((2,2))
|
||||
arr0[1] = data0
|
||||
|
@ -32,7 +32,7 @@ class TestSetitem(unittest.TestCase):
|
|||
|
||||
arr1 = jt.random((4,2,2))
|
||||
data1 = jt.zeros((2,2))
|
||||
arr1[3,:,0:2] = data1
|
||||
arr1[3,:,:] = data1
|
||||
arr1.sync()
|
||||
data1.data[0,0] = 1
|
||||
assert arr1[3,0,0] == 1
|
||||
|
|
|
@ -956,6 +956,14 @@ class ToTensor:
|
|||
return self.__class__.__name__ + '()'
|
||||
|
||||
class ToPILImage(object):
|
||||
"""Convert a tensor or an ndarray to PIL Image.
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image.
|
||||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
||||
Returns:
|
||||
PIL Image: Image converted to PIL Image.
|
||||
"""
|
||||
def __init__(self, mode=None):
|
||||
self.mode = mode
|
||||
|
||||
|
|
Loading…
Reference in New Issue