mirror of https://github.com/Jittor/Jittor
pytorch converter server
This commit is contained in:
parent
c8252ac7fb
commit
3c55f6ecff
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.3'
|
||||
__version__ = '1.2.2.4'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
from flask import Flask
|
||||
from flask import request
|
||||
from flask import jsonify
|
||||
app = Flask(__name__)
|
||||
import json
|
||||
|
||||
from jittor.utils.pytorch_converter import convert
|
||||
|
||||
@app.route('/', methods=["GET", "POST"])
|
||||
def hello():
|
||||
msg = request
|
||||
data = msg.data.decode("utf-8")
|
||||
try:
|
||||
data = json.loads(data)
|
||||
src = data["src"]
|
||||
pjmap = json.loads(data["pjmap"])
|
||||
jt_src = convert(src, pjmap)
|
||||
except Exception as e:
|
||||
jt_src = str(e)
|
||||
response = jsonify(jt_src=jt_src)
|
||||
|
||||
# Enable Access-Control-Allow-Origin
|
||||
response.headers.add("Access-Control-Allow-Origin", "*")
|
||||
return response
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0")
|
|
@ -363,6 +363,32 @@ pjmap = {
|
|||
}
|
||||
}
|
||||
|
||||
unsupport_ops = [
|
||||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
|
||||
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
|
||||
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
|
||||
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
|
||||
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
|
||||
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
|
||||
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
|
||||
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
|
||||
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
|
||||
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
|
||||
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
|
||||
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
|
||||
]
|
||||
|
||||
def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None):
|
||||
''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool
|
||||
|
@ -405,67 +431,268 @@ def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_fun
|
|||
'delete': delete,
|
||||
}
|
||||
|
||||
unsupport_ops = [
|
||||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
|
||||
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
|
||||
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
|
||||
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
|
||||
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
|
||||
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
|
||||
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
|
||||
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
|
||||
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
|
||||
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
|
||||
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
|
||||
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
|
||||
]
|
||||
|
||||
support_ops = {}
|
||||
for key in pjmap.keys():
|
||||
module = pjmap[key]['jittor']['module']
|
||||
name = pjmap[key]['jittor']['name']
|
||||
if module == 'nn':
|
||||
support_ops[key] = name
|
||||
|
||||
def raise_unsupport(name,ori_src):
|
||||
def raise_unsupport(name, ori_src):
|
||||
ret = f"raise RuntimeError('''original source: <{ori_src.strip()}>, {name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.''')"
|
||||
print(ret+'\n')
|
||||
ret = ast.parse(ret).body[0]
|
||||
return ret
|
||||
|
||||
def replace(a):
|
||||
if hasattr(a, "attr") and a.attr in unsupport_ops:
|
||||
ori_src = astunparse.unparse(a)
|
||||
print('2')
|
||||
return raise_unsupport(a.attr,ori_src)
|
||||
|
||||
if hasattr(a, "id") and a.id in unsupport_ops:
|
||||
ori_src = astunparse.unparse(a)
|
||||
print('3')
|
||||
return raise_unsupport(a.id)
|
||||
|
||||
if hasattr(a, "attr"):
|
||||
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
|
||||
|
||||
if hasattr(a, "id"):
|
||||
if a.id in support_ops.keys(): a.id = support_ops[a.id]
|
||||
|
||||
return None
|
||||
class Converter:
|
||||
def __init__(self, ex_pjmap):
|
||||
import copy
|
||||
self.pjmap = copy.deepcopy(pjmap)
|
||||
if ex_pjmap:
|
||||
self.pjmap.update(ex_pjmap)
|
||||
self.unsupport_ops = set(unsupport_ops)
|
||||
support_ops = {}
|
||||
for key in self.pjmap.keys():
|
||||
module = self.pjmap[key]['jittor']['module']
|
||||
name = self.pjmap[key]['jittor']['name']
|
||||
if module == 'nn':
|
||||
support_ops[key] = name
|
||||
if key in self.unsupport_ops:
|
||||
self.unsupport_ops.remove(key)
|
||||
self.support_ops = support_ops
|
||||
self.import_flag = []
|
||||
|
||||
import_flag = []
|
||||
def convert(code):
|
||||
def replace(self, a):
|
||||
if hasattr(a, "attr") and a.attr in self.unsupport_ops:
|
||||
ori_src = astunparse.unparse(a)
|
||||
return raise_unsupport(a.attr, ori_src)
|
||||
|
||||
if hasattr(a, "id") and a.id in self.unsupport_ops:
|
||||
ori_src = astunparse.unparse(a)
|
||||
return raise_unsupport(a.id, ori_src)
|
||||
|
||||
if hasattr(a, "attr"):
|
||||
if a.attr in self.support_ops.keys(): a.attr = self.support_ops[a.attr]
|
||||
|
||||
if hasattr(a, "id"):
|
||||
if a.id in self.support_ops.keys(): a.id = self.support_ops[a.id]
|
||||
|
||||
return None
|
||||
|
||||
def convert_(self, prefix, func_name, ags, kws, ori_src):
|
||||
info = self.pjmap[func_name]
|
||||
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
|
||||
if p_prefix is not None and prefix in p_prefix:
|
||||
p_ags = info['pytorch']['args_prefix']
|
||||
j_ags = info['jittor']['args_prefix']
|
||||
else:
|
||||
p_ags = info['pytorch']['args']
|
||||
j_ags = info['jittor']['args']
|
||||
if 'delete' in info.keys():
|
||||
delete = info['delete']
|
||||
else:
|
||||
delete = None
|
||||
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
|
||||
j_module = info['jittor']['module']
|
||||
j_name = info['jittor']['name']
|
||||
links = info['links']
|
||||
extras = info['extras']
|
||||
jj_ags = []
|
||||
jj_kws = {}
|
||||
pp_ags = []
|
||||
pp_kws = {}
|
||||
if j_ags == '' and p_ags == '':
|
||||
# no args in Pytorch and Jittor.
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}()"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}()"
|
||||
else:
|
||||
return f"{prefix}.{j_name}()"
|
||||
else:
|
||||
j_ags = j_ags.replace(' ','').split(',')
|
||||
for j_ag in j_ags:
|
||||
if '=' in j_ag:
|
||||
k,v = j_ag.split('=')
|
||||
jj_kws[k] = v
|
||||
else:
|
||||
jj_ags.append(j_ag)
|
||||
p_ags = p_ags.replace(' ','').split(',')
|
||||
for p_ag in p_ags:
|
||||
if '=' in p_ag:
|
||||
k,v = p_ag.split('=')
|
||||
pp_kws[k] = v
|
||||
else:
|
||||
pp_ags.append(p_ag)
|
||||
if len(jj_ags) == 0 and len(pp_ags) != 0:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {pp_ags[0]}''')"
|
||||
# raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in pp_ags:
|
||||
jj_ags.append(d)
|
||||
if d in pp_kws.keys():
|
||||
jj_kws[d] = None
|
||||
if len(pp_ags) > len(ags) + len(kws):
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>, There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
ags_ = []
|
||||
for i in range(len(pp_ags)):
|
||||
if i < len(ags):
|
||||
if '*' in pp_ags[i]:
|
||||
ags_.append('(' + ', '.join(ags[i:]) + ')')
|
||||
ags = ags_
|
||||
break
|
||||
else:
|
||||
ags_.append(ags[i])
|
||||
else:
|
||||
break
|
||||
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
|
||||
j_ags_flag = np.zeros(len(jj_ags))
|
||||
j_ags_values = {}
|
||||
j_kws_values = {}
|
||||
for i,ag in enumerate(ags):
|
||||
if len(pp_ags) == 0:
|
||||
ag_name = list(pp_kws.keys())[i]
|
||||
elif i < len(pp_ags):
|
||||
ag_name = pp_ags[i]
|
||||
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
|
||||
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
|
||||
else:
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,The args number is not matc{func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
|
||||
if ag_name in links.keys():
|
||||
ag_name = links[ag_name]
|
||||
if ag_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(ag_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(ag_name))] = ag
|
||||
elif ag_name in jj_kws.keys():
|
||||
j_kws_values[ag_name] = ag
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
|
||||
for i,kw in enumerate(kws):
|
||||
kw_name, kw_value = kw.split('=')
|
||||
if kw_name in links.keys():
|
||||
kw_name = links[kw_name]
|
||||
if kw_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(kw_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(kw_name))] = kw_value
|
||||
elif kw_name in jj_kws.keys():
|
||||
j_kws_values[kw_name] = kw_value
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {kw_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
|
||||
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
|
||||
if j_ags_flag.sum() < len_jj_ags:
|
||||
missing_args = []
|
||||
for i in range(len(jj_ags)):
|
||||
if j_ags_flag[i] == 0:
|
||||
missing_args.append(jj_ags[i])
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.''')"
|
||||
# raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
|
||||
if extras:
|
||||
for k in extras.keys():
|
||||
if k in jj_ags:
|
||||
j_ags_values[str(jj_ags.index(k))] = extras[k]
|
||||
elif k in jj_kws.keys():
|
||||
j_kws_values[k] = extras[k]
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.''')"
|
||||
# raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in j_ags_values:
|
||||
del j_ags_values[d]
|
||||
if d in j_kws_values.keys():
|
||||
j_kws_values.pop(d)
|
||||
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
|
||||
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
|
||||
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
return j_func
|
||||
|
||||
def dfs(self, a):
|
||||
if isinstance(a, ast.Import):
|
||||
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
|
||||
self.import_flag.append('init')
|
||||
return ast.parse('from jittor import init').body[0]
|
||||
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
|
||||
self.import_flag.append('nn')
|
||||
return ast.parse('from jittor import nn').body[0]
|
||||
if 'torch' in a.names[0].name:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.ImportFrom):
|
||||
if 'torch' in a.module:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.Call):
|
||||
for idx, ag in enumerate(a.args):
|
||||
ret = self.dfs(ag)
|
||||
if ret is not None:
|
||||
a.args[idx] = ret
|
||||
for idx, kw in enumerate(a.keywords):
|
||||
ret = self.dfs(kw)
|
||||
if ret is not None:
|
||||
a.keywords[idx] = ret
|
||||
ori_src = astunparse.unparse(a)
|
||||
func = astunparse.unparse(a.func).strip('\n').split('.')
|
||||
prefix = '.'.join(func[0:-1])
|
||||
func_name = func[-1]
|
||||
if func_name in self.unsupport_ops:
|
||||
ret = raise_unsupport(func_name, ori_src)
|
||||
return ret
|
||||
if func_name in self.pjmap:
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
|
||||
ret = self.convert_(prefix, func_name, ags, kws, ori_src)
|
||||
ret_tmp = ret
|
||||
ret = ast.parse(ret).body[0]
|
||||
if hasattr(ret,'value'):
|
||||
return ret.value
|
||||
else:
|
||||
print(ret_tmp+'\n')
|
||||
return ret
|
||||
if ".load_state_dict" in astunparse.unparse(a.func):
|
||||
a.func.attr = 'load_parameters'
|
||||
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
if len(ags) != 0:
|
||||
con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
|
||||
else:
|
||||
con = astunparse.unparse(a.func).replace('size', 'shape')
|
||||
return ast.parse(con).body[0].value
|
||||
elif isinstance(a, ast.Expr): pass
|
||||
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name):
|
||||
ret = self.replace(a)
|
||||
if ret is not None:
|
||||
print(ret)
|
||||
return ret
|
||||
elif isinstance(a, ast.FunctionDef):
|
||||
if a.name == 'forward': a.name = 'execute'
|
||||
if hasattr(a, '__dict__'):
|
||||
for k in a.__dict__.keys():
|
||||
if isinstance(a.__dict__[k], list):
|
||||
delete_flag = []
|
||||
for i,a_ in enumerate(a.__dict__[k]):
|
||||
ret = self.dfs(a_)
|
||||
if ret == 'delete':
|
||||
delete_flag.append(True)
|
||||
continue
|
||||
if ret is not None:
|
||||
a.__dict__[k][i] = ret
|
||||
delete_flag.append(False)
|
||||
tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
|
||||
a.__dict__[k] = tmp
|
||||
else:
|
||||
ret = self.dfs(a.__dict__[k])
|
||||
if ret is not None:
|
||||
a.__dict__[k] = ret
|
||||
|
||||
|
||||
def convert(code, ex_pjmaps=None):
|
||||
''' Model code converter, example:
|
||||
|
||||
from jittor.utils.pytorch_converter import convert
|
||||
|
@ -490,229 +717,13 @@ def convert(code):
|
|||
model = Model()
|
||||
print("## Jittor model:", model)
|
||||
'''
|
||||
|
||||
a = ast.parse(code)
|
||||
dfs(a)
|
||||
converter = Converter(ex_pjmaps)
|
||||
converter.dfs(a)
|
||||
a.body.insert(0, ast.parse('import jittor as jt').body[0])
|
||||
if 'init' not in import_flag:
|
||||
if 'init' not in converter.import_flag:
|
||||
a.body.insert(1, ast.parse('from jittor import init').body[0])
|
||||
if 'nn' not in import_flag:
|
||||
if 'nn' not in converter.import_flag:
|
||||
a.body.insert(2, ast.parse('from jittor import nn').body[0])
|
||||
return astunparse.unparse(a)
|
||||
|
||||
def convert_(prefix, func_name, ags, kws, ori_src):
|
||||
info = pjmap[func_name]
|
||||
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
|
||||
if p_prefix is not None and prefix in p_prefix:
|
||||
p_ags = info['pytorch']['args_prefix']
|
||||
j_ags = info['jittor']['args_prefix']
|
||||
else:
|
||||
p_ags = info['pytorch']['args']
|
||||
j_ags = info['jittor']['args']
|
||||
if 'delete' in info.keys():
|
||||
delete = info['delete']
|
||||
else:
|
||||
delete = None
|
||||
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
|
||||
j_module = info['jittor']['module']
|
||||
j_name = info['jittor']['name']
|
||||
links = info['links']
|
||||
extras = info['extras']
|
||||
jj_ags = []
|
||||
jj_kws = {}
|
||||
pp_ags = []
|
||||
pp_kws = {}
|
||||
if j_ags == '' and p_ags == '':
|
||||
# no args in Pytorch and Jittor.
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}()"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}()"
|
||||
else:
|
||||
return f"{prefix}.{j_name}()"
|
||||
else:
|
||||
j_ags = j_ags.replace(' ','').split(',')
|
||||
for j_ag in j_ags:
|
||||
if '=' in j_ag:
|
||||
k,v = j_ag.split('=')
|
||||
jj_kws[k] = v
|
||||
else:
|
||||
jj_ags.append(j_ag)
|
||||
p_ags = p_ags.replace(' ','').split(',')
|
||||
for p_ag in p_ags:
|
||||
if '=' in p_ag:
|
||||
k,v = p_ag.split('=')
|
||||
pp_kws[k] = v
|
||||
else:
|
||||
pp_ags.append(p_ag)
|
||||
if len(jj_ags) == 0 and len(pp_ags) != 0:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {pp_ags[0]}''')"
|
||||
# raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in pp_ags:
|
||||
jj_ags.append(d)
|
||||
if d in pp_kws.keys():
|
||||
jj_kws[d] = None
|
||||
if len(pp_ags) > len(ags) + len(kws):
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>, There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
ags_ = []
|
||||
for i in range(len(pp_ags)):
|
||||
if i < len(ags):
|
||||
if '*' in pp_ags[i]:
|
||||
ags_.append('(' + ', '.join(ags[i:]) + ')')
|
||||
ags = ags_
|
||||
break
|
||||
else:
|
||||
ags_.append(ags[i])
|
||||
else:
|
||||
break
|
||||
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
|
||||
j_ags_flag = np.zeros(len(jj_ags))
|
||||
j_ags_values = {}
|
||||
j_kws_values = {}
|
||||
for i,ag in enumerate(ags):
|
||||
if len(pp_ags) == 0:
|
||||
ag_name = list(pp_kws.keys())[i]
|
||||
elif i < len(pp_ags):
|
||||
ag_name = pp_ags[i]
|
||||
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
|
||||
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
|
||||
else:
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,The args number is not matc{func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
|
||||
if ag_name in links.keys():
|
||||
ag_name = links[ag_name]
|
||||
if ag_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(ag_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(ag_name))] = ag
|
||||
elif ag_name in jj_kws.keys():
|
||||
j_kws_values[ag_name] = ag
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
|
||||
for i,kw in enumerate(kws):
|
||||
kw_name, kw_value = kw.split('=')
|
||||
if kw_name in links.keys():
|
||||
kw_name = links[kw_name]
|
||||
if kw_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(kw_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(kw_name))] = kw_value
|
||||
elif kw_name in jj_kws.keys():
|
||||
j_kws_values[kw_name] = kw_value
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {kw_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
|
||||
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
|
||||
if j_ags_flag.sum() < len_jj_ags:
|
||||
missing_args = []
|
||||
for i in range(len(jj_ags)):
|
||||
if j_ags_flag[i] == 0:
|
||||
missing_args.append(jj_ags[i])
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.''')"
|
||||
# raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
|
||||
if extras:
|
||||
for k in extras.keys():
|
||||
if k in jj_ags:
|
||||
j_ags_values[str(jj_ags.index(k))] = extras[k]
|
||||
elif k in jj_kws.keys():
|
||||
j_kws_values[k] = extras[k]
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.''')"
|
||||
# raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in j_ags_values:
|
||||
j_ags_values.remove(d)
|
||||
if d in j_kws_values.keys():
|
||||
j_kws_values.pop(d)
|
||||
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
|
||||
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
|
||||
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
return j_func
|
||||
|
||||
def dfs(a):
|
||||
if isinstance(a, ast.Import):
|
||||
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
|
||||
import_flag.append('init')
|
||||
return ast.parse('from jittor import init').body[0]
|
||||
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
|
||||
import_flag.append('nn')
|
||||
return ast.parse('from jittor import nn').body[0]
|
||||
if 'torch' in a.names[0].name:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.ImportFrom):
|
||||
if 'torch' in a.module:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.Call):
|
||||
for idx, ag in enumerate(a.args):
|
||||
ret = dfs(ag)
|
||||
if ret is not None:
|
||||
a.args[idx] = ret
|
||||
for idx, kw in enumerate(a.keywords):
|
||||
ret = dfs(kw)
|
||||
if ret is not None:
|
||||
a.keywords[idx] = ret
|
||||
ori_src = astunparse.unparse(a)
|
||||
func = astunparse.unparse(a.func).strip('\n').split('.')
|
||||
prefix = '.'.join(func[0:-1])
|
||||
func_name = func[-1]
|
||||
if func_name in unsupport_ops:
|
||||
ret = raise_unsupport(func_name, ori_src)
|
||||
return ret
|
||||
if func_name in pjmap.keys():
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
|
||||
ret = convert_(prefix, func_name, ags, kws, ori_src)
|
||||
ret_tmp = ret
|
||||
ret = ast.parse(ret).body[0]
|
||||
if hasattr(ret,'value'):
|
||||
return ret.value
|
||||
else:
|
||||
print(ret_tmp+'\n')
|
||||
return ret
|
||||
if ".load_state_dict" in astunparse.unparse(a.func):
|
||||
a.func.attr = 'load_parameters'
|
||||
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
if len(ags) != 0:
|
||||
con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
|
||||
else:
|
||||
con = astunparse.unparse(a.func).replace('size', 'shape')
|
||||
return ast.parse(con).body[0].value
|
||||
elif isinstance(a, ast.Expr): pass
|
||||
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name):
|
||||
ret = replace(a)
|
||||
if ret is not None:
|
||||
print(ret)
|
||||
return ret
|
||||
elif isinstance(a, ast.FunctionDef):
|
||||
if a.name == 'forward': a.name = 'execute'
|
||||
if hasattr(a, '__dict__'):
|
||||
for k in a.__dict__.keys():
|
||||
if isinstance(a.__dict__[k], list):
|
||||
delete_flag = []
|
||||
for i,a_ in enumerate(a.__dict__[k]):
|
||||
ret = dfs(a_)
|
||||
if ret is 'delete':
|
||||
delete_flag.append(True)
|
||||
continue
|
||||
if ret is not None:
|
||||
a.__dict__[k][i] = ret
|
||||
delete_flag.append(False)
|
||||
tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
|
||||
a.__dict__[k] = tmp
|
||||
else:
|
||||
ret = dfs(a.__dict__[k])
|
||||
if ret is not None:
|
||||
a.__dict__[k] = ret
|
|
@ -0,0 +1,9 @@
|
|||
cat > /tmp/converter_server.dockerfile <<\EOF
|
||||
FROM jittor/jittor
|
||||
|
||||
RUN python3.7 -m pip install flask
|
||||
RUN apt update && apt install git -y
|
||||
EOF
|
||||
|
||||
sudo docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile .
|
||||
sudo docker run --rm jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
|
Loading…
Reference in New Issue