mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of github.com:Jittor/jittor
This commit is contained in:
commit
c12549020f
|
@ -486,8 +486,26 @@ def einsum(string, *args):
|
|||
|
||||
def einsum_outshape(einsum_expr, inputs):
|
||||
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
|
||||
p = einsum_expr.split(',')
|
||||
p = einsum_expr.replace(" ", "").split(',')
|
||||
s = p[:-1] + p[-1].split('->')
|
||||
rec_shape = []
|
||||
ellip_expr = None
|
||||
const_rep = '1234567890' # assume tensor shape no more than 10 dimensions
|
||||
for idx, expr in enumerate(s[:-1]):
|
||||
if "..." in expr:
|
||||
assert "..." in s[-1]
|
||||
else:
|
||||
continue
|
||||
shp = inputs[idx].shape
|
||||
ellipsis_pos = len(expr.replace("...", ""))
|
||||
nellip_expr = const_rep[0 : len(shp) - ellipsis_pos]
|
||||
if ellip_expr is None:
|
||||
ellip_expr = nellip_expr
|
||||
else:
|
||||
assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis."
|
||||
s[idx] = expr.replace("...", ellip_expr)
|
||||
if ellip_expr:
|
||||
s[-1] = s[-1].replace("...", ellip_expr)
|
||||
if s[-1]=='':
|
||||
return ()
|
||||
else:
|
||||
|
|
|
@ -90,7 +90,8 @@ class StorageType():
|
|||
def jittor_rebuild(storage, storage_offset, size, stride, requires_grad, backward_hooks):
|
||||
if len(size) == 0:
|
||||
return jt.array(storage)
|
||||
return jt.array(storage).reshape(size)
|
||||
record_size = np.prod(size)
|
||||
return jt.array(storage[:record_size]).reshape(size)
|
||||
|
||||
def jittor_rebuild_var(data, requires_grad, backward_hooks):
|
||||
v = jt.array(data)
|
||||
|
@ -112,19 +113,20 @@ class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
|
|||
return super().find_class(mod_name, name)
|
||||
|
||||
class ArrayWrapper:
|
||||
def __init__(self, storage, size=None, requires_grad=None):
|
||||
def __init__(self, storage, stride=None, size=None, requires_grad=None):
|
||||
self.requires_grad = requires_grad
|
||||
self.size = size
|
||||
self.storage = storage
|
||||
|
||||
self.stride = stride
|
||||
|
||||
def __str__(self):
|
||||
return self.storage.__str__()
|
||||
|
||||
def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks):
|
||||
if len(size) == 0:
|
||||
return ArrayWrapper(storage, size=size)
|
||||
return ArrayWrapper(storage, stride=stride, size=size)
|
||||
storage.reshape(size)
|
||||
return ArrayWrapper(storage, size=size)
|
||||
return ArrayWrapper(storage, stride=stride, size=size)
|
||||
|
||||
def jittor_rebuild_var_direct(data, requires_grad, backward_hooks):
|
||||
v = ArrayWrapper(storage, requires_grad=requires_grad)
|
||||
|
@ -206,7 +208,9 @@ def persistent_load_direct(saved_id):
|
|||
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
|
||||
|
||||
def load_pytorch(fn_name):
|
||||
global contents, deserialized_objects
|
||||
global contents, deserialized_objects, loaded_storages
|
||||
loaded_storages = {}
|
||||
deserialized_objects = {}
|
||||
if not fn_name.endswith(".pth"):
|
||||
print("This function is designed to load pytorch pth format files.")
|
||||
return None
|
||||
|
@ -252,7 +256,14 @@ def load_pytorch(fn_name):
|
|||
shape = params.size
|
||||
result[key] = jt.array(params.storage)
|
||||
if shape is not None and len(shape) > 0:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if len(params.stride) > 1:
|
||||
eval_list = []
|
||||
for idx in range(len(params.stride)):
|
||||
eval_list.append(f"@e0({idx}) * i{idx}")
|
||||
evals = "+".join(eval_list)
|
||||
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
|
||||
else:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if requires_grad is not None:
|
||||
result[key].requires_grad = requires_grad
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue