add parameter list and dict

This commit is contained in:
Dun Liang 2021-06-22 21:25:38 +08:00
parent 1acf6492f4
commit 698fc6fe88
3 changed files with 78 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.37'
__version__ = '1.2.3.38'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -780,8 +780,11 @@ class Module:
stack = []
def callback(parents, k, v, n):
stack.append(str(k))
for k2, p in v.__dict__.items():
if k2.startswith("_"): continue
dc = v.__dict__
if isinstance(v, nn.ParameterList):
dc = v.params
for k2, p in dc.items():
if isinstance(k2, str) and k2.startswith("_"): continue
if isinstance(p, Var):
ps.append(p)
p.name(".".join(stack[1:]+[str(k2)]))

View File

@ -1564,6 +1564,47 @@ class Sequential(Module):
return len(self.layers)
class ParameterList(Module):
def __init__(self, *args):
self.params = collections.OrderedDict()
for var in args:
if isinstance(var, (collections.OrderedDict, dict)):
for k, v in var.items():
self.add_param(k, v)
elif isinstance(var, list):
for v in var:
self.append(v)
else:
self.append(var)
def __getitem__(self, idx):
if idx not in self.params:
return list(self.params.values())[idx]
return self.params[idx]
def __iter__(self):
return self.params.values().__iter__()
def keys(self):
return self.params.keys()
def values(self):
return self.params.values()
def items(self):
return self.params.items()
def execute(self, x):
raise NotImplementedError("Parameters is not executable")
def append(self, var):
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
self.params[len(self.params)] = var
def add_param(self, name, var):
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
self.params[name]=var
def __setitem__(self, name, var):
self.add_param(name, var)
def __len__(self):
return len(self.params)
ParameterDict = ParameterList
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
assert X.ndim == 4
if not isinstance(kernel_size, tuple):

View File

@ -0,0 +1,31 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestParamList(unittest.TestCase):
def test_param_list(self):
ps = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)])
assert len(ps.parameters()) == 2
assert list(ps.state_dict().keys()) == ['0', '1'], ps.state_dict().keys()
def test_with_module(self):
class Net(jt.nn.Module):
def __init__(self):
self.ps1 = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)])
self.ps2 = jt.nn.ParameterDict({
"aaa":jt.array([1,2,3]),
"bbb": jt.rand(10)
})
net = Net()
assert net.state_dict().keys() == ['ps1.0', 'ps1.1', 'ps2.aaa', 'ps2.bbb']
if __name__ == "__main__":
unittest.main()