Merge pull request #378 from Exusial/npth

Polish loading weights for PyTorch .pth files
This commit is contained in:
Zheng-Ning Liu 2022-10-06 16:50:35 +08:00 committed by GitHub
commit 7827c45047
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 8 deletions

View File

@ -486,8 +486,26 @@ def einsum(string, *args):
def einsum_outshape(einsum_expr, inputs):
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
p = einsum_expr.split(',')
p = einsum_expr.replace(" ", "").split(',')
s = p[:-1] + p[-1].split('->')
rec_shape = []
ellip_expr = None
const_rep = '1234567890' # assume tensor shape no more than 10 dimensions
for idx, expr in enumerate(s[:-1]):
if "..." in expr:
assert "..." in s[-1]
else:
continue
shp = inputs[idx].shape
ellipsis_pos = len(expr.replace("...", ""))
nellip_expr = const_rep[0 : len(shp) - ellipsis_pos]
if ellip_expr is None:
ellip_expr = nellip_expr
else:
assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis."
s[idx] = expr.replace("...", ellip_expr)
if ellip_expr:
s[-1] = s[-1].replace("...", ellip_expr)
if s[-1]=='':
return ()
else:

View File

@ -90,7 +90,8 @@ class StorageType():
def jittor_rebuild(storage, storage_offset, size, stride, requires_grad, backward_hooks):
if len(size) == 0:
return jt.array(storage)
return jt.array(storage).reshape(size)
record_size = np.prod(size)
return jt.array(storage[:record_size]).reshape(size)
def jittor_rebuild_var(data, requires_grad, backward_hooks):
v = jt.array(data)
@ -112,19 +113,20 @@ class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined]
return super().find_class(mod_name, name)
class ArrayWrapper:
def __init__(self, storage, size=None, requires_grad=None):
def __init__(self, storage, stride=None, size=None, requires_grad=None):
self.requires_grad = requires_grad
self.size = size
self.storage = storage
self.stride = stride
def __str__(self):
return self.storage.__str__()
def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks):
if len(size) == 0:
return ArrayWrapper(storage, size=size)
return ArrayWrapper(storage, stride=stride, size=size)
storage.reshape(size)
return ArrayWrapper(storage, size=size)
return ArrayWrapper(storage, stride=stride, size=size)
def jittor_rebuild_var_direct(data, requires_grad, backward_hooks):
v = ArrayWrapper(storage, requires_grad=requires_grad)
@ -206,7 +208,9 @@ def persistent_load_direct(saved_id):
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
def load_pytorch(fn_name):
global contents, deserialized_objects
global contents, deserialized_objects, loaded_storages
loaded_storages = {}
deserialized_objects = {}
if not fn_name.endswith(".pth"):
print("This function is designed to load pytorch pth format files.")
return None
@ -252,7 +256,14 @@ def load_pytorch(fn_name):
shape = params.size
result[key] = jt.array(params.storage)
if shape is not None and len(shape) > 0:
result[key] = result[key].reshape(shape)
if len(params.stride) > 1:
eval_list = []
for idx in range(len(params.stride)):
eval_list.append(f"@e0({idx}) * i{idx}")
evals = "+".join(eval_list)
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
else:
result[key] = result[key].reshape(shape)
if requires_grad is not None:
result[key].requires_grad = requires_grad
return result