mirror of https://github.com/Jittor/Jittor
add parameter list and dict
This commit is contained in:
parent
1acf6492f4
commit
698fc6fe88
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.37'
|
||||
__version__ = '1.2.3.38'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -780,8 +780,11 @@ class Module:
|
|||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
for k2, p in v.__dict__.items():
|
||||
if k2.startswith("_"): continue
|
||||
dc = v.__dict__
|
||||
if isinstance(v, nn.ParameterList):
|
||||
dc = v.params
|
||||
for k2, p in dc.items():
|
||||
if isinstance(k2, str) and k2.startswith("_"): continue
|
||||
if isinstance(p, Var):
|
||||
ps.append(p)
|
||||
p.name(".".join(stack[1:]+[str(k2)]))
|
||||
|
|
|
@ -1564,6 +1564,47 @@ class Sequential(Module):
|
|||
return len(self.layers)
|
||||
|
||||
|
||||
class ParameterList(Module):
|
||||
def __init__(self, *args):
|
||||
self.params = collections.OrderedDict()
|
||||
for var in args:
|
||||
if isinstance(var, (collections.OrderedDict, dict)):
|
||||
for k, v in var.items():
|
||||
self.add_param(k, v)
|
||||
elif isinstance(var, list):
|
||||
for v in var:
|
||||
self.append(v)
|
||||
else:
|
||||
self.append(var)
|
||||
def __getitem__(self, idx):
|
||||
if idx not in self.params:
|
||||
return list(self.params.values())[idx]
|
||||
|
||||
return self.params[idx]
|
||||
def __iter__(self):
|
||||
return self.params.values().__iter__()
|
||||
def keys(self):
|
||||
return self.params.keys()
|
||||
def values(self):
|
||||
return self.params.values()
|
||||
def items(self):
|
||||
return self.params.items()
|
||||
def execute(self, x):
|
||||
raise NotImplementedError("Parameters is not executable")
|
||||
def append(self, var):
|
||||
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
|
||||
self.params[len(self.params)] = var
|
||||
def add_param(self, name, var):
|
||||
assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var"
|
||||
self.params[name]=var
|
||||
def __setitem__(self, name, var):
|
||||
self.add_param(name, var)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.params)
|
||||
|
||||
ParameterDict = ParameterList
|
||||
|
||||
def unfold(X, kernel_size, dilation=1, padding=0, stride=1):
|
||||
assert X.ndim == 4
|
||||
if not isinstance(kernel_size, tuple):
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class TestParamList(unittest.TestCase):
|
||||
def test_param_list(self):
|
||||
ps = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)])
|
||||
assert len(ps.parameters()) == 2
|
||||
assert list(ps.state_dict().keys()) == ['0', '1'], ps.state_dict().keys()
|
||||
|
||||
def test_with_module(self):
|
||||
class Net(jt.nn.Module):
|
||||
def __init__(self):
|
||||
self.ps1 = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)])
|
||||
self.ps2 = jt.nn.ParameterDict({
|
||||
"aaa":jt.array([1,2,3]),
|
||||
"bbb": jt.rand(10)
|
||||
})
|
||||
net = Net()
|
||||
assert net.state_dict().keys() == ['ps1.0', 'ps1.1', 'ps2.aaa', 'ps2.bbb']
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue