update load_pytorch.py && add linear

This commit is contained in:
lzhengning 2023-03-27 23:02:37 +08:00 committed by Zheng-Ning Liu
parent b451a1eb85
commit 8a9f0bb904
3 changed files with 31 additions and 16 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.6.15'
__version__ = '1.3.6.16'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -636,6 +636,14 @@ class Linear(Module):
if self.bias is not None:
return x + self.bias
return x
def linear(x, weight, bias=None):
''' Returns x * weight^T
'''
x = matmul_transpose(input, weight)
if bias is not None:
return x + bias
return x
class BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):

View File

@ -247,21 +247,28 @@ def load_pytorch(fn_name):
f.read(8)
if offset is not None:
offset = f.tell()
for key, params in result.items():
requires_grad = params.requires_grad
shape = params.size
result[key] = jt.array(params.storage)
if shape is not None and len(shape) > 0:
if len(params.stride) > 1:
eval_list = []
for idx in range(len(params.stride)):
eval_list.append(f"@e0({idx}) * i{idx}")
evals = "+".join(eval_list)
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
else:
result[key] = result[key].reshape(shape)
if requires_grad is not None:
result[key].requires_grad = requires_grad
def dfs_results(result):
for key, params in result.items():
if isinstance(params, dict):
result[key] = dfs_results(params)
elif isinstance(params, ArrayWrapper):
requires_grad = params.requires_grad
shape = params.size
result[key] = jt.array(params.storage)
if shape is not None and len(shape) > 0:
if len(params.stride) > 1:
eval_list = []
for idx in range(len(params.stride)):
eval_list.append(f"@e0({idx}) * i{idx}")
evals = "+".join(eval_list)
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
else:
result[key] = result[key].reshape(shape)
if requires_grad is not None:
result[key].requires_grad = requires_grad
return result
result = dfs_results(result)
return result
if __name__ == "__main__":