small fix & add test

This commit is contained in:
Dun Liang 2020-04-20 14:54:44 +08:00
parent d94f227272
commit 4ddb2b7053
4 changed files with 39 additions and 14 deletions

View File

@ -21,6 +21,7 @@ with lock.lock_scope():
import contextlib
import numpy as np
from collections import OrderedDict
from collections.abc import Sequence, Mapping
import types
import pickle
import sys
@ -340,19 +341,24 @@ def detach(x):
return x.clone().stop_grad().clone()
Var.detach = detach
def view(x, *shape):
if isinstance(shape[0], tuple):
origin_reshape = reshape
def reshape(x, *shape):
if len(shape) == 1 and isinstance(shape[0], Sequence):
shape = shape[0]
return x.reshape(shape)
Var.view = view
return origin_reshape(x, shape)
reshape.__doc__ = origin_reshape.__doc__
Var.view = Var.reshape = view = reshape
def permute(x, *dim):
if isinstance(dim[0], tuple):
origin_transpose = transpose
def transpose(x, *dim):
if len(dim) == 1 and isinstance(dim[0], Sequence):
dim = dim[0]
return transpose(x, dim)
Var.permute = permute
return origin_transpose(x, dim)
transpose.__doc__ = origin_transpose.__doc__
Var.transpose = Var.permute = permute = transpose
def flatten(input, start_dim=0, end_dim=-1):
'''flatten dimentions by reshape'''
in_shape = input.shape
start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
@ -668,8 +674,9 @@ def jittor_exit():
core.sync_all(True)
atexit.register(jittor_exit)
Var.__repr__ = Var.__str__ = lambda x: str(x.data)
Var.peek = lambda x: str(x.dtype)+str(x.shape)
Var.__str__ = lambda x: str(x.data)
Var.__repr__ = lambda x: f"jt.Var:{x.dtype}{x.uncertain_shape}"
Var.peek = lambda x: f"{x.dtype}{x.shape}"
from . import nn
from .nn import matmul

View File

@ -60,5 +60,16 @@ class TestReshapeOp(unittest.TestCase):
assert node_dict['a'] == node_dict['d']
assert node_dict['a'] == node_dict['e']
def test_view(self):
a = jt.ones([2,3,4])
assert a.view(2,-1).shape == [2,12]
def test_flatten(self):
a = jt.ones([2,3,4])
assert a.flatten().shape == [24]
assert a.flatten(1).shape == [2,12]
assert a.flatten(0,-2).shape == [6,4]
if __name__ == "__main__":
unittest.main()

View File

@ -61,5 +61,10 @@ class TestTransposeOp(unittest.TestCase):
assert ((da-jda.data)<1e-5).all(), (da, jda.data, da-jda.data)
assert ((db-jdb.data)<1e-5).all(), (db-jdb.data)
def test_permute(self):
a = jt.ones([2,3,4])
assert a.permute().shape == [4,3,2]
assert a.permute(0,2,1).shape == [2,4,3]
if __name__ == "__main__":
unittest.main()

View File

@ -250,13 +250,15 @@ for key in pjmap.keys():
if module == 'nn':
support_ops[key] = name
def raise_unspoort(name):
raise RuntimeError(f'{a.attr} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {a.attr} and make pull request at https://github.com/Jittor/jittor.')
def replace(a):
if hasattr(a, "attr") and a.attr in unsupport_ops:
raise RuntimeError(f'{a.attr} is not supported in Jittor yet. We will appreciate it if you code {a.attr} function and make pull request at https://github.com/Jittor/jittor.')
raise_unspoort(a.attr)
if hasattr(a, "id") and a.id in unsupport_ops:
raise RuntimeError(f'{a.id} is not supported in Jittor yet. We will appreciate it if you code {a.id} function and make pull request at https://github.com/Jittor/jittor.')
raise_unspoort(a.id)
if hasattr(a, "attr"):
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
@ -419,7 +421,7 @@ def dfs(a):
prefix = '.'.join(func[0:-1])
func_name = func[-1]
if func_name in unsupport_ops:
raise RuntimeError(f'{func_name} is not supported in Jittor yet. We will appreciate it if you code {func_name} function and make pull request at https://github.com/Jittor/jittor.')
raise_unspoort(func_name)
if func_name in pjmap.keys():
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
@ -456,4 +458,4 @@ def dfs(a):
else:
ret = dfs(a.__dict__[k])
if ret is not None:
a.__dict__[k] = ret
a.__dict__[k] = ret