mirror of https://github.com/Jittor/Jittor
add __iter__ in sequential
This commit is contained in:
parent
d72ebb8ea1
commit
ee268c7456
|
@ -1142,6 +1142,14 @@ class Sequential(Module):
|
|||
self.append(mod)
|
||||
def __getitem__(self, idx):
|
||||
return self.layers[idx]
|
||||
def __iter__(self):
|
||||
return self.layers.values().__iter__()
|
||||
def keys(self):
|
||||
return self.layers.keys()
|
||||
def values(self):
|
||||
return self.layers.values()
|
||||
def items(self):
|
||||
return self.layers.items()
|
||||
def execute(self, x):
|
||||
for k, layer in self.layers.items():
|
||||
x = layer(x)
|
||||
|
|
|
@ -149,6 +149,14 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
da = jt.grad(a.sigmoid(), a)
|
||||
assert np.isnan(da.data).sum()==0, da.data
|
||||
|
||||
def test_sequential(self):
|
||||
x = jt.nn.Sequential(lambda x:x, lambda x:x)
|
||||
n = 0
|
||||
for a in x:
|
||||
n += 1
|
||||
assert n == 2
|
||||
assert list(x.keys()) == [0,1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue