mirror of https://github.com/Jittor/Jittor
add doc to converter
This commit is contained in:
parent
353f582885
commit
fca8224c0d
|
@ -298,6 +298,26 @@ pjmap = {
|
|||
|
||||
|
||||
def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None):
|
||||
''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool
|
||||
args:
|
||||
* `pytorch_func_name`: Pytorch function name
|
||||
* `pytorch_args`: Pytorch parameter list
|
||||
* `jittor_func_module`: to which module the Jittor function belongs
|
||||
* `jittor_func_name`: Jittor function name
|
||||
* `jittor_args`: Jittor parameter list
|
||||
* `extras`: parameter assignment
|
||||
* `links`: connection parameters
|
||||
* `delete`: delete parameters
|
||||
|
||||
example:
|
||||
from jittor.utils.pytorch_converter import pjmap_append
|
||||
pjmap_append(pytorch_func_name='AvgPool2d',
|
||||
pytorch_args='kernel_size, stride=None, padding=0, dilation=1, return_indices=False',
|
||||
jittor_func_module='nn',
|
||||
jittor_func_name='Pool',
|
||||
jittor_args='kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"',
|
||||
extras={"op": "'mean'"})
|
||||
'''
|
||||
if links == None: links = {}
|
||||
if extras == None: extras = {}
|
||||
if delete == None: delete = []
|
||||
|
@ -371,6 +391,30 @@ def replace(a):
|
|||
|
||||
import_flag = []
|
||||
def convert(code):
|
||||
''' Model code converter, example:
|
||||
|
||||
from jittor.utils.pytorch_converter import convert
|
||||
pytorch_code = """
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, 3)
|
||||
self.conv2 = nn.Conv2d(10, 32, 3)
|
||||
self.fc = nn.Linear(1200, 100)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
"""
|
||||
jittor_code = convert(pytorch_code)
|
||||
print("## Generate Jittor code:", jittor_code)
|
||||
exec(jittor_code)
|
||||
model = Model()
|
||||
print("## Jittor model:", model)
|
||||
'''
|
||||
a = ast.parse(code)
|
||||
dfs(a)
|
||||
a.body.insert(0, ast.parse('import jittor as jt').body[0])
|
||||
|
|
Loading…
Reference in New Issue