mirror of https://github.com/Jittor/Jittor
small fix & add test
This commit is contained in:
parent
d94f227272
commit
4ddb2b7053
|
@ -21,6 +21,7 @@ with lock.lock_scope():
|
|||
import contextlib
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence, Mapping
|
||||
import types
|
||||
import pickle
|
||||
import sys
|
||||
|
@ -340,19 +341,24 @@ def detach(x):
|
|||
return x.clone().stop_grad().clone()
|
||||
Var.detach = detach
|
||||
|
||||
def view(x, *shape):
|
||||
if isinstance(shape[0], tuple):
|
||||
origin_reshape = reshape
|
||||
def reshape(x, *shape):
|
||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||
shape = shape[0]
|
||||
return x.reshape(shape)
|
||||
Var.view = view
|
||||
return origin_reshape(x, shape)
|
||||
reshape.__doc__ = origin_reshape.__doc__
|
||||
Var.view = Var.reshape = view = reshape
|
||||
|
||||
def permute(x, *dim):
|
||||
if isinstance(dim[0], tuple):
|
||||
origin_transpose = transpose
|
||||
def transpose(x, *dim):
|
||||
if len(dim) == 1 and isinstance(dim[0], Sequence):
|
||||
dim = dim[0]
|
||||
return transpose(x, dim)
|
||||
Var.permute = permute
|
||||
return origin_transpose(x, dim)
|
||||
transpose.__doc__ = origin_transpose.__doc__
|
||||
Var.transpose = Var.permute = permute = transpose
|
||||
|
||||
def flatten(input, start_dim=0, end_dim=-1):
|
||||
'''flatten dimentions by reshape'''
|
||||
in_shape = input.shape
|
||||
start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim
|
||||
end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim
|
||||
|
@ -668,8 +674,9 @@ def jittor_exit():
|
|||
core.sync_all(True)
|
||||
atexit.register(jittor_exit)
|
||||
|
||||
Var.__repr__ = Var.__str__ = lambda x: str(x.data)
|
||||
Var.peek = lambda x: str(x.dtype)+str(x.shape)
|
||||
Var.__str__ = lambda x: str(x.data)
|
||||
Var.__repr__ = lambda x: f"jt.Var:{x.dtype}{x.uncertain_shape}"
|
||||
Var.peek = lambda x: f"{x.dtype}{x.shape}"
|
||||
|
||||
from . import nn
|
||||
from .nn import matmul
|
||||
|
|
|
@ -60,5 +60,16 @@ class TestReshapeOp(unittest.TestCase):
|
|||
assert node_dict['a'] == node_dict['d']
|
||||
assert node_dict['a'] == node_dict['e']
|
||||
|
||||
def test_view(self):
|
||||
a = jt.ones([2,3,4])
|
||||
assert a.view(2,-1).shape == [2,12]
|
||||
|
||||
def test_flatten(self):
|
||||
a = jt.ones([2,3,4])
|
||||
assert a.flatten().shape == [24]
|
||||
assert a.flatten(1).shape == [2,12]
|
||||
assert a.flatten(0,-2).shape == [6,4]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -61,5 +61,10 @@ class TestTransposeOp(unittest.TestCase):
|
|||
assert ((da-jda.data)<1e-5).all(), (da, jda.data, da-jda.data)
|
||||
assert ((db-jdb.data)<1e-5).all(), (db-jdb.data)
|
||||
|
||||
def test_permute(self):
|
||||
a = jt.ones([2,3,4])
|
||||
assert a.permute().shape == [4,3,2]
|
||||
assert a.permute(0,2,1).shape == [2,4,3]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -250,13 +250,15 @@ for key in pjmap.keys():
|
|||
if module == 'nn':
|
||||
support_ops[key] = name
|
||||
|
||||
def raise_unspoort(name):
|
||||
raise RuntimeError(f'{a.attr} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {a.attr} and make pull request at https://github.com/Jittor/jittor.')
|
||||
|
||||
def replace(a):
|
||||
if hasattr(a, "attr") and a.attr in unsupport_ops:
|
||||
raise RuntimeError(f'{a.attr} is not supported in Jittor yet. We will appreciate it if you code {a.attr} function and make pull request at https://github.com/Jittor/jittor.')
|
||||
raise_unspoort(a.attr)
|
||||
|
||||
if hasattr(a, "id") and a.id in unsupport_ops:
|
||||
raise RuntimeError(f'{a.id} is not supported in Jittor yet. We will appreciate it if you code {a.id} function and make pull request at https://github.com/Jittor/jittor.')
|
||||
raise_unspoort(a.id)
|
||||
|
||||
if hasattr(a, "attr"):
|
||||
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
|
||||
|
@ -419,7 +421,7 @@ def dfs(a):
|
|||
prefix = '.'.join(func[0:-1])
|
||||
func_name = func[-1]
|
||||
if func_name in unsupport_ops:
|
||||
raise RuntimeError(f'{func_name} is not supported in Jittor yet. We will appreciate it if you code {func_name} function and make pull request at https://github.com/Jittor/jittor.')
|
||||
raise_unspoort(func_name)
|
||||
if func_name in pjmap.keys():
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
|
||||
|
@ -456,4 +458,4 @@ def dfs(a):
|
|||
else:
|
||||
ret = dfs(a.__dict__[k])
|
||||
if ret is not None:
|
||||
a.__dict__[k] = ret
|
||||
a.__dict__[k] = ret
|
||||
|
|
Loading…
Reference in New Issue