This commit is contained in:
Gword 2020-12-17 20:32:29 +08:00
commit 25510ddf32
47 changed files with 1226 additions and 366 deletions

View File

@ -15,7 +15,6 @@
#include <cuda_runtime.h>
#include "helper_cuda.h"
#ifdef _CUFFT_H_
// cuFFT API errors
const char *_cudaGetErrorEnum(cufftResult error) {

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.1'
__version__ = '1.2.2.8'
from . import lock
with lock.lock_scope():
ori_int = int
@ -33,9 +33,38 @@ from collections import OrderedDict
from collections.abc import Sequence, Mapping
import types
import pickle
import sys
import hashlib
import sys, os
import traceback
def safepickle(obj, path):
s = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
checksum = hashlib.sha1(s).digest()
s += bytes(checksum)
s += b"HCAJSLHD"
with open(path, 'wb') as f:
f.write(s)
def safeunpickle(path):
if path.startswith("jittorhub://"):
path = path.replace("jittorhub://", "https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/")
if path.startswith("https:") or path.startswith("http:"):
base = path.split("/")[-1]
fname = os.path.join(compiler.ck_path, base)
from jittor.utils.misc import download_url_to_local
download_url_to_local(path, base, compiler.ck_path, None)
path = fname
with open(path, "rb") as f:
s = f.read()
if not s.endswith(b"HCAJSLHD"):
return pickle.loads(s)
checksum = s[-28:-8]
s = s[:-28]
if hashlib.sha1(s).digest() != checksum:
raise ValueError("Pickle checksum does not match! path: "+path)
return pickle.loads(s)
class _call_no_record_scope:
def __enter__(self): pass
def __exit__(self, *exc): pass
@ -436,8 +465,15 @@ def display_memory_info():
core.display_memory_info(fileline)
def load(path):
pkl_file = open(path, 'rb')
model_dict = pickle.load(pkl_file)
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
model_dict = torch.load(path, map_location=torch.device('cpu'))
else:
model_dict = safeunpickle(path)
return model_dict
def _uniq(x):
@ -647,20 +683,10 @@ class Module:
params_dict = {}
for p in params:
params_dict[p.name()] = p.data
with open(path, 'wb') as f:
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
safepickle(params_dict, path)
def load(self, path):
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
import torch
except:
raise RuntimeError("pytorch need to be installed when load pth format.")
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
return
with open(path, 'rb') as f:
self.load_parameters(pickle.load(f))
self.load_parameters(load(path))
def eval(self):
def callback(parents, k, v, n):
@ -805,6 +831,11 @@ can also be None)::
def dfs(self, parents, k, callback, callback_leave=None):
pass
@classmethod
def apply(cls, *args, **kw):
func = cls()
return func(*args, **kw)
def make_module(func, exec_n_args=1):
class MakeModule(Module):
@ -904,7 +935,10 @@ def format(v, spec):
return v.item().__format__(spec)
Var.__format__ = format
def get_len(var):
return var.shape[0]
Var.__len__ = get_len
int = int32
Var.int = Var.int32
float = float32
@ -921,3 +955,4 @@ from . import contrib
from . import numpy2cupy
from .contrib import concat
from .misc import *
from . import sparse

View File

@ -74,10 +74,17 @@ def compile(compiler, flags, inputs, output, combind_build=False):
for input, obj_file in zip(inputs, obj_files):
cc = compiler
nflags = oflags
if has_cuda and input.endswith(".cu"):
nflags = convert_nvcc_flags(oflags)
cc = nvcc_path
if input.endswith(".cu"):
if has_cuda:
nflags = convert_nvcc_flags(oflags)
cc = nvcc_path
else:
continue
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
if "nan_checker" in input:
# nan checker needs to disable fast_math
cmd = cmd.replace("--use_fast_math", "")
cmd = cmd.replace("-Ofast", "-O2")
cmds.append(cmd)
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
@ -894,6 +901,8 @@ make_cache_dir(cache_path)
make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen"))
ck_path = os.path.join(cache_path, "checkpoints")
make_cache_dir(ck_path)
# build cache_compile
cc_flags += f" -I{jittor_path}/src "
@ -943,7 +952,8 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
# 3. op_utils
# 4. other
files2 = pyjt_gen_src
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
grep_args = '"c[cu]$"' if has_cuda else '"cc$"'
files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
at_beginning = [
"src/ops/op_utils.cc",
"src/event_queue.cc",

View File

@ -12,6 +12,7 @@ import numpy as np
from jittor import pool
from collections.abc import Sequence
def argmax_pool(x, size, stride, padding=0):
return pool.pool(x, size, 'maximum', padding, stride)
@ -196,8 +197,14 @@ def setitem(x, slices, value):
mask = jt.broadcast(slices, x)
value = jt.broadcast(value, x)
return x.assign(mask.ternary(value, x))
if isinstance(slices, list):
slices = tuple(slices)
if isinstance(slices, Sequence):
ss = []
for s in slices:
if isinstance(s, jt.Var) and s.dtype == "bool":
ss.extend(s.where())
else:
ss.append(s)
slices = tuple(ss)
return x.assign(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem

View File

@ -27,8 +27,9 @@ mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")
class Worker:
def __init__(self, target, args, buffer_size):
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
self.buffer = jt.RingBuffer(buffer_size)
self.buffer.keep_numpy_array(keep_numpy_array)
self.status = mp.Array('f', 5, lock=False)
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
@ -67,7 +68,8 @@ class Dataset(object):
drop_last = False,
num_workers = 0,
buffer_size = 512*1024*1024,
stop_grad = True):
stop_grad = True,
keep_numpy_array = False):
super().__init__()
self.total_len = None
self.batch_size = batch_size
@ -76,6 +78,7 @@ class Dataset(object):
self.num_workers = num_workers
self.buffer_size = buffer_size
self.stop_grad = stop_grad
self.keep_numpy_array = keep_numpy_array
def __getitem__(self, index):
raise NotImplementedError
@ -117,6 +120,8 @@ class Dataset(object):
'''
Change batch data to jittor array, such as np.ndarray, int, and float.
'''
if self.keep_numpy_array: return batch
if isinstance(batch, jt.Var): return batch
to_jt = lambda x: jt.array(x).stop_grad() \
if self.stop_grad else jt.array(x)
if isinstance(batch, np.ndarray):
@ -299,7 +304,8 @@ Example::
self.num_idle_c = mp.Condition(self.gid.get_lock())
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),
buffer_size=self.buffer_size)
buffer_size=self.buffer_size,
keep_numpy_array=self.keep_numpy_array)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)

View File

