mirror of https://github.com/Jittor/Jittor
add parameter list and dict
This commit is contained in:
parent
1acf6492f4
commit
698fc6fe88
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.37'
|
__version__ = '1.2.3.38'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
@ -780,8 +780,11 @@ class Module:
|
||||||
stack = []
|
stack = []
|
||||||
def callback(parents, k, v, n):
|
def callback(parents, k, v, n):
|
||||||
stack.append(str(k))
|
stack.append(str(k))
|
||||||
for k2, p in v.__dict__.items():
|
dc = v.__dict__
|
||||||
if k2.startswith("_"): continue
|
if isinstance(v, nn.ParameterList):
|
||||||
|
dc = v.params
|
||||||
|
for k2, p in dc.items():
|
||||||
|
if isinstance(k2, str) and k2.startswith("_"): continue
|
||||||
if isinstance(p, Var):
|
if isinstance(p, Var):
|
||||||
ps.append(p)
|
ps.append(p)
|
||||||
p.name(".".join(stack[1:]+[str(k2)]))
|
p.name(".".join(stack[1:]+[str(k2)]))
|
||||||
|
|
|
@ -1564,6 +1564,47 @@ class Sequential(Module):
|
||||||
return len(self.layers)
|
return len(self.layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterList(Module):
|
||||||
|
def __init__(self, *args):
|
||||||
|
self.params = collections.OrderedDict()
|
||||||
|
for var in args:
|
||||||
|
if isinstance(var, (collections.OrderedDict, dict)):
|
||||||
|
for k, v in var.items():
|
||||||
|
self.add_param(k, v)
|
||||||
|
elif isinstance(var, list):
|
||||||
|
for v in var:
|
||||||
|
self.append(v)
|
||||||
|
else:
|
||||||
|
self.append(var)
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if idx not in self.params:
|
||||||
|
return list(self.params.values())[idx]
|
||||||
|
|
||||||
|
return self.params[idx]
|
||||||
|
def __iter__(self):
|
||||||
|
return self.params.values().__iter__()
|
||||||
|
def keys(self):
|
||||||
|
return self.params.keys()
|
||||||
|
def values(self):
|
||||||
|
return self.params.values()
|
||||||
|
def items(self):
|
||||||
|
return self.params.items()
|
||||||
|
def execute(self, x):
|
||||||
|
raise NotImplementedError("Parameters is not executable")
|
||||||
|
def append(self, var):
|
||||||
|
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
|
||||||
|
self.params[len(self.params)] = var
|
||||||
|
def add_param(self, name, var):
|
||||||
|
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
|
||||||
|
self.params[name]=var
|
||||||
|
def __setitem__(self, name, var):
|
||||||
|
self.add_param(name, var)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.params)
|
||||||
|
|
||||||
|
ParameterDict = ParameterList
|
||||||
|
|
||||||
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
||||||
assert X.ndim == 4
|
assert X.ndim == 4
|
||||||
if not isinstance(kernel_size, tuple):
|
if not isinstance(kernel_size, tuple):
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||||
|
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import unittest
|
||||||
|
import jittor as jt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestParamList(unittest.TestCase):
|
||||||
|
def test_param_list(self):
|
||||||
|
ps = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)])
|
||||||
|
assert len(ps.parameters()) == 2
|
||||||
|
assert list(ps.state_dict().keys()) == ['0', '1'], ps.state_dict().keys()
|
||||||
|
|
||||||
|
def test_with_module(self):
|
||||||
|
class Net(jt.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
self.ps1 = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)])
|
||||||
|
self.ps2 = jt.nn.ParameterDict({
|
||||||
|
"aaa":jt.array([1,2,3]),
|
||||||
|
"bbb": jt.rand(10)
|
||||||
|
})
|
||||||
|
net = Net()
|
||||||
|
assert net.state_dict().keys() == ['ps1.0', 'ps1.1', 'ps2.aaa', 'ps2.bbb']
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue