mirror of https://github.com/Jittor/Jittor
update pytorch_converter
This commit is contained in:
parent
3c95c6d100
commit
ac052171ce
|
@ -179,6 +179,18 @@ pjmap = {
|
|||
'links': {},
|
||||
'extras': {'affine': 'None'},
|
||||
},
|
||||
'Parameter':{
|
||||
'pytorch': {
|
||||
'args': "data,require_grad=True"
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'jt',
|
||||
'name': 'array',
|
||||
'args': 'data,dtype=None',
|
||||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
'Dropout2d': {
|
||||
'pytorch': {
|
||||
'args': 'p=0.5, inplace=False',
|
||||
|
@ -397,7 +409,7 @@ unsupport_ops = [
|
|||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
|
@ -427,21 +439,30 @@ for key in pjmap.keys():
|
|||
if module == 'nn':
|
||||
support_ops[key] = name
|
||||
|
||||
def raise_unsupport(name):
|
||||
raise RuntimeError(f'{name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.')
|
||||
|
||||
def raise_unsupport(name,ori_src):
|
||||
ret = f"raise RuntimeError('''original source: <{ori_src.strip()}>, {name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.''')"
|
||||
print(ret+'\n')
|
||||
ret = ast.parse(ret).body[0]
|
||||
return ret
|
||||
|
||||
def replace(a):
|
||||
if hasattr(a, "attr") and a.attr in unsupport_ops:
|
||||
raise_unsupport(a.attr)
|
||||
ori_src = astunparse.unparse(a)
|
||||
print('2')
|
||||
return raise_unsupport(a.attr,ori_src)
|
||||
|
||||
if hasattr(a, "id") and a.id in unsupport_ops:
|
||||
raise_unsupport(a.id)
|
||||
ori_src = astunparse.unparse(a)
|
||||
print('3')
|
||||
return raise_unsupport(a.id)
|
||||
|
||||
if hasattr(a, "attr"):
|
||||
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
|
||||
|
||||
if hasattr(a, "id"):
|
||||
if a.id in support_ops.keys(): a.id = support_ops[a.id]
|
||||
|
||||
return None
|
||||
|
||||
import_flag = []
|
||||
def convert(code):
|
||||
|
@ -478,7 +499,7 @@ def convert(code):
|
|||
a.body.insert(2, ast.parse('from jittor import nn').body[0])
|
||||
return astunparse.unparse(a)
|
||||
|
||||
def convert_(prefix, func_name, ags, kws):
|
||||
def convert_(prefix, func_name, ags, kws, ori_src):
|
||||
info = pjmap[func_name]
|
||||
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
|
||||
if p_prefix is not None and prefix in p_prefix:
|
||||
|
@ -525,7 +546,8 @@ def convert_(prefix, func_name, ags, kws):
|
|||
else:
|
||||
pp_ags.append(p_ag)
|
||||
if len(jj_ags) == 0 and len(pp_ags) != 0:
|
||||
raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {pp_ags[0]}''')"
|
||||
# raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in pp_ags:
|
||||
|
@ -533,7 +555,8 @@ def convert_(prefix, func_name, ags, kws):
|
|||
if d in pp_kws.keys():
|
||||
jj_kws[d] = None
|
||||
if len(pp_ags) > len(ags) + len(kws):
|
||||
raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>, There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
ags_ = []
|
||||
for i in range(len(pp_ags)):
|
||||
if i < len(ags):
|
||||
|
@ -546,7 +569,8 @@ def convert_(prefix, func_name, ags, kws):
|
|||
else:
|
||||
break
|
||||
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
|
||||
raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
|
||||
j_ags_flag = np.zeros(len(jj_ags))
|
||||
j_ags_values = {}
|
||||
j_kws_values = {}
|
||||
|
@ -558,7 +582,8 @@ def convert_(prefix, func_name, ags, kws):
|
|||
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
|
||||
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
|
||||
else:
|
||||
raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,The args number is not matc{func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
|
||||
if ag_name in links.keys():
|
||||
ag_name = links[ag_name]
|
||||
if ag_name in jj_ags:
|
||||
|
@ -567,7 +592,8 @@ def convert_(prefix, func_name, ags, kws):
|
|||
elif ag_name in jj_kws.keys():
|
||||
j_kws_values[ag_name] = ag
|
||||
else:
|
||||
raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
|
||||
for i,kw in enumerate(kws):
|
||||
kw_name, kw_value = kw.split('=')
|
||||
if kw_name in links.keys():
|
||||
|
@ -578,14 +604,16 @@ def convert_(prefix, func_name, ags, kws):
|
|||
elif kw_name in jj_kws.keys():
|
||||
j_kws_values[kw_name] = kw_value
|
||||
else:
|
||||
raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {kw_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
|
||||
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
|
||||
if j_ags_flag.sum() < len_jj_ags:
|
||||
missing_args = []
|
||||
for i in range(len(jj_ags)):
|
||||
if j_ags_flag[i] == 0:
|
||||
missing_args.append(jj_ags[i])
|
||||
raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.''')"
|
||||
# raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
|
||||
if extras:
|
||||
for k in extras.keys():
|
||||
if k in jj_ags:
|
||||
|
@ -593,7 +621,8 @@ def convert_(prefix, func_name, ags, kws):
|
|||
elif k in jj_kws.keys():
|
||||
j_kws_values[k] = extras[k]
|
||||
else:
|
||||
raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.''')"
|
||||
# raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in j_ags_values:
|
||||
|
@ -634,16 +663,24 @@ def dfs(a):
|
|||
ret = dfs(kw)
|
||||
if ret is not None:
|
||||
a.keywords[idx] = ret
|
||||
ori_src = astunparse.unparse(a)
|
||||
func = astunparse.unparse(a.func).strip('\n').split('.')
|
||||
prefix = '.'.join(func[0:-1])
|
||||
func_name = func[-1]
|
||||
if func_name in unsupport_ops:
|
||||
raise_unsupport(func_name)
|
||||
ret = raise_unsupport(func_name, ori_src)
|
||||
return ret
|
||||
if func_name in pjmap.keys():
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
|
||||
ret = convert_(prefix, func_name, ags, kws)
|
||||
return ast.parse(ret).body[0].value
|
||||
ret = convert_(prefix, func_name, ags, kws, ori_src)
|
||||
ret_tmp = ret
|
||||
ret = ast.parse(ret).body[0]
|
||||
if hasattr(ret,'value'):
|
||||
return ret.value
|
||||
else:
|
||||
print(ret_tmp+'\n')
|
||||
return ret
|
||||
if ".load_state_dict" in astunparse.unparse(a.func):
|
||||
a.func.attr = 'load_parameters'
|
||||
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
|
||||
|
@ -654,7 +691,11 @@ def dfs(a):
|
|||
con = astunparse.unparse(a.func).replace('size', 'shape')
|
||||
return ast.parse(con).body[0].value
|
||||
elif isinstance(a, ast.Expr): pass
|
||||
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name): replace(a)
|
||||
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name):
|
||||
ret = replace(a)
|
||||
if ret is not None:
|
||||
print(ret)
|
||||
return ret
|
||||
elif isinstance(a, ast.FunctionDef):
|
||||
if a.name == 'forward': a.name = 'execute'
|
||||
if hasattr(a, '__dict__'):
|
||||
|
|
Loading…
Reference in New Issue