add unsupport_ops for converter

This commit is contained in:
zwy 2020-04-19 22:41:11 +08:00
parent 6e08781df4
commit 52317b040d
1 changed files with 42 additions and 13 deletions

View File

@ -25,6 +25,18 @@ pjmap = {
'links': {},
'extras': {},
},
'ConvTranspose2d': {
'pytorch': {
'args': "in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'"
},
'jittor': {
'module': 'nn',
'name': 'ConvTranspose',
'args': 'in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1'
},
'links': {},
'extras': {},
},
'MaxPool2d': {
'pytorch': {
'args': 'kernel_size, stride=None, padding=0, dilation=1, return_indices=False',
@ -138,7 +150,7 @@ pjmap = {
'extras': {},
},
# ***************************************************************
# torch.Tensor.xxx(...) and torch.xxx(torch.Tensor, ...)
# Convert format for function which can be writen as either torch.Tensor.xxx(...) or torch.xxx(torch.Tensor, ...)
# Example: x.reshape([2,3]) and torch.reshape(x, [2,3])
# ***************************************************************
'flatten': {
@ -208,22 +220,37 @@ pjmap = {
}
}
unsupport_ops = [
# ***************************************************************
# torch.nn
# ***************************************************************
'Parameter', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict',
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d', 'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ELU', 'Hardshrink', 'Hardtanh', 'LeakyReLU', 'LogSigmoid', 'MultiheadAttention',
'PReLU', 'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink', 'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss', 'BatchNorm1d', 'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'LocalResponseNorm', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'Embedding', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss', 'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss', 'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss', 'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_', 'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity', 'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity', 'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured', 'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
]
support_ops = {}
for key in pjmap.keys():
module = pjmap[key]['jittor']['module']
name = pjmap[key]['jittor']['name']
if module == 'nn':
support_ops[key] = name
def replace(a):
if hasattr(a, "attr") and a.attr in unsupport_ops:
raise RuntimeError(f'{a.attr} is not supported in Jittor yet. We will appreciate it if you code {a.attr} function and make pull request at https://github.com/Jittor/jittor.')
if hasattr(a, "id") and a.id in unsupport_ops:
raise RuntimeError(f'{a.id} is not supported in Jittor yet. We will appreciate it if you code {a.id} function and make pull request at https://github.com/Jittor/jittor.')
if hasattr(a, "attr"):
if a.attr == "Conv2d": a.attr = "Conv"
if a.attr == "BatchNorm2d": a.attr = "BatchNorm"
if a.attr == "ReLU": a.attr = "Relu"
if a.attr == "AvgPool2d": a.attr = "Pool"
if a.attr == "MaxPool2d": a.attr = "Pool"
if a.attr == "LeakyReLU": a.attr = "Leaky_relu"
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
if hasattr(a, "id"):
if a.id == "Conv2d": a.id = "Conv"
if a.id == "BatchNorm2d": a.id = "BatchNorm"
if a.id == "ReLU": a.id = "Relu"
if a.id == "AvgPool2d": a.id = "Pool"
if a.id == "MaxPool2d": a.id = "Pool"
if a.id == "LeakyReLU": a.id = "Leaky_relu"
if a.id in support_ops.keys(): a.id = support_ops[a.id]
import_flag = []
def convert(code):
@ -379,6 +406,8 @@ def dfs(a):
func = astunparse.unparse(a.func).strip('\n').split('.')
prefix = '.'.join(func[0:-1])
func_name = func[-1]
if func_name in unsupport_ops:
raise RuntimeError(f'{func_name} is not supported in Jittor yet. We will appreciate it if you code {func_name} function and make pull request at https://github.com/Jittor/jittor.')
if func_name in pjmap.keys():
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]