@ -96,6 +96,27 @@ def repeat(x, *shape):
jt.Var.repeat = repeat
def repeat_interleave(x,repeats,dim=None):
# TODO repeats is jt.Var
assert isinstance(repeats,int)
if dim == None:
x = x.reshape(-1)
dim=0
if dim<0: dim+=x.ndim
tar_shape = list(x.shape)
x_shape = list(x.shape)
tar_shape[dim] = tar_shape[dim]*repeats
dims = []
for i in range(len(tar_shape)):
if dim==i:
dims.append(f"i{i}/{repeats}")
else:
dims.append(f"i{i}")
return x.reindex(tar_shape,dims)
jt.Var.repeat_interleave = repeat_interleave
def chunk(x, chunks, dim=0):
r'''
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
@ -209,15 +230,18 @@ def flip(x, dim=0):
>>> x.flip(1)
[[4 3 2 1]]
'''
assert isinstance(dim, int)
if dim<0:
dim+=x.ndim
assert dim>=0 and dim<len(x.shape)
if isinstance(dim, int):
dim = [dim]
for i in range(len(dim)):
if dim[i]<0:
dim[i] += x.ndim
assert dim[i]>=0 and dim[i]<x.ndim
dim = set(dim)
tar_dims = []
for i in range(len(x.shape)):
if i == dim:
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
if i in dim:
tar_dims.append(f"xshape{i}-1-i{i}")
else:
tar_dims.append(f"i{i}")
return x.reindex(x.shape, tar_dims)
@ -335,16 +359,37 @@ def unbind(x, dim=0):
jt.Var.unbind = unbind
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
assert range == None
assert isinstance(range, tuple) or range is None
assert scale_each == False
if isinstance(x, list): x = jt.stack(x)
if normalize: x = (x - x.min()) / (x.max() - x.min())
if normalize:
if range is None: x = (x - x.min()) / (x.max() - x.min())
else: x = (x - range[0]) / (range[1] - range[0])
b,c,h,w = x.shape
ncol = math.ceil(b / nrow)
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
def save_image(
x,
filepath,
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range = None,
scale_each = False,
pad_value = 0,
format = None
):
from PIL import Image
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
im = Image.fromarray(ndarr)
im.save(filepath, format=format)
def _ntuple(n):
def parse(x):
@ -582,12 +627,11 @@ def gather(x,dim,index):
return x.reindex(ins)
jt.Var.gather = gather
def prod(x,dim=0):
def _prod(x,dim=0):
x = jt.log(x)
x = x.sum(dim=dim)
return jt.exp(x)
jt.Var.prod = prod
def cumsum_forward(np, data):
a = data['inputs'][0]

View File

@ -61,6 +61,7 @@ class AlexNet(nn.Module):
x = self.classifier(x)
return x
def alexnet(**kwargs):
def alexnet(pretrained=False, **kwargs):
model = AlexNet(**kwargs)
if pretrained: model.load("jittorhub://alexnet.pkl")
return model

View File

