mirror of https://github.com/Jittor/Jittor
update load_pytorch.py && add linear
This commit is contained in:
parent
b451a1eb85
commit
8a9f0bb904
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.6.15'
|
||||
__version__ = '1.3.6.16'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -636,6 +636,14 @@ class Linear(Module):
|
|||
if self.bias is not None:
|
||||
return x + self.bias
|
||||
return x
|
||||
|
||||
def linear(x, weight, bias=None):
|
||||
''' Returns x * weight^T
|
||||
'''
|
||||
x = matmul_transpose(input, weight)
|
||||
if bias is not None:
|
||||
return x + bias
|
||||
return x
|
||||
|
||||
class BatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
|
|
|
@ -247,21 +247,28 @@ def load_pytorch(fn_name):
|
|||
f.read(8)
|
||||
if offset is not None:
|
||||
offset = f.tell()
|
||||
for key, params in result.items():
|
||||
requires_grad = params.requires_grad
|
||||
shape = params.size
|
||||
result[key] = jt.array(params.storage)
|
||||
if shape is not None and len(shape) > 0:
|
||||
if len(params.stride) > 1:
|
||||
eval_list = []
|
||||
for idx in range(len(params.stride)):
|
||||
eval_list.append(f"@e0({idx}) * i{idx}")
|
||||
evals = "+".join(eval_list)
|
||||
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
|
||||
else:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if requires_grad is not None:
|
||||
result[key].requires_grad = requires_grad
|
||||
|
||||
def dfs_results(result):
|
||||
for key, params in result.items():
|
||||
if isinstance(params, dict):
|
||||
result[key] = dfs_results(params)
|
||||
elif isinstance(params, ArrayWrapper):
|
||||
requires_grad = params.requires_grad
|
||||
shape = params.size
|
||||
result[key] = jt.array(params.storage)
|
||||
if shape is not None and len(shape) > 0:
|
||||
if len(params.stride) > 1:
|
||||
eval_list = []
|
||||
for idx in range(len(params.stride)):
|
||||
eval_list.append(f"@e0({idx}) * i{idx}")
|
||||
evals = "+".join(eval_list)
|
||||
result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)])
|
||||
else:
|
||||
result[key] = result[key].reshape(shape)
|
||||
if requires_grad is not None:
|
||||
result[key].requires_grad = requires_grad
|
||||
return result
|
||||
result = dfs_results(result)
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue