add doc to converter

This commit is contained in:
zwy 2020-05-01 00:01:28 +08:00
parent 353f582885
commit fca8224c0d
1 changed files with 44 additions and 0 deletions

View File

@ -298,6 +298,26 @@ pjmap = {
def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None):
''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool
args:
* `pytorch_func_name`: Pytorch function name
* `pytorch_args`: Pytorch parameter list
* `jittor_func_module`: to which module the Jittor function belongs
* `jittor_func_name`: Jittor function name
* `jittor_args`: Jittor parameter list
* `extras`: parameter assignment
* `links`: connection parameters
* `delete`: delete parameters
example:
from jittor.utils.pytorch_converter import pjmap_append
pjmap_append(pytorch_func_name='AvgPool2d',
pytorch_args='kernel_size, stride=None, padding=0, dilation=1, return_indices=False',
jittor_func_module='nn',
jittor_func_name='Pool',
jittor_args='kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"',
extras={"op": "'mean'"})
'''
if links == None: links = {}
if extras == None: extras = {}
if delete == None: delete = []
@ -371,6 +391,30 @@ def replace(a):
import_flag = []
def convert(code):
''' Model code converter, example:
from jittor.utils.pytorch_converter import convert
pytorch_code = """
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, 3)
self.conv2 = nn.Conv2d(10, 32, 3)
self.fc = nn.Linear(1200, 100)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
"""
jittor_code = convert(pytorch_code)
print("## Generate Jittor code:", jittor_code)
exec(jittor_code)
model = Model()
print("## Jittor model:", model)
'''
a = ast.parse(code)
dfs(a)
a.body.insert(0, ast.parse('import jittor as jt').body[0])