@ -21,7 +21,7 @@ def densenet121(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet121.pkl")
return model
def densenet161(pretrained=False, **kwargs):
@ -32,7 +32,7 @@ def densenet161(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet161.pkl")
return model
def densenet169(pretrained=False, **kwargs):
@ -43,7 +43,7 @@ def densenet169(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet169.pkl")
return model
def densenet201(pretrained=False, **kwargs):
@ -54,7 +54,7 @@ def densenet201(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
'''
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
assert not pretrained, "pretrained doesn't support now"
if pretrained: model.load("jittorhub://densenet201.pkl")
return model

View File

@ -12,8 +12,10 @@ from jittor import nn
__all__ = ['GoogLeNet', 'googlenet']
def googlenet(**kwargs):
return GoogLeNet(**kwargs)
def googlenet(pretrained=False, **kwargs):
model = GoogLeNet(**kwargs)
if pretrained: model.load("jittorhub://googlenet.pkl")
return model
class GoogLeNet(nn.Module):
""" GoogLeNet model architecture.

View File

@ -4,7 +4,9 @@ from jittor import nn
__all__ = ['Inception3', 'inception_v3']
def inception_v3(pretrained=False, progress=True, **kwargs):
return Inception3(**kwargs)
model = Inception3(**kwargs)
if pretrained: model.load("jittorhub://inception_v3.pkl")
return model
class Inception3(nn.Module):
""" Inceptionv3 model architecture.

View File

@ -90,18 +90,22 @@ class MNASNet(nn.Module):
x = x.mean([2, 3])
return self.classifier(x)
def mnasnet0_5(**kwargs):
def mnasnet0_5(pretrained=False, **kwargs):
model = MNASNet(0.5, **kwargs)
if pretrained: model.load("jittorhub://mnasnet0_5.pkl")
return model
def mnasnet0_75(**kwargs):
def mnasnet0_75(pretrained=False, **kwargs):
model = MNASNet(0.75, **kwargs)
if pretrained: model.load("jittorhub://mnasnet0_75.pkl")
return model
def mnasnet1_0(**kwargs):
def mnasnet1_0(pretrained=False, **kwargs):
model = MNASNet(1.0, **kwargs)
if pretrained: model.load("jittorhub://mnasnet1_0.pkl")
return model
def mnasnet1_3(**kwargs):
def mnasnet1_3(pretrained=False, **kwargs):
model = MNASNet(1.3, **kwargs)
if pretrained: model.load("jittorhub://mnasnet1_3.pkl")
return model

View File

@ -93,7 +93,8 @@ class MobileNetV2(nn.Module):
def execute(self, x):
return self._forward_impl(x)
def mobilenet_v2():
def mobilenet_v2(pretrained=False):
model = MobileNetV2()
if pretrained: model.load("jittorhub://mobilenet_v2.pkl")
return model

View File

@ -154,19 +154,26 @@ def _resnet(block, layers, **kwargs):
model = ResNet(block, layers, **kwargs)
return model
def Resnet18(**kwargs):
return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
def Resnet18(pretrained=False, **kwargs):
model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained: model.load("jittorhub://resnet18.pkl")
return model
resnet18 = Resnet18
def Resnet34(**kwargs):
return _resnet( BasicBlock, [3, 4, 6, 3], **kwargs)
def Resnet34(pretrained=False, **kwargs):
model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet34.pkl")
return model
resnet34 = Resnet34
def Resnet50(**kwargs):
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
def Resnet50(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet50.pkl")
return model
resnet50 = Resnet50
def Resnet101(**kwargs):
def Resnet101(pretrained=False, **kwargs):
"""
ResNet-101 model architecture.
@ -180,28 +187,38 @@ def Resnet101(**kwargs):
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
resnet101 = Resnet101
def Resnet152(**kwargs):
return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
def Resnet152(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: model.load("jittorhub://resnet152.pkl")
return model
resnet152 = Resnet152
def Resnext50_32x4d(**kwargs):
def Resnext50_32x4d(pretrained=False, **kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://resnext50_32x4d.pkl")
return model
resnext50_32x4d = Resnext50_32x4d
def Resnext101_32x8d(**kwargs):
def Resnext101_32x8d(pretrained=False, **kwargs):
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load("jittorhub://resnext101_32x8d.pkl")
return model
resnext101_32x8d = Resnext101_32x8d
def Wide_resnet50_2(**kwargs):
def Wide_resnet50_2(pretrained=False, **kwargs):
kwargs['width_per_group'] = (64 * 2)
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load("jittorhub://wide_resnet50_2.pkl")
return model
wide_resnet50_2 = Wide_resnet50_2
def Wide_resnet101_2(**kwargs):
def Wide_resnet101_2(pretrained=False, **kwargs):
kwargs['width_per_group'] = (64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load("jittorhub://wide_resnet101_2.pkl")
return model
wide_resnet101_2 = Wide_resnet101_2

View File

@ -93,14 +93,22 @@ def _shufflenetv2(arch, *args):
model = ShuffleNetV2(*args)
return model
def shufflenet_v2_x0_5():
return _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
def shufflenet_v2_x0_5(pretrained=False):
model = _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
if pretrained: model.load("jittorhub://shufflenet_v2_x0_5.pkl")
return model
def shufflenet_v2_x1_0():
return _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
def shufflenet_v2_x1_0(pretrained=False):
model = _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
if pretrained: model.load("jittorhub://shufflenet_v2_x1_0.pkl")
return model
def shufflenet_v2_x1_5():
return _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
def shufflenet_v2_x1_5(pretrained=False):
model = _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
if pretrained: model.load("jittorhub://shufflenet_v2_x1_5.pkl")
return model
def shufflenet_v2_x2_0():
return _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
def shufflenet_v2_x2_0(pretrained=False):
model = _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
if pretrained: model.load("jittorhub://shufflenet_v2_x2_0.pkl")
return model

View File

@ -83,8 +83,12 @@ def _squeezenet(version, **kwargs):
model = SqueezeNet(version, **kwargs)
return model
def squeezenet1_0(**kwargs):
return _squeezenet('1_0', **kwargs)
def squeezenet1_0(pretrained=False, **kwargs):
model = _squeezenet('1_0', **kwargs)
if pretrained: model.load("jittorhub://squeezenet1_0.pkl")
return model
def squeezenet1_1(**kwargs):
return _squeezenet('1_1', **kwargs)
def squeezenet1_1(pretrained=False, **kwargs):
model = _squeezenet('1_1', **kwargs)
if pretrained: model.load("jittorhub://squeezenet1_1.pkl")
return model

View File

@ -67,33 +67,49 @@ def _vgg(arch, cfg, batch_norm, **kwargs):
return model
def vgg11(**kwargs):
return _vgg('vgg11', 'A', False, **kwargs)
def vgg11(pretrained=False, **kwargs):
model = _vgg('vgg11', 'A', False, **kwargs)
if pretrained: model.load("jittorhub://vgg11.pkl")
return model
def vgg11_bn(**kwargs):
return _vgg('vgg11_bn', 'A', True, **kwargs)
def vgg11_bn(pretrained=False, **kwargs):
model = _vgg('vgg11_bn', 'A', True, **kwargs)
if pretrained: model.load("jittorhub://vgg11_bn.pkl")
return model
def vgg13(**kwargs):
return _vgg('vgg13', 'B', False, **kwargs)
def vgg13(pretrained=False, **kwargs):
model = _vgg('vgg13', 'B', False, **kwargs)
if pretrained: model.load("jittorhub://vgg13.pkl")
return model
def vgg13_bn(**kwargs):
return _vgg('vgg13_bn', 'B', True, **kwargs)
def vgg13_bn(pretrained=False, **kwargs):
model = _vgg('vgg13_bn', 'B', True, **kwargs)
if pretrained: model.load("jittorhub://vgg13_bn.pkl")
return model
def vgg16(**kwargs):
return _vgg('vgg16', 'D', False, **kwargs)
def vgg16(pretrained=False, **kwargs):
model = _vgg('vgg16', 'D', False, **kwargs)
if pretrained: model.load("jittorhub://vgg16.pkl")
return model
def vgg16_bn(**kwargs):
return _vgg('vgg16_bn', 'D', True, **kwargs)
def vgg16_bn(pretrained=False, **kwargs):
model = _vgg('vgg16_bn', 'D', True, **kwargs)
if pretrained: model.load("jittorhub://vgg16_bn.pkl")
return model
def vgg19(**kwargs):
return _vgg('vgg19', 'E', False, **kwargs)
def vgg19(pretrained=False, **kwargs):
model = _vgg('vgg19', 'E', False, **kwargs)
if pretrained: model.load("jittorhub://vgg19.pkl")
return model
def vgg19_bn(**kwargs):
return _vgg('vgg19_bn', 'E', True, **kwargs)
def vgg19_bn(pretrained=False, **kwargs):
model = _vgg('vgg19_bn', 'E', True, **kwargs)
if pretrained: model.load("jittorhub://vgg19_bn.pkl")
return model

View File

@ -154,6 +154,7 @@ def get_init_var_rand(shape, dtype):
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0)
def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1))
def sign(x):
one = jt.ones(x.shape)
x = jt.ternary(x>0, one, x)
@ -165,6 +166,13 @@ def gelu(x):
r = erf*x*.5
return r
class ELU(Module):
def __init__(self,alpha=1.0):
self.alpha=alpha
def execute(self,x):
return elu(x,self.alpha)
class PReLU(Module):
def __init__(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
@ -238,6 +246,30 @@ def smooth_l1_loss(y_true, y_pred,reduction="mean"):
else:
raise ValueError(f'not support {reduction}')
def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'):
assert output.ndim<=2 and output.ndim>0 and target.ndim==1
n_classes = output.shape[-1]
assert weight is None or weight.numel()==n_classes
assert ignore_index<0 or ignore_index<n_classes
if weight is None:
weight = jt.ones((n_classes,))
if ignore_index>0:
weight[ignore_index]=0
if output.ndim==2:
index = jt.index((output.shape[0],),dim=0)
loss = -output[index,target]*weight[target]
else:
loss = -output[target[0]]*weight[target[0]]
if reduction=="mean":
total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum()
return loss.sum()/total_weight
elif reduction=="sum":
return loss.sum()
elif reduction=="none":
return loss
else:
raise ValueError(f'not support {reduction}')
class CrossEntropyLoss(Module):
def __init__(self,ignore_index=None):
self.ignore_index = ignore_index
@ -330,6 +362,9 @@ class Dropout(Module):
output = output * noise / (1.0 - self.p) # div keep prob
return output
def dropout(x,p=0.5,is_train=False):
return Dropout(p=p,is_train=is_train)(x)
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
self.in_features = in_features
@ -707,6 +742,45 @@ class ConvTranspose(Module):
y = y + b
return y
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
x = input
N,C,H,W = x.shape
i,o,h,w = weight.shape
assert C==i
assert groups==1, "Group conv not supported yet."
stride = stride if isinstance(stride, tuple) else (stride, stride)
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# added
padding = padding if isinstance(padding, tuple) else (padding, padding)
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert output_padding[0] < max(stride[0], dilation[0]) and \
output_padding[1] < max(stride[1], dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
stride_h, stride_w = stride
padding_h, padding_w = padding
dilation_h, dilation_w = dilation
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if isinstance(bias, jt.Var):
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
else:
assert not bias, "Bias should be none or jittor var"
return y
conv_transpose2d = conv_transpose
def pad(x,padding, mode='constant', value=0):
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'

View File

@ -33,6 +33,9 @@ class Optimizer(object):
assert isinstance(pg, dict)
self.param_groups.append(pg)
self.n_step = 0
def add_param_group(self, group):
self.param_groups.append(group)
@property
def defaults(self):

53
python/jittor/sparse.py Normal file
View File

@ -0,0 +1,53 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# Xiangli Li <190569238@qq.com>
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
class SparseVar:
def __init__(self,indices,values,shape):
assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector)
self.indices = indices
self.values = values
self.shape = shape
self.ndim = len(shape)
def _indices(self):
return self.indices
def _values(self):
return self.values
def t(self):
indices = list(self.indices.split(1,dim=0))
indices[-1],indices[-2] = indices[-2],indices[-1]
indices = jt.contrib.concat(indices,dim=0)
shape = list(self.shape)
shape[-1],shape[-2] = shape[-2],shape[-1]
shape = jt.NanoVector(shape)
return SparseVar(indices,self.values,shape)
def to_dense(self):
ret = jt.zeros(self.shape,self.values.dtype)
indices = tuple(self.indices.split(1,dim=0))
ret[indices]=self.values
return ret
def sparse_array(indices,values,shape):
return SparseVar(indices,values,shape)
def spmm(spase_x,y):
assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var)
assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0]
# TODO
x = spase_x.to_dense()
return jt.matmul(x,y)

View File

@ -60,6 +60,41 @@ class TestConvTranspose(unittest.TestCase):
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
def test_function(self):
def check(data_shape, weights_shape, stride=1, dilation=1):
N,C,H,W = data_shape
i,o,h,w = weights_shape
img = np.random.rand(N,C,H,W).astype("float32")
weights = np.random.rand(i,o,h,w).astype("float32")
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
m1.weight.data = weights
m2.weight.data = torch.Tensor(weights)
x = jt.array(img)
# out1 = m1(x)
out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False)
mask = jt.random(out1.shape)
out1 = out1*mask
tx = torch.Tensor(img)
tx.requires_grad = True
out2 = m2(tx) * torch.Tensor(mask.data)
with jt.log_capture_scope(log_silent=1,
log_vprefix="var_re=0,conv=0,op.cc=100") as logs:
assert np.allclose(out1.data, out2.data)
dx, dw = jt.grad(out1, [x, m1.weight])
jt.sync([dx, dw])
out2.sum().backward()
assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3)
assert np.allclose(dx.data, tx.grad.numpy())
assert len(find_log_with_re(logs, "conv")) == 3
check((4, 5, 10, 10), (5, 6, 3, 3))
check((4, 5, 10, 10), (5, 6, 3, 3), 2)
check((4, 5, 100, 100), (5, 6, 4, 4), 2)
check((4, 5, 100, 100), (5, 6, 4, 4), 3)
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
if __name__ == "__main__":
unittest.main()

View File

@ -26,6 +26,19 @@ class TestFunction(unittest.TestCase):
da = jt.grad(b, a)
assert da.data == -1
def test_apply(self):
class MyFunc(Function):
def execute(self, x):
return x+1
def grad(self, grad):
return grad-2
a = jt.ones(1)
func = MyFunc.apply
b = func(a)
da = jt.grad(b, a)
assert da.data == -1
def test2(self):
class MyFunc(Function):
def execute(self, x):

View File

@ -40,6 +40,21 @@ class TestLoss(unittest.TestCase):
jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_nll_loss(self):
tc_loss = tnn.functional.nll_loss
jt_loss = jnn.nll_loss
output=np.random.randn(10,10).astype(np.float32)
target=np.random.randint(10, size=(10))
jt_y=jt_loss(jt.array(output), jt.array(target),reduction='mean')
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),reduction='mean')
assert np.allclose(jt_y.numpy(), tc_y.numpy())
output=np.random.randn(10,10).astype(np.float32)
target=np.random.randint(10, size=(10))
weight=np.random.randn(10,).astype(np.float32)
jt_y=jt_loss(jt.array(output), jt.array(target),jt.array(weight),reduction='mean')
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),torch.from_numpy(weight),reduction='mean')
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_cross_entropy_loss(self):
jt_loss=jnn.CrossEntropyLoss()

View File

@ -54,6 +54,7 @@ class TestPad(unittest.TestCase):
check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1))
check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2))
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
check_equal(torch.Tensor(arr).flip([2,3]), jt.array(arr).flip([2,3]))
print('pass flip test ...')
def test_cross(self):
@ -83,8 +84,13 @@ class TestPad(unittest.TestCase):
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4))
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1))
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)))
print('pass make_grid test ...')
def test_save_image(self):
arr = jt.array(np.random.randn(16,3,10,10))
jt.save_image(arr, "/tmp/a.jpg")
def test_unbind(self):
arr = np.random.randn(2,3,4)
for dim in range(len(arr.shape)):

View File

@ -28,7 +28,7 @@ def check_equal(arr, j_layer, p_layer):
pytorch_arr = torch.Tensor(arr)
jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(),rtol=1e-5,atol=1e-5)
@unittest.skipIf(skip_this_test, "No Torch found")
class TestRelu(unittest.TestCase):
@ -61,6 +61,15 @@ class TestRelu(unittest.TestCase):
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
# ***************************************************************
# Test ELU Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.ELU(), tnn.ELU())
check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3))
check_equal(arr, jnn.ELU(2), tnn.ELU(2))
check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9))
# ***************************************************************
# Test GELU Layer

View File

@ -0,0 +1,118 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
skip_this_test = False
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSetitem(unittest.TestCase):
def test_setitem(self):
arr0 = jt.random((4,2,2))
data0 = jt.ones((2,2))
arr0[1] = data0
arr0.sync()
data0.data[0,0] = 0
assert arr0[1,0,0] == 0
arr00 = jt.random((4,2,2))
data00 = jt.ones((2,2))
# share memory will fail if d has an edge to other nodes.
tmp = data00 + 1
arr00[1] = data00
arr00.sync()
data00.data[0,0] = 0
assert arr00[1,0,0] == 0
arr1 = jt.random((4,2,2))
data1 = jt.zeros((2,2))
arr1[3,:,0:2] = data1
arr1.sync()
data1.data[0,0] = 1
assert arr1[3,0,0] == 1
arr21 = jt.ones((2,2))
arr22 = jt.ones((2,2)) * 2
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
arr2.sync()
arr21.data[0,0] = 3
arr22.data[0,0] = 4
assert arr2[0,0] == 3
assert arr2[2,0] == 4
def test_getitem(self):
# test for different slice type
arr0 = jt.random((4,3))
arr0_res = arr0[2,:]
arr0_res.data[1] = 1
assert arr0[2,1] == 1
arr1 = jt.array([1,2,3,4])
arr1_res = arr1[None]
arr1_res.data[0,2] = -1
assert arr1[2] == -1
arr2 = jt.array([1,2,3,4])
arr2_res = arr2[...]
arr2_res.data[2] = -1
assert arr2[2] == -1
arr3 = jt.array([1,2,3,4])
arr3_res = arr3[3]
arr3_res.data[0] = -1
assert arr3[3] == -1
arr4 = jt.random((4,2,3,3))
arr4_res = arr4[...,:,:]
arr4_res.data[0,0,1,1] = 1
assert arr4[0,0,1,1] == 1
arr5 = jt.random((4,2,3,3))
arr5_res = arr5[1:3,:,:,:]
arr5_res.data[1,0,1,1] = 1
assert arr5[2,0,1,1] == 1
arr6 = jt.random((4,2,3,3))
arr6_res = arr6[1]
arr6_res.data[0,1,1] = 1
assert arr6[1,0,1,1] == 1
# test for different data type (float32/float64/bool/int8/int32)
arr_float32 = jt.random((4,2,3))
arr_float32_res = arr_float32[1:3,:,:]
arr_float32_res.data[0,0,0] = 1
assert arr_float32[1,0,0] == 1
arr_float32_res.data[1,1,2] = 1
assert arr_float32[2,1,2] == 1
arr_float32[1,0,0] = 0
# getitem and setitem do not conflict
assert arr_float32_res[0,0,0] == 1
arr_bool = jt.bool(np.ones((4,2,3)))
arr_bool_res = arr_bool[1:3,:,:]
arr_bool_res.data[0,0,0] = False
assert arr_bool[1,0,0] == False
arr_bool_res.data[0,0,1] = False
assert arr_bool[1,0,1] == False
arr_float64 = jt.random((4,2,3), dtype='float64')
arr_float64_res = arr_float64[1:3,:,:]
arr_float64_res.data[0,0,0] = 1
assert arr_float64[1,0,0] == 1
arr_float64_res.data[1,1,2] = 1
assert arr_float64[2,1,2] == 1
arr_int32 = jt.ones((4,2,3), dtype='int32')
arr_int32_res = arr_int32[1:3,:,:]
arr_int32_res.data[0,0,0] = 0
assert arr_int32[1,0,0] == 0
arr_int32_res.data[1,1,2] = 0
assert arr_int32[2,1,2] == 0
if __name__ == "__main__":
unittest.main()

View File

@ -140,6 +140,11 @@ class TestSlice(unittest.TestCase):
a[c] = 0
assert (a.data == [1,2,3,0,0]).all()
def test_numpy_scalar_slice(self):
a = jt.random((2,2))
b = np.array([1])[0]
assert a[b].shape == [2]
if __name__ == "__main__":

View File

@ -0,0 +1,39 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Xiangli Li <1905692338@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSparse(unittest.TestCase):
def test_sparse_var(self):
indices = np.array([[0,1,1],[2,0,2]])
values = np.array([3,4,5]).astype(np.float32)
shape = [2,3]
jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape))
torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape))
jt_numpy = jt_array.to_dense().numpy()
torch_numpy = torch_tensor.to_dense().numpy()
assert np.allclose(jt_numpy,torch_numpy)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,27 @@
from flask import Flask
from flask import request
from flask import jsonify
app = Flask(__name__)
import json
from jittor.utils.pytorch_converter import convert
@app.route('/', methods=["GET", "POST"])
def hello():
msg = request
data = msg.data.decode("utf-8")
try:
data = json.loads(data)
src = data["src"]
pjmap = json.loads(data["pjmap"])
jt_src = convert(src, pjmap)
except Exception as e:
jt_src = str(e)
response = jsonify(jt_src=jt_src)
# Enable Access-Control-Allow-Origin
response.headers.add("Access-Control-Allow-Origin", "*")
return response
if __name__ == '__main__':
app.run(host="0.0.0.0")

View File

@ -38,7 +38,7 @@ def download_url_to_local(url, filename, root_folder, md5):
ensure_dir(root_folder)
file_path = os.path.join(root_folder, filename)
if check_file_exist(file_path, md5):
print("Data file has been downloaded and verified")
return
else:
try:
print('Downloading ' + url + ' to ' + file_path)

View File

@ -179,6 +179,18 @@ pjmap = {
'links': {},
'extras': {'affine': 'None'},
},
'Parameter':{
'pytorch': {
'args': "data,require_grad=True"
},
'jittor': {
'module': 'jt',
'name': 'array',
'args': 'data,dtype=None',
},
'links': {},
'extras': {},
},
'Dropout2d': {
'pytorch': {
'args': 'p=0.5, inplace=False',
@ -351,6 +363,32 @@ pjmap = {
}
}
unsupport_ops = [
# ***************************************************************
# torch.nn
# ***************************************************************
'ModuleDict', 'ParameterList', 'ParameterDict',
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
]
def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None):
''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool
@ -393,58 +431,268 @@ def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_fun
'delete': delete,
}
unsupport_ops = [
# ***************************************************************
# torch.nn
# ***************************************************************
'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
]
support_ops = {}
for key in pjmap.keys():
module = pjmap[key]['jittor']['module']
name = pjmap[key]['jittor']['name']
if module == 'nn':
support_ops[key] = name
def raise_unsupport(name, ori_src):
ret = f"raise RuntimeError('''original source: <{ori_src.strip()}>, {name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.''')"
print(ret+'\n')
ret = ast.parse(ret).body[0]
return ret
def raise_unsupport(name):
raise RuntimeError(f'{name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.')
class Converter:
def __init__(self, ex_pjmap):
import copy
self.pjmap = copy.deepcopy(pjmap)
if ex_pjmap:
self.pjmap.update(ex_pjmap)
self.unsupport_ops = set(unsupport_ops)
support_ops = {}
for key in self.pjmap.keys():
module = self.pjmap[key]['jittor']['module']
name = self.pjmap[key]['jittor']['name']
if module == 'nn':
support_ops[key] = name
if key in self.unsupport_ops:
self.unsupport_ops.remove(key)
self.support_ops = support_ops
self.import_flag = []
def replace(a):
if hasattr(a, "attr") and a.attr in unsupport_ops:
raise_unsupport(a.attr)
if hasattr(a, "id") and a.id in unsupport_ops:
raise_unsupport(a.id)
if hasattr(a, "attr"):
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
def replace(self, a):
if hasattr(a, "attr") and a.attr in self.unsupport_ops:
ori_src = astunparse.unparse(a)
return raise_unsupport(a.attr, ori_src)
if hasattr(a, "id"):
if a.id in support_ops.keys(): a.id = support_ops[a.id]
if hasattr(a, "id") and a.id in self.unsupport_ops:
ori_src = astunparse.unparse(a)
return raise_unsupport(a.id, ori_src)
import_flag = []
def convert(code):
if hasattr(a, "attr"):
if a.attr in self.support_ops.keys(): a.attr = self.support_ops[a.attr]
if hasattr(a, "id"):
if a.id in self.support_ops.keys(): a.id = self.support_ops[a.id]
return None
def convert_(self, prefix, func_name, ags, kws, ori_src):
info = self.pjmap[func_name]
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
if p_prefix is not None and prefix in p_prefix:
p_ags = info['pytorch']['args_prefix']
j_ags = info['jittor']['args_prefix']
else:
p_ags = info['pytorch']['args']
j_ags = info['jittor']['args']
if 'delete' in info.keys():
delete = info['delete']
else:
delete = None
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
j_module = info['jittor']['module']
j_name = info['jittor']['name']
links = info['links']
extras = info['extras']
jj_ags = []
jj_kws = {}
pp_ags = []
pp_kws = {}
if j_ags == '' and p_ags == '':
# no args in Pytorch and Jittor.
if p_prefix is None:
return f"{j_module}.{j_name}()"
else:
if prefix in p_prefix:
return f"{j_prefix}.{j_name}()"
else:
return f"{prefix}.{j_name}()"
else:
j_ags = j_ags.replace(' ','').split(',')
for j_ag in j_ags:
if '=' in j_ag:
k,v = j_ag.split('=')
jj_kws[k] = v
else:
jj_ags.append(j_ag)
p_ags = p_ags.replace(' ','').split(',')
for p_ag in p_ags:
if '=' in p_ag:
k,v = p_ag.split('=')
pp_kws[k] = v
else:
pp_ags.append(p_ag)
if len(jj_ags) == 0 and len(pp_ags) != 0:
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {pp_ags[0]}''')"
# raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
if delete is not None:
for d in delete:
if d in pp_ags:
jj_ags.append(d)
if d in pp_kws.keys():
jj_kws[d] = None
if len(pp_ags) > len(ags) + len(kws):
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>, There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}''')"
# raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
ags_ = []
for i in range(len(pp_ags)):
if i < len(ags):
if '*' in pp_ags[i]:
ags_.append('(' + ', '.join(ags[i:]) + ')')
ags = ags_
break
else:
ags_.append(ags[i])
else:
break
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}''')"
# raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
j_ags_flag = np.zeros(len(jj_ags))
j_ags_values = {}
j_kws_values = {}
for i,ag in enumerate(ags):
if len(pp_ags) == 0:
ag_name = list(pp_kws.keys())[i]
elif i < len(pp_ags):
ag_name = pp_ags[i]
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
else:
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,The args number is not matc{func_name} in Jittor has no Attribute {ag_name}''')"
# raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
if ag_name in links.keys():
ag_name = links[ag_name]
if ag_name in jj_ags:
j_ags_flag[jj_ags.index(ag_name)] = 1
j_ags_values[str(jj_ags.index(ag_name))] = ag
elif ag_name in jj_kws.keys():
j_kws_values[ag_name] = ag
else:
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {ag_name}''')"
# raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
for i,kw in enumerate(kws):
kw_name, kw_value = kw.split('=')
if kw_name in links.keys():
kw_name = links[kw_name]
if kw_name in jj_ags:
j_ags_flag[jj_ags.index(kw_name)] = 1
j_ags_values[str(jj_ags.index(kw_name))] = kw_value
elif kw_name in jj_kws.keys():
j_kws_values[kw_name] = kw_value
else:
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {kw_name}''')"
# raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
if j_ags_flag.sum() < len_jj_ags:
missing_args = []
for i in range(len(jj_ags)):
if j_ags_flag[i] == 0:
missing_args.append(jj_ags[i])
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.''')"
# raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
if extras:
for k in extras.keys():
if k in jj_ags:
j_ags_values[str(jj_ags.index(k))] = extras[k]
elif k in jj_kws.keys():
j_kws_values[k] = extras[k]
else:
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.''')"
# raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
if delete is not None:
for d in delete:
if d in j_ags_values:
del j_ags_values[d]
if d in j_kws_values.keys():
j_kws_values.pop(d)
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
if p_prefix is None:
return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
else:
if prefix in p_prefix:
return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
else:
return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
return j_func
def dfs(self, a):
if isinstance(a, ast.Import):
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
self.import_flag.append('init')
return ast.parse('from jittor import init').body[0]
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
self.import_flag.append('nn')
return ast.parse('from jittor import nn').body[0]
if 'torch' in a.names[0].name:
return 'delete'
elif isinstance(a, ast.ImportFrom):
if 'torch' in a.module:
return 'delete'
elif isinstance(a, ast.Call):
for idx, ag in enumerate(a.args):
ret = self.dfs(ag)
if ret is not None:
a.args[idx] = ret
for idx, kw in enumerate(a.keywords):
ret = self.dfs(kw)
if ret is not None:
a.keywords[idx] = ret
ori_src = astunparse.unparse(a)
func = astunparse.unparse(a.func).strip('\n').split('.')
prefix = '.'.join(func[0:-1])
func_name = func[-1]
if func_name in self.unsupport_ops:
ret = raise_unsupport(func_name, ori_src)
return ret
if func_name in self.pjmap:
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
ret = self.convert_(prefix, func_name, ags, kws, ori_src)
ret_tmp = ret
ret = ast.parse(ret).body[0]
if hasattr(ret,'value'):
return ret.value
else:
print(ret_tmp+'\n')
return ret
if ".load_state_dict" in astunparse.unparse(a.func):
a.func.attr = 'load_parameters'
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
if len(ags) != 0:
con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
else:
con = astunparse.unparse(a.func).replace('size', 'shape')
return ast.parse(con).body[0].value
elif isinstance(a, ast.Expr): pass
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name):
ret = self.replace(a)
if ret is not None:
print(ret)
return ret
elif isinstance(a, ast.FunctionDef):
if a.name == 'forward': a.name = 'execute'
if hasattr(a, '__dict__'):
for k in a.__dict__.keys():
if isinstance(a.__dict__[k], list):
delete_flag = []
for i,a_ in enumerate(a.__dict__[k]):
ret = self.dfs(a_)
if ret == 'delete':
delete_flag.append(True)
continue
if ret is not None:
a.__dict__[k][i] = ret
delete_flag.append(False)
tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
a.__dict__[k] = tmp
else:
ret = self.dfs(a.__dict__[k])
if ret is not None:
a.__dict__[k] = ret
def convert(code, ex_pjmaps=None):
''' Model code converter, example:
from jittor.utils.pytorch_converter import convert
@ -469,209 +717,13 @@ def convert(code):
model = Model()
print("## Jittor model:", model)
'''
a = ast.parse(code)
dfs(a)
converter = Converter(ex_pjmaps)
converter.dfs(a)
a.body.insert(0, ast.parse('import jittor as jt').body[0])
if 'init' not in import_flag:
if 'init' not in converter.import_flag:
a.body.insert(1, ast.parse('from jittor import init').body[0])
if 'nn' not in import_flag:
if 'nn' not in converter.import_flag:
a.body.insert(2, ast.parse('from jittor import nn').body[0])
return astunparse.unparse(a)
def convert_(prefix, func_name, ags, kws):
info = pjmap[func_name]
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
if p_prefix is not None and prefix in p_prefix:
p_ags = info['pytorch']['args_prefix']
j_ags = info['jittor']['args_prefix']
else:
p_ags = info['pytorch']['args']
j_ags = info['jittor']['args']
if 'delete' in info.keys():
delete = info['delete']
else:
delete = None
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
j_module = info['jittor']['module']
j_name = info['jittor']['name']
links = info['links']
extras = info['extras']
jj_ags = []
jj_kws = {}
pp_ags = []
pp_kws = {}
if j_ags == '' and p_ags == '':
# no args in Pytorch and Jittor.
if p_prefix is None:
return f"{j_module}.{j_name}()"
else:
if prefix in p_prefix:
return f"{j_prefix}.{j_name}()"
else:
return f"{prefix}.{j_name}()"
else:
j_ags = j_ags.replace(' ','').split(',')
for j_ag in j_ags:
if '=' in j_ag:
k,v = j_ag.split('=')
jj_kws[k] = v
else:
jj_ags.append(j_ag)
p_ags = p_ags.replace(' ','').split(',')
for p_ag in p_ags:
if '=' in p_ag:
k,v = p_ag.split('=')
pp_kws[k] = v
else:
pp_ags.append(p_ag)
if len(jj_ags) == 0 and len(pp_ags) != 0:
raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
if delete is not None:
for d in delete:
if d in pp_ags:
jj_ags.append(d)
if d in pp_kws.keys():
jj_kws[d] = None
if len(pp_ags) > len(ags) + len(kws):
raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
ags_ = []
for i in range(len(pp_ags)):
if i < len(ags):
if '*' in pp_ags[i]:
ags_.append('(' + ', '.join(ags[i:]) + ')')
ags = ags_
break
else:
ags_.append(ags[i])
else:
break
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
j_ags_flag = np.zeros(len(jj_ags))
j_ags_values = {}
j_kws_values = {}
for i,ag in enumerate(ags):
if len(pp_ags) == 0:
ag_name = list(pp_kws.keys())[i]
elif i < len(pp_ags):
ag_name = pp_ags[i]
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
else:
raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
if ag_name in links.keys():
ag_name = links[ag_name]
if ag_name in jj_ags:
j_ags_flag[jj_ags.index(ag_name)] = 1
j_ags_values[str(jj_ags.index(ag_name))] = ag
elif ag_name in jj_kws.keys():
j_kws_values[ag_name] = ag
else:
raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
for i,kw in enumerate(kws):
kw_name, kw_value = kw.split('=')
if kw_name in links.keys():
kw_name = links[kw_name]
if kw_name in jj_ags:
j_ags_flag[jj_ags.index(kw_name)] = 1
j_ags_values[str(jj_ags.index(kw_name))] = kw_value
elif kw_name in jj_kws.keys():
j_kws_values[kw_name] = kw_value
else:
raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
if j_ags_flag.sum() < len_jj_ags:
missing_args = []
for i in range(len(jj_ags)):
if j_ags_flag[i] == 0:
missing_args.append(jj_ags[i])
raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
if extras:
for k in extras.keys():
if k in jj_ags:
j_ags_values[str(jj_ags.index(k))] = extras[k]
elif k in jj_kws.keys():
j_kws_values[k] = extras[k]
else:
raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
if delete is not None:
for d in delete:
if d in j_ags_values:
j_ags_values.remove(d)
if d in j_kws_values.keys():
j_kws_values.pop(d)
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
if p_prefix is None:
return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
else:
if prefix in p_prefix:
return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
else:
return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
return j_func
def dfs(a):
if isinstance(a, ast.Import):
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
import_flag.append('init')
return ast.parse('from jittor import init').body[0]
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
import_flag.append('nn')
return ast.parse('from jittor import nn').body[0]
if 'torch' in a.names[0].name:
return 'delete'
elif isinstance(a, ast.ImportFrom):
if 'torch' in a.module:
return 'delete'
elif isinstance(a, ast.Call):
for idx, ag in enumerate(a.args):
ret = dfs(ag)
if ret is not None:
a.args[idx] = ret
for idx, kw in enumerate(a.keywords):
ret = dfs(kw)
if ret is not None:
a.keywords[idx] = ret
func = astunparse.unparse(a.func).strip('\n').split('.')
prefix = '.'.join(func[0:-1])
func_name = func[-1]
if func_name in unsupport_ops:
raise_unsupport(func_name)
if func_name in pjmap.keys():
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
ret = convert_(prefix, func_name, ags, kws)
return ast.parse(ret).body[0].value
if ".load_state_dict" in astunparse.unparse(a.func):
a.func.attr = 'load_parameters'
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
if len(ags) != 0:
con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
else:
con = astunparse.unparse(a.func).replace('size', 'shape')
return ast.parse(con).body[0].value
elif isinstance(a, ast.Expr): pass
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name): replace(a)
elif isinstance(a, ast.FunctionDef):
if a.name == 'forward': a.name = 'execute'
if hasattr(a, '__dict__'):
for k in a.__dict__.keys():
if isinstance(a.__dict__[k], list):
delete_flag = []
for i,a_ in enumerate(a.__dict__[k]):
ret = dfs(a_)
if ret is 'delete':
delete_flag.append(True)
continue
if ret is not None:
a.__dict__[k][i] = ret
delete_flag.append(False)
tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
a.__dict__[k] = tmp
else:
ret = dfs(a.__dict__[k])
if ret is not None:
a.__dict__[k] = ret

View File

@ -1 +1 @@
84596508776983dce645fc4ef77c7f35700549d5
d2eb452b81e704188346a788d8d53889f7b12007

View File

@ -0,0 +1,14 @@
cat > /tmp/converter_server.dockerfile <<\EOF
FROM jittor/jittor
RUN python3.7 -m pip install flask
RUN apt update && apt install git -y
EOF
docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile .
# docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
while true; do
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
sleep 10
done

View File

@ -21,6 +21,7 @@
#include "fuser.h"
#include "profiler/profiler_guard.h"
#include "parallel_compiler.h"
#include "misc/nan_checker.h"
namespace jittor {
@ -46,7 +47,10 @@ void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, i
for (Op* op : fused_op.ops) {
uint fid1 = op->custom_data;
int iid = 0;
for (Var* v : op->inputs()) {
for (auto ve : op->_inputs) {
// this is a control dependency edge, dont used
if (ve.back->index<0) continue;
auto v = ve.node->var();
iid++;
int iop_id;
int iv_id;
@ -450,6 +454,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
if (use_cuda)
checkCudaErrors(cudaDeviceSynchronize());
#endif
for (Var* var : op->outputs())
check_nan(var);
}
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs();

View File

@ -7,6 +7,10 @@
#pragma once
#include "common.h"
#include "mem/allocator.h"
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
#endif
namespace jittor {

View File

@ -22,6 +22,7 @@ static auto make_number = get_op_info("number")
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
if (dout == nullptr) return nullptr;
if (x_index<0) return nullptr;
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
auto dx = op->grad(out, dout, x, x_index);

73
src/misc/nan_checker.cc Normal file
View File

@ -0,0 +1,73 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cfloat>
#include <cmath>
#include "misc/nan_checker.h"
#ifdef HAS_CUDA
#include "misc/cuda_flags.h"
#include <cuda_runtime.h>
#include "helper_cuda.h"
#endif
#include "mem/allocator.h"
#include "op.h"
namespace jittor {
#ifdef HAS_CUDA
extern void check_nan_float32(float32* ptr, int64 num);
extern void check_nan_float64(float64* ptr, int64 num);
#endif
bool check_nan(Var* v) {
if (!v->dtype().is_float()) return true;
if (v->input() && (
v->input()->name() == string("empty") ||
v->input()->name() == string("setitem")))
return true;
#ifdef HAS_CUDA
if (v->allocator->is_cuda()) {
if (v->dtype() == ns_float32) {
check_nan_float32((float32*)v->mem_ptr, v->num);
} else
if (v->dtype() == ns_float64) {
check_nan_float64((float64*)v->mem_ptr, v->num);
}
ASSERT(cudaDeviceSynchronize()==0) << "detect nan or inf at" << v;
} else
#endif
{
if (v->dtype() == ns_float32) {
auto* __restrict__ ptr = v->ptr<float32>();
auto num = v->num;
bool ok = true;
int64 i=0;
for (; i<num; i++) {
if (std::isinf(ptr[i]) || std::isnan(ptr[i])) {
ok = false;
break;
}
}
ASSERT(ok) << "detect nan at index" << i << v;
}
if (v->dtype() == ns_float64) {
auto* __restrict__ ptr = v->ptr<float64>();
auto num = v->num;
bool ok = true;
int64 i=0;
for (; i<num; i++) {
if (std::isinf(ptr[i]) || std::isnan(ptr[i])) {
ok = false;
break;
}
}
ASSERT(ok) << "detect nan at index" << i << v;
}
}
return true;
}
}

47
src/misc/nan_checker.cu Normal file
View File

@ -0,0 +1,47 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "misc/nan_checker.h"
#include "misc/cuda_flags.h"
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include <cassert>
namespace jittor {
#ifdef HAS_CUDA
__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
__trap();
}
}
__global__ void _check_nan_float64(float64* __restrict__ ptr, int64 num) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
__trap();
}
}
void check_nan_float64(float64* ptr, int64 num) {
int block_num = std::max((int64)1, (num-1)/1024+1);
int thread_num = std::min((int64)1024, num);
_check_nan_float64<<<block_num, thread_num>>>(ptr, num);
}
void check_nan_float32(float32* ptr, int64 num) {
int block_num = std::max((int64)1, (num-1)/1024+1);
int thread_num = std::min((int64)1024, num);
_check_nan_float32<<<block_num, thread_num>>>(ptr, num);
}
#endif
}

13
src/misc/nan_checker.h Normal file
View File

@ -0,0 +1,13 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "var.h"
namespace jittor {
bool check_nan(Var* v);
}

View File

@ -312,6 +312,10 @@ void SetitemOp::jit_run() {
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
#endif
if (data->allocation == in->allocation &&
data->allocator == in->allocator)
return;
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
index_t did = 0 @for(d, 0, ODIM, @if((BMASK>>d)&1,+ i@d * dstride@d));
@for(d, 0, IDIM, index_t iid@d =

View File

@ -14,11 +14,6 @@ namespace jittor {
static auto make_transpose = get_op_info("transpose")
.get_constructor<VarPtr, Var*, NanoVector>();
#ifdef HAS_CUDA
static auto make_reshape = get_op_info("reshape")
.get_constructor<VarPtr, Var*, NanoVector>();
#endif
TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
int i=0;
for (; i<axes.size(); i++)

View File

@ -16,8 +16,19 @@ inline static bool fast_strcmp(const char* a, const char* b) {
return !*b;
}
// add dependency b -> a
static inline void add_dependency(Node* a, const vector<Node*>& b) {
a->add_inputs(b);
auto edge = a->_inputs.end();
for (int i=0; i<b.size(); i++) {
edge = std::prev(edge);
// set -1 mean this is a control dependency edge
edge->back->index = -1;
}
}
static void setitem_inplace(SetitemOp* op) {
// LOGir << "setitem_inplace";
// LOGir << "in setitem_inplace";
auto input = op->inputs().front();
if (!(input->outputs().size() == 1 &&
input->forward_liveness<=1 &&
@ -29,8 +40,7 @@ static void setitem_inplace(SetitemOp* op) {
// make sure input op will not use input
auto input_name = input_op->name();
if (!(input_op->type() == OpType::broadcast ||
fast_strcmp(input_name, "array") ||
fast_strcmp(input_name, "empty") ||
input_op->inputs().size() == 0 ||
fast_strcmp(input_name, "setitem") ||
fast_strcmp(input_name, "getitem")))
// TODO: inplace getitem maybe risky, getitem maybe inplace too
@ -38,7 +48,50 @@ static void setitem_inplace(SetitemOp* op) {
}
auto output = op->outputs().front();
output->share_with(input);
// LOGir << "apply setitem_inplace on" << op << "input:" << input << "output:" << output;
// return;
// LOGir << "pass setitem optim one";
auto data = op->input(1);
input_op = input->input();
if (input_op && input_op->inputs().size() == 1) {
input_op = input_op->inputs().front()->input();
}
if (input_op && input_op->inputs().size() == 1) {
input_op = input_op->inputs().front()->input();
}
VarSlices vs = op->vs;
if (!(data->is_finished() == 0 &&
(data->outputs().size() == 1 ||
(!input_op
|| input_op->inputs().size() == 0))))
return;
if (data->allocator)
return;
auto in_shape = input->shape;
for (int i = vs.n - 1; i > 0; --i) {
VarSlice s = vs.slices[i];
if (!(s.is_slice())) return;
Slice ss = s.slice;
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
return;
}
VarSlice s = vs.slices[0];
if (s.is_var()) return;
auto size = 0;
if (s.is_int())
size = s.i * input->size / in_shape[0];
else if (s.is_slice())
size = s.slice.start * input->size / in_shape[0];
add_dependency(data->input(), {input->node()});
data->share_with(input, size);
// LOGir << "pass setitem optim two";
}
struct BBox {
@ -103,9 +156,43 @@ static void setitem_grad_opt(GetitemOp* op) {
}
static void getitem_inplace(GetitemOp* op) {
// LOGir << "in getitem_inplace";
auto in = op->inputs().front();
auto ou = op->outputs().front();
// return if input or output's shape is variable
if (in->num < 0 || ou->num < 0)
return;
VarSlices vs = op->vs;
auto in_shape = in->shape;
for (int i = vs.n - 1; i > 0; --i) {
VarSlice s = vs.slices[i];
if (!(s.is_slice())) return;
Slice ss = s.slice;
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
return;
}
VarSlice s = vs.slices[0];
if (s.is_var()) return;
auto size = 0;
if (s.is_int())
size = s.i * in->size / in_shape[0];
else if (s.is_slice())
size = s.slice.start * in->size / in_shape[0];
ou->share_with(in, size);
// LOGir << "pass getitem_inplace";
}
void SetitemOp::graph_optimize() {
// LOGir << "hello graph_optimize";
setitem_inplace(this);
(void)setitem_inplace;
}
void GetitemOp::graph_optimize() {
@ -113,6 +200,9 @@ void GetitemOp::graph_optimize() {
// LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this);
(void)setitem_grad_opt;
// (void)getitem_inplace;
getitem_inplace(this);
(void)getitem_inplace;
}
}

View File

@ -91,7 +91,8 @@ void PassManager::run_passes() {
run_pass<SolveConflictDefinePass>();
run_pass<MergeLoopVarPass>();
run_pass<ConstVarPass>();
// tmp disable ConstVarPass
// run_pass<ConstVarPass>();
run_pass<RestridePass>();

View File

@ -734,9 +734,10 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
} else
if (obj == Py_None) {
var_slice->set_none();
} else
} else
if (PyObject_TypeCheck(obj, PyNumberArrType_Type)) {
PyArrayDescr_Proxy array_descr = {.type_num = 5}; // 5: int32
PyArrayDescr_Proxy array_descr;
array_descr.type_num = 5; // 5: int32
int value;
PyArray_CastScalarToCtype(obj, &value, &array_descr);
var_slice->set_int(value);

View File

@ -142,7 +142,7 @@ static PyObject* to_py_object3(ArrayArgs&& a) {
return to_py_object(jit_op_maker::array_(move(a)));
}
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset, bool keep_numpy_array) {
auto t = rb->pop_t<uint8>(offset);
if (t==0) {
auto x = rb->pop_t<int64>(offset);
@ -161,7 +161,7 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
auto size = rb->pop_t<int64>(offset);
PyObjHolder list(PyList_New(size));
for (uint i=0; i<size; i++) {
PyObject* o = pop_py_object(rb, offset);
PyObject* o = pop_py_object(rb, offset, keep_numpy_array);
PyList_SET_ITEM(list.obj, i, o);
}
return list.release();
@ -170,8 +170,8 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
auto size = rb->pop_t<int64>(offset);
PyObjHolder dict(PyDict_New());
for (int64 i=0; i<size; i++) {
PyObject* key = pop_py_object(rb, offset);
PyObject* value = pop_py_object(rb, offset);
PyObject* key = pop_py_object(rb, offset, keep_numpy_array);
PyObject* value = pop_py_object(rb, offset, keep_numpy_array);
PyDict_SetItem(dict.obj, key, value);
}
return dict.release();
@ -185,7 +185,10 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
size *= args.shape[i];
rb->pop(size, offset);
args.ptr = rb->get_ptr(size, offset);
return to_py_object3(move(args));
if (!keep_numpy_array)
return to_py_object3(move(args));
else
return to_py_object<ArrayArgs>(args);
}
if (t==6) {
return pop_py_object_pickle(rb, offset);
@ -212,7 +215,7 @@ void PyMultiprocessRingBuffer::push(PyObject* obj) {
PyObject* PyMultiprocessRingBuffer::pop() {
auto offset = rb->l;
auto obj = pop_py_object(rb, offset);
auto obj = pop_py_object(rb, offset, _keep_numpy_array);
rb->commit_pop(offset);
return obj;
}

View File

@ -13,6 +13,7 @@ namespace jittor {
// @pyjt(RingBuffer)
struct PyMultiprocessRingBuffer {
RingBuffer* rb;
bool _keep_numpy_array = false;
// @pyjt(__init__)
PyMultiprocessRingBuffer(uint64 size);
// @pyjt(__dealloc__)
@ -23,6 +24,8 @@ struct PyMultiprocessRingBuffer {
PyObject* pop();
// @pyjt(clear)
inline void clear() { rb->clear(); }
// @pyjt(keep_numpy_array)
inline void keep_numpy_array(bool keep) { _keep_numpy_array = keep; }
// @pyjt(stop)
inline void stop() { rb->stop(); }
// @pyjt(is_stop)

View File

@ -64,7 +64,7 @@ bool Var::alloc(Allocator* allocator) {
if (mem_ptr) return true;
if (auto* x = (Var*)(this->allocator)) {
if (x->allocator->share_with(size, x->allocation)) {
mem_ptr = x->mem_ptr;
mem_ptr = ((char*) x->mem_ptr) + allocation;
allocation = x->allocation;
this->allocator = x->allocator;
return true;

View File

@ -42,7 +42,7 @@ struct Var : Node {
int64_t numel();
void set_shape(NanoVector shape);
bool alloc(Allocator* allocator);
inline void share_with(Var* x) { CHECK_EXIST; allocator = (Allocator*)x; }
inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; }
};
struct VarPtr {