mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
25510ddf32
|
@ -15,7 +15,6 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
|
||||
|
||||
#ifdef _CUFFT_H_
|
||||
// cuFFT API errors
|
||||
const char *_cudaGetErrorEnum(cufftResult error) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.1'
|
||||
__version__ = '1.2.2.8'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -33,9 +33,38 @@ from collections import OrderedDict
|
|||
from collections.abc import Sequence, Mapping
|
||||
import types
|
||||
import pickle
|
||||
import sys
|
||||
import hashlib
|
||||
import sys, os
|
||||
import traceback
|
||||
|
||||
|
||||
def safepickle(obj, path):
|
||||
s = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
checksum = hashlib.sha1(s).digest()
|
||||
s += bytes(checksum)
|
||||
s += b"HCAJSLHD"
|
||||
with open(path, 'wb') as f:
|
||||
f.write(s)
|
||||
|
||||
def safeunpickle(path):
|
||||
if path.startswith("jittorhub://"):
|
||||
path = path.replace("jittorhub://", "https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/")
|
||||
if path.startswith("https:") or path.startswith("http:"):
|
||||
base = path.split("/")[-1]
|
||||
fname = os.path.join(compiler.ck_path, base)
|
||||
from jittor.utils.misc import download_url_to_local
|
||||
download_url_to_local(path, base, compiler.ck_path, None)
|
||||
path = fname
|
||||
with open(path, "rb") as f:
|
||||
s = f.read()
|
||||
if not s.endswith(b"HCAJSLHD"):
|
||||
return pickle.loads(s)
|
||||
checksum = s[-28:-8]
|
||||
s = s[:-28]
|
||||
if hashlib.sha1(s).digest() != checksum:
|
||||
raise ValueError("Pickle checksum does not match! path: "+path)
|
||||
return pickle.loads(s)
|
||||
|
||||
class _call_no_record_scope:
|
||||
def __enter__(self): pass
|
||||
def __exit__(self, *exc): pass
|
||||
|
@ -436,8 +465,15 @@ def display_memory_info():
|
|||
core.display_memory_info(fileline)
|
||||
|
||||
def load(path):
|
||||
pkl_file = open(path, 'rb')
|
||||
model_dict = pickle.load(pkl_file)
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
model_dict = torch.load(path, map_location=torch.device('cpu'))
|
||||
else:
|
||||
model_dict = safeunpickle(path)
|
||||
return model_dict
|
||||
|
||||
def _uniq(x):
|
||||
|
@ -647,20 +683,10 @@ class Module:
|
|||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(params_dict, f, pickle.HIGHEST_PROTOCOL)
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path):
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
raise RuntimeError("pytorch need to be installed when load pth format.")
|
||||
self.load_parameters(torch.load(path, map_location=torch.device('cpu')))
|
||||
return
|
||||
with open(path, 'rb') as f:
|
||||
self.load_parameters(pickle.load(f))
|
||||
self.load_parameters(load(path))
|
||||
|
||||
def eval(self):
|
||||
def callback(parents, k, v, n):
|
||||
|
@ -805,6 +831,11 @@ can also be None)::
|
|||
def dfs(self, parents, k, callback, callback_leave=None):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kw):
|
||||
func = cls()
|
||||
return func(*args, **kw)
|
||||
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
@ -904,7 +935,10 @@ def format(v, spec):
|
|||
return v.item().__format__(spec)
|
||||
Var.__format__ = format
|
||||
|
||||
def get_len(var):
|
||||
return var.shape[0]
|
||||
|
||||
Var.__len__ = get_len
|
||||
int = int32
|
||||
Var.int = Var.int32
|
||||
float = float32
|
||||
|
@ -921,3 +955,4 @@ from . import contrib
|
|||
from . import numpy2cupy
|
||||
from .contrib import concat
|
||||
from .misc import *
|
||||
from . import sparse
|
||||
|
|
|
@ -74,10 +74,17 @@ def compile(compiler, flags, inputs, output, combind_build=False):
|
|||
for input, obj_file in zip(inputs, obj_files):
|
||||
cc = compiler
|
||||
nflags = oflags
|
||||
if has_cuda and input.endswith(".cu"):
|
||||
nflags = convert_nvcc_flags(oflags)
|
||||
cc = nvcc_path
|
||||
if input.endswith(".cu"):
|
||||
if has_cuda:
|
||||
nflags = convert_nvcc_flags(oflags)
|
||||
cc = nvcc_path
|
||||
else:
|
||||
continue
|
||||
cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}"
|
||||
if "nan_checker" in input:
|
||||
# nan checker needs to disable fast_math
|
||||
cmd = cmd.replace("--use_fast_math", "")
|
||||
cmd = cmd.replace("-Ofast", "-O2")
|
||||
cmds.append(cmd)
|
||||
jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output)
|
||||
cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}"
|
||||
|
@ -894,6 +901,8 @@ make_cache_dir(cache_path)
|
|||
make_cache_dir(os.path.join(cache_path, "jit"))
|
||||
make_cache_dir(os.path.join(cache_path, "obj_files"))
|
||||
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||
ck_path = os.path.join(cache_path, "checkpoints")
|
||||
make_cache_dir(ck_path)
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += f" -I{jittor_path}/src "
|
||||
|
@ -943,7 +952,8 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)
|
|||
# 3. op_utils
|
||||
# 4. other
|
||||
files2 = pyjt_gen_src
|
||||
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
|
||||
grep_args = '"c[cu]$"' if has_cuda else '"cc$"'
|
||||
files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines()
|
||||
at_beginning = [
|
||||
"src/ops/op_utils.cc",
|
||||
"src/event_queue.cc",
|
||||
|
|
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
from jittor import pool
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def argmax_pool(x, size, stride, padding=0):
|
||||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
||||
|
@ -196,8 +197,14 @@ def setitem(x, slices, value):
|
|||
mask = jt.broadcast(slices, x)
|
||||
value = jt.broadcast(value, x)
|
||||
return x.assign(mask.ternary(value, x))
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
if isinstance(slices, Sequence):
|
||||
ss = []
|
||||
for s in slices:
|
||||
if isinstance(s, jt.Var) and s.dtype == "bool":
|
||||
ss.extend(s.where())
|
||||
else:
|
||||
ss.append(s)
|
||||
slices = tuple(ss)
|
||||
return x.assign(x.setitem(slices, value))
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
|
|
|
@ -27,8 +27,9 @@ mpi = jt.mpi
|
|||
img_open_hook = HookTimer(Image, "open")
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
|
||||
self.buffer = jt.RingBuffer(buffer_size)
|
||||
self.buffer.keep_numpy_array(keep_numpy_array)
|
||||
|
||||
self.status = mp.Array('f', 5, lock=False)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
|
||||
|
@ -67,7 +68,8 @@ class Dataset(object):
|
|||
drop_last = False,
|
||||
num_workers = 0,
|
||||
buffer_size = 512*1024*1024,
|
||||
stop_grad = True):
|
||||
stop_grad = True,
|
||||
keep_numpy_array = False):
|
||||
super().__init__()
|
||||
self.total_len = None
|
||||
self.batch_size = batch_size
|
||||
|
@ -76,6 +78,7 @@ class Dataset(object):
|
|||
self.num_workers = num_workers
|
||||
self.buffer_size = buffer_size
|
||||
self.stop_grad = stop_grad
|
||||
self.keep_numpy_array = keep_numpy_array
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -117,6 +120,8 @@ class Dataset(object):
|
|||
'''
|
||||
Change batch data to jittor array, such as np.ndarray, int, and float.
|
||||
'''
|
||||
if self.keep_numpy_array: return batch
|
||||
if isinstance(batch, jt.Var): return batch
|
||||
to_jt = lambda x: jt.array(x).stop_grad() \
|
||||
if self.stop_grad else jt.array(x)
|
||||
if isinstance(batch, np.ndarray):
|
||||
|
@ -299,7 +304,8 @@ Example::
|
|||
self.num_idle_c = mp.Condition(self.gid.get_lock())
|
||||
for i in range(self.num_workers):
|
||||
w = Worker(target=self._worker_main, args=(i,),
|
||||
buffer_size=self.buffer_size)
|
||||
buffer_size=self.buffer_size,
|
||||
keep_numpy_array=self.keep_numpy_array)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
|
|
|
@ -96,6 +96,27 @@ def repeat(x, *shape):
|
|||
|
||||
jt.Var.repeat = repeat
|
||||
|
||||
def repeat_interleave(x,repeats,dim=None):
|
||||
# TODO repeats is jt.Var
|
||||
assert isinstance(repeats,int)
|
||||
if dim == None:
|
||||
x = x.reshape(-1)
|
||||
dim=0
|
||||
if dim<0: dim+=x.ndim
|
||||
|
||||
tar_shape = list(x.shape)
|
||||
x_shape = list(x.shape)
|
||||
tar_shape[dim] = tar_shape[dim]*repeats
|
||||
dims = []
|
||||
for i in range(len(tar_shape)):
|
||||
if dim==i:
|
||||
dims.append(f"i{i}/{repeats}")
|
||||
else:
|
||||
dims.append(f"i{i}")
|
||||
return x.reindex(tar_shape,dims)
|
||||
|
||||
jt.Var.repeat_interleave = repeat_interleave
|
||||
|
||||
def chunk(x, chunks, dim=0):
|
||||
r'''
|
||||
Splits a var into a specific number of chunks. Each chunk is a view of the input var.
|
||||
|
@ -209,15 +230,18 @@ def flip(x, dim=0):
|
|||
>>> x.flip(1)
|
||||
[[4 3 2 1]]
|
||||
'''
|
||||
assert isinstance(dim, int)
|
||||
if dim<0:
|
||||
dim+=x.ndim
|
||||
assert dim>=0 and dim<len(x.shape)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
for i in range(len(dim)):
|
||||
if dim[i]<0:
|
||||
dim[i] += x.ndim
|
||||
assert dim[i]>=0 and dim[i]<x.ndim
|
||||
dim = set(dim)
|
||||
|
||||
tar_dims = []
|
||||
for i in range(len(x.shape)):
|
||||
if i == dim:
|
||||
tar_dims.append(f"{x.shape[dim]-1}-i{i}")
|
||||
if i in dim:
|
||||
tar_dims.append(f"xshape{i}-1-i{i}")
|
||||
else:
|
||||
tar_dims.append(f"i{i}")
|
||||
return x.reindex(x.shape, tar_dims)
|
||||
|
@ -335,16 +359,37 @@ def unbind(x, dim=0):
|
|||
jt.Var.unbind = unbind
|
||||
|
||||
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
|
||||
assert range == None
|
||||
assert isinstance(range, tuple) or range is None
|
||||
assert scale_each == False
|
||||
if isinstance(x, list): x = jt.stack(x)
|
||||
if normalize: x = (x - x.min()) / (x.max() - x.min())
|
||||
if normalize:
|
||||
if range is None: x = (x - x.min()) / (x.max() - x.min())
|
||||
else: x = (x - range[0]) / (range[1] - range[0])
|
||||
b,c,h,w = x.shape
|
||||
ncol = math.ceil(b / nrow)
|
||||
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
|
||||
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
|
||||
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
|
||||
|
||||
def save_image(
|
||||
x,
|
||||
filepath,
|
||||
nrow: int = 8,
|
||||
padding: int = 2,
|
||||
normalize: bool = False,
|
||||
range = None,
|
||||
scale_each = False,
|
||||
pad_value = 0,
|
||||
format = None
|
||||
):
|
||||
from PIL import Image
|
||||
grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value,
|
||||
normalize=normalize, range=range, scale_each=scale_each)
|
||||
|
||||
ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(filepath, format=format)
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
|
@ -582,12 +627,11 @@ def gather(x,dim,index):
|
|||
return x.reindex(ins)
|
||||
jt.Var.gather = gather
|
||||
|
||||
def prod(x,dim=0):
|
||||
def _prod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
x = x.sum(dim=dim)
|
||||
return jt.exp(x)
|
||||
|
||||
jt.Var.prod = prod
|
||||
|
||||
def cumsum_forward(np, data):
|
||||
a = data['inputs'][0]
|
||||
|
|
|
@ -61,6 +61,7 @@ class AlexNet(nn.Module):
|
|||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def alexnet(**kwargs):
|
||||
def alexnet(pretrained=False, **kwargs):
|
||||
model = AlexNet(**kwargs)
|
||||
if pretrained: model.load("jittorhub://alexnet.pkl")
|
||||
return model
|
||||
|
|
|
@ -21,7 +21,7 @@ def densenet121(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet121.pkl")
|
||||
return model
|
||||
|
||||
def densenet161(pretrained=False, **kwargs):
|
||||
|
@ -32,7 +32,7 @@ def densenet161(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet161.pkl")
|
||||
return model
|
||||
|
||||
def densenet169(pretrained=False, **kwargs):
|
||||
|
@ -43,7 +43,7 @@ def densenet169(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet169.pkl")
|
||||
return model
|
||||
|
||||
def densenet201(pretrained=False, **kwargs):
|
||||
|
@ -54,7 +54,7 @@ def densenet201(pretrained=False, **kwargs):
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
'''
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
|
||||
assert not pretrained, "pretrained doesn't support now"
|
||||
if pretrained: model.load("jittorhub://densenet201.pkl")
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -12,8 +12,10 @@ from jittor import nn
|
|||
|
||||
__all__ = ['GoogLeNet', 'googlenet']
|
||||
|
||||
def googlenet(**kwargs):
|
||||
return GoogLeNet(**kwargs)
|
||||
def googlenet(pretrained=False, **kwargs):
|
||||
model = GoogLeNet(**kwargs)
|
||||
if pretrained: model.load("jittorhub://googlenet.pkl")
|
||||
return model
|
||||
|
||||
class GoogLeNet(nn.Module):
|
||||
""" GoogLeNet model architecture.
|
||||
|
|
|
@ -4,7 +4,9 @@ from jittor import nn
|
|||
__all__ = ['Inception3', 'inception_v3']
|
||||
|
||||
def inception_v3(pretrained=False, progress=True, **kwargs):
|
||||
return Inception3(**kwargs)
|
||||
model = Inception3(**kwargs)
|
||||
if pretrained: model.load("jittorhub://inception_v3.pkl")
|
||||
return model
|
||||
|
||||
class Inception3(nn.Module):
|
||||
""" Inceptionv3 model architecture.
|
||||
|
|
|
@ -90,18 +90,22 @@ class MNASNet(nn.Module):
|
|||
x = x.mean([2, 3])
|
||||
return self.classifier(x)
|
||||
|
||||
def mnasnet0_5(**kwargs):
|
||||
def mnasnet0_5(pretrained=False, **kwargs):
|
||||
model = MNASNet(0.5, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet0_5.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet0_75(**kwargs):
|
||||
def mnasnet0_75(pretrained=False, **kwargs):
|
||||
model = MNASNet(0.75, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet0_75.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet1_0(**kwargs):
|
||||
def mnasnet1_0(pretrained=False, **kwargs):
|
||||
model = MNASNet(1.0, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet1_0.pkl")
|
||||
return model
|
||||
|
||||
def mnasnet1_3(**kwargs):
|
||||
def mnasnet1_3(pretrained=False, **kwargs):
|
||||
model = MNASNet(1.3, **kwargs)
|
||||
if pretrained: model.load("jittorhub://mnasnet1_3.pkl")
|
||||
return model
|
||||
|
|
|
@ -93,7 +93,8 @@ class MobileNetV2(nn.Module):
|
|||
def execute(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def mobilenet_v2():
|
||||
def mobilenet_v2(pretrained=False):
|
||||
model = MobileNetV2()
|
||||
if pretrained: model.load("jittorhub://mobilenet_v2.pkl")
|
||||
return model
|
||||
|
||||
|
|
|
@ -154,19 +154,26 @@ def _resnet(block, layers, **kwargs):
|
|||
model = ResNet(block, layers, **kwargs)
|
||||
return model
|
||||
|
||||
def Resnet18(**kwargs):
|
||||
return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
def Resnet18(pretrained=False, **kwargs):
|
||||
model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet18.pkl")
|
||||
return model
|
||||
resnet18 = Resnet18
|
||||
|
||||
def Resnet34(**kwargs):
|
||||
return _resnet( BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
def Resnet34(pretrained=False, **kwargs):
|
||||
model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet34.pkl")
|
||||
return model
|
||||
resnet34 = Resnet34
|
||||
|
||||
def Resnet50(**kwargs):
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
def Resnet50(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet50.pkl")
|
||||
return model
|
||||
|
||||
resnet50 = Resnet50
|
||||
|
||||
def Resnet101(**kwargs):
|
||||
def Resnet101(pretrained=False, **kwargs):
|
||||
"""
|
||||
ResNet-101 model architecture.
|
||||
|
||||
|
@ -180,28 +187,38 @@ def Resnet101(**kwargs):
|
|||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
resnet101 = Resnet101
|
||||
|
||||
def Resnet152(**kwargs):
|
||||
return _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
def Resnet152(pretrained=False, **kwargs):
|
||||
model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnet152.pkl")
|
||||
return model
|
||||
resnet152 = Resnet152
|
||||
|
||||
def Resnext50_32x4d(**kwargs):
|
||||
def Resnext50_32x4d(pretrained=False, **kwargs):
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnext50_32x4d.pkl")
|
||||
return model
|
||||
resnext50_32x4d = Resnext50_32x4d
|
||||
|
||||
def Resnext101_32x8d(**kwargs):
|
||||
def Resnext101_32x8d(pretrained=False, **kwargs):
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://resnext101_32x8d.pkl")
|
||||
return model
|
||||
resnext101_32x8d = Resnext101_32x8d
|
||||
|
||||
def Wide_resnet50_2(**kwargs):
|
||||
def Wide_resnet50_2(pretrained=False, **kwargs):
|
||||
kwargs['width_per_group'] = (64 * 2)
|
||||
return _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://wide_resnet50_2.pkl")
|
||||
return model
|
||||
wide_resnet50_2 = Wide_resnet50_2
|
||||
|
||||
def Wide_resnet101_2(**kwargs):
|
||||
def Wide_resnet101_2(pretrained=False, **kwargs):
|
||||
kwargs['width_per_group'] = (64 * 2)
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained: model.load("jittorhub://wide_resnet101_2.pkl")
|
||||
return model
|
||||
wide_resnet101_2 = Wide_resnet101_2
|
|
@ -93,14 +93,22 @@ def _shufflenetv2(arch, *args):
|
|||
model = ShuffleNetV2(*args)
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x0_5():
|
||||
return _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
|
||||
def shufflenet_v2_x0_5(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x0_5.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x1_0():
|
||||
return _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
|
||||
def shufflenet_v2_x1_0(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x1_0.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x1_5():
|
||||
return _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
|
||||
def shufflenet_v2_x1_5(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x1_5.pkl")
|
||||
return model
|
||||
|
||||
def shufflenet_v2_x2_0():
|
||||
return _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
|
||||
def shufflenet_v2_x2_0(pretrained=False):
|
||||
model = _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048])
|
||||
if pretrained: model.load("jittorhub://shufflenet_v2_x2_0.pkl")
|
||||
return model
|
||||
|
|
|
@ -83,8 +83,12 @@ def _squeezenet(version, **kwargs):
|
|||
model = SqueezeNet(version, **kwargs)
|
||||
return model
|
||||
|
||||
def squeezenet1_0(**kwargs):
|
||||
return _squeezenet('1_0', **kwargs)
|
||||
def squeezenet1_0(pretrained=False, **kwargs):
|
||||
model = _squeezenet('1_0', **kwargs)
|
||||
if pretrained: model.load("jittorhub://squeezenet1_0.pkl")
|
||||
return model
|
||||
|
||||
def squeezenet1_1(**kwargs):
|
||||
return _squeezenet('1_1', **kwargs)
|
||||
def squeezenet1_1(pretrained=False, **kwargs):
|
||||
model = _squeezenet('1_1', **kwargs)
|
||||
if pretrained: model.load("jittorhub://squeezenet1_1.pkl")
|
||||
return model
|
||||
|
|
|
@ -67,33 +67,49 @@ def _vgg(arch, cfg, batch_norm, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def vgg11(**kwargs):
|
||||
return _vgg('vgg11', 'A', False, **kwargs)
|
||||
def vgg11(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg11', 'A', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg11.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg11_bn(**kwargs):
|
||||
return _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
def vgg11_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg11_bn', 'A', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg11_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg13(**kwargs):
|
||||
return _vgg('vgg13', 'B', False, **kwargs)
|
||||
def vgg13(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg13', 'B', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg13.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg13_bn(**kwargs):
|
||||
return _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
def vgg13_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg13_bn', 'B', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg13_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg16(**kwargs):
|
||||
return _vgg('vgg16', 'D', False, **kwargs)
|
||||
def vgg16(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg16', 'D', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg16.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg16_bn(**kwargs):
|
||||
return _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
def vgg16_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg16_bn', 'D', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg16_bn.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg19(**kwargs):
|
||||
return _vgg('vgg19', 'E', False, **kwargs)
|
||||
def vgg19(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg19', 'E', False, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg19.pkl")
|
||||
return model
|
||||
|
||||
|
||||
def vgg19_bn(**kwargs):
|
||||
return _vgg('vgg19_bn', 'E', True, **kwargs)
|
||||
def vgg19_bn(pretrained=False, **kwargs):
|
||||
model = _vgg('vgg19_bn', 'E', True, **kwargs)
|
||||
if pretrained: model.load("jittorhub://vgg19_bn.pkl")
|
||||
return model
|
|
@ -154,6 +154,7 @@ def get_init_var_rand(shape, dtype):
|
|||
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0)
|
||||
def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1))
|
||||
def sign(x):
|
||||
one = jt.ones(x.shape)
|
||||
x = jt.ternary(x>0, one, x)
|
||||
|
@ -165,6 +166,13 @@ def gelu(x):
|
|||
r = erf*x*.5
|
||||
return r
|
||||
|
||||
class ELU(Module):
|
||||
def __init__(self,alpha=1.0):
|
||||
self.alpha=alpha
|
||||
|
||||
def execute(self,x):
|
||||
return elu(x,self.alpha)
|
||||
|
||||
class PReLU(Module):
|
||||
def __init__(self, num_parameters=1, init_=0.25):
|
||||
self.num_parameters = num_parameters
|
||||
|
@ -238,6 +246,30 @@ def smooth_l1_loss(y_true, y_pred,reduction="mean"):
|
|||
else:
|
||||
raise ValueError(f'not support {reduction}')
|
||||
|
||||
def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'):
|
||||
assert output.ndim<=2 and output.ndim>0 and target.ndim==1
|
||||
n_classes = output.shape[-1]
|
||||
assert weight is None or weight.numel()==n_classes
|
||||
assert ignore_index<0 or ignore_index<n_classes
|
||||
if weight is None:
|
||||
weight = jt.ones((n_classes,))
|
||||
if ignore_index>0:
|
||||
weight[ignore_index]=0
|
||||
if output.ndim==2:
|
||||
index = jt.index((output.shape[0],),dim=0)
|
||||
loss = -output[index,target]*weight[target]
|
||||
else:
|
||||
loss = -output[target[0]]*weight[target[0]]
|
||||
if reduction=="mean":
|
||||
total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum()
|
||||
return loss.sum()/total_weight
|
||||
elif reduction=="sum":
|
||||
return loss.sum()
|
||||
elif reduction=="none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError(f'not support {reduction}')
|
||||
|
||||
class CrossEntropyLoss(Module):
|
||||
def __init__(self,ignore_index=None):
|
||||
self.ignore_index = ignore_index
|
||||
|
@ -330,6 +362,9 @@ class Dropout(Module):
|
|||
output = output * noise / (1.0 - self.p) # div keep prob
|
||||
return output
|
||||
|
||||
def dropout(x,p=0.5,is_train=False):
|
||||
return Dropout(p=p,is_train=is_train)(x)
|
||||
|
||||
class Linear(Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.in_features = in_features
|
||||
|
@ -707,6 +742,45 @@ class ConvTranspose(Module):
|
|||
y = y + b
|
||||
return y
|
||||
|
||||
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
N,C,H,W = x.shape
|
||||
i,o,h,w = weight.shape
|
||||
assert C==i
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
|
||||
h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, h_out, w_out)
|
||||
shape = (N, i, o, H, W, h, w)
|
||||
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
|
||||
ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
|
||||
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if isinstance(bias, jt.Var):
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
else:
|
||||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
|
||||
conv_transpose2d = conv_transpose
|
||||
|
||||
def pad(x,padding, mode='constant', value=0):
|
||||
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
|
||||
|
|
|
@ -33,6 +33,9 @@ class Optimizer(object):
|
|||
assert isinstance(pg, dict)
|
||||
self.param_groups.append(pg)
|
||||
self.n_step = 0
|
||||
|
||||
def add_param_group(self, group):
|
||||
self.param_groups.append(group)
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Xiangli Li <190569238@qq.com>
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
class SparseVar:
|
||||
def __init__(self,indices,values,shape):
|
||||
assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector)
|
||||
self.indices = indices
|
||||
self.values = values
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
|
||||
def _indices(self):
|
||||
return self.indices
|
||||
|
||||
def _values(self):
|
||||
return self.values
|
||||
|
||||
def t(self):
|
||||
indices = list(self.indices.split(1,dim=0))
|
||||
indices[-1],indices[-2] = indices[-2],indices[-1]
|
||||
indices = jt.contrib.concat(indices,dim=0)
|
||||
shape = list(self.shape)
|
||||
shape[-1],shape[-2] = shape[-2],shape[-1]
|
||||
shape = jt.NanoVector(shape)
|
||||
return SparseVar(indices,self.values,shape)
|
||||
|
||||
def to_dense(self):
|
||||
ret = jt.zeros(self.shape,self.values.dtype)
|
||||
indices = tuple(self.indices.split(1,dim=0))
|
||||
ret[indices]=self.values
|
||||
return ret
|
||||
|
||||
def sparse_array(indices,values,shape):
|
||||
return SparseVar(indices,values,shape)
|
||||
|
||||
def spmm(spase_x,y):
|
||||
assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var)
|
||||
assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0]
|
||||
|
||||
# TODO
|
||||
x = spase_x.to_dense()
|
||||
return jt.matmul(x,y)
|
||||
|
|
@ -60,6 +60,41 @@ class TestConvTranspose(unittest.TestCase):
|
|||
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
|
||||
|
||||
def test_function(self):
|
||||
def check(data_shape, weights_shape, stride=1, dilation=1):
|
||||
N,C,H,W = data_shape
|
||||
i,o,h,w = weights_shape
|
||||
img = np.random.rand(N,C,H,W).astype("float32")
|
||||
weights = np.random.rand(i,o,h,w).astype("float32")
|
||||
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
|
||||
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
|
||||
m1.weight.data = weights
|
||||
m2.weight.data = torch.Tensor(weights)
|
||||
x = jt.array(img)
|
||||
# out1 = m1(x)
|
||||
out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False)
|
||||
mask = jt.random(out1.shape)
|
||||
out1 = out1*mask
|
||||
tx = torch.Tensor(img)
|
||||
tx.requires_grad = True
|
||||
out2 = m2(tx) * torch.Tensor(mask.data)
|
||||
with jt.log_capture_scope(log_silent=1,
|
||||
log_vprefix="var_re=0,conv=0,op.cc=100") as logs:
|
||||
assert np.allclose(out1.data, out2.data)
|
||||
dx, dw = jt.grad(out1, [x, m1.weight])
|
||||
jt.sync([dx, dw])
|
||||
out2.sum().backward()
|
||||
assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3)
|
||||
assert np.allclose(dx.data, tx.grad.numpy())
|
||||
assert len(find_log_with_re(logs, "conv")) == 3
|
||||
check((4, 5, 10, 10), (5, 6, 3, 3))
|
||||
check((4, 5, 10, 10), (5, 6, 3, 3), 2)
|
||||
check((4, 5, 100, 100), (5, 6, 4, 4), 2)
|
||||
check((4, 5, 100, 100), (5, 6, 4, 4), 3)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
|
||||
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -26,6 +26,19 @@ class TestFunction(unittest.TestCase):
|
|||
da = jt.grad(b, a)
|
||||
assert da.data == -1
|
||||
|
||||
def test_apply(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x):
|
||||
return x+1
|
||||
|
||||
def grad(self, grad):
|
||||
return grad-2
|
||||
a = jt.ones(1)
|
||||
func = MyFunc.apply
|
||||
b = func(a)
|
||||
da = jt.grad(b, a)
|
||||
assert da.data == -1
|
||||
|
||||
def test2(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x):
|
||||
|
|
|
@ -40,6 +40,21 @@ class TestLoss(unittest.TestCase):
|
|||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_nll_loss(self):
|
||||
tc_loss = tnn.functional.nll_loss
|
||||
jt_loss = jnn.nll_loss
|
||||
output=np.random.randn(10,10).astype(np.float32)
|
||||
target=np.random.randint(10, size=(10))
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target),reduction='mean')
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),reduction='mean')
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
output=np.random.randn(10,10).astype(np.float32)
|
||||
target=np.random.randint(10, size=(10))
|
||||
weight=np.random.randn(10,).astype(np.float32)
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target),jt.array(weight),reduction='mean')
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),torch.from_numpy(weight),reduction='mean')
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_cross_entropy_loss(self):
|
||||
jt_loss=jnn.CrossEntropyLoss()
|
||||
|
|
|
@ -54,6 +54,7 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1))
|
||||
check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2))
|
||||
check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3))
|
||||
check_equal(torch.Tensor(arr).flip([2,3]), jt.array(arr).flip([2,3]))
|
||||
print('pass flip test ...')
|
||||
|
||||
def test_cross(self):
|
||||
|
@ -83,8 +84,13 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1))
|
||||
check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)))
|
||||
print('pass make_grid test ...')
|
||||
|
||||
def test_save_image(self):
|
||||
arr = jt.array(np.random.randn(16,3,10,10))
|
||||
jt.save_image(arr, "/tmp/a.jpg")
|
||||
|
||||
def test_unbind(self):
|
||||
arr = np.random.randn(2,3,4)
|
||||
for dim in range(len(arr.shape)):
|
||||
|
|
|
@ -28,7 +28,7 @@ def check_equal(arr, j_layer, p_layer):
|
|||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(),rtol=1e-5,atol=1e-5)
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestRelu(unittest.TestCase):
|
||||
|
@ -61,6 +61,15 @@ class TestRelu(unittest.TestCase):
|
|||
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
|
||||
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
||||
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ELU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.ELU(), tnn.ELU())
|
||||
check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3))
|
||||
check_equal(arr, jnn.ELU(2), tnn.ELU(2))
|
||||
check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9))
|
||||
|
||||
# ***************************************************************
|
||||
# Test GELU Layer
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
skip_this_test = False
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSetitem(unittest.TestCase):
|
||||
def test_setitem(self):
|
||||
arr0 = jt.random((4,2,2))
|
||||
data0 = jt.ones((2,2))
|
||||
arr0[1] = data0
|
||||
arr0.sync()
|
||||
data0.data[0,0] = 0
|
||||
assert arr0[1,0,0] == 0
|
||||
|
||||
arr00 = jt.random((4,2,2))
|
||||
data00 = jt.ones((2,2))
|
||||
# share memory will fail if d has an edge to other nodes.
|
||||
tmp = data00 + 1
|
||||
arr00[1] = data00
|
||||
arr00.sync()
|
||||
data00.data[0,0] = 0
|
||||
assert arr00[1,0,0] == 0
|
||||
|
||||
arr1 = jt.random((4,2,2))
|
||||
data1 = jt.zeros((2,2))
|
||||
arr1[3,:,0:2] = data1
|
||||
arr1.sync()
|
||||
data1.data[0,0] = 1
|
||||
assert arr1[3,0,0] == 1
|
||||
|
||||
arr21 = jt.ones((2,2))
|
||||
arr22 = jt.ones((2,2)) * 2
|
||||
arr2 = jt.contrib.concat([arr21, arr22], dim=0)
|
||||
arr2.sync()
|
||||
arr21.data[0,0] = 3
|
||||
arr22.data[0,0] = 4
|
||||
assert arr2[0,0] == 3
|
||||
assert arr2[2,0] == 4
|
||||
|
||||
def test_getitem(self):
|
||||
# test for different slice type
|
||||
arr0 = jt.random((4,3))
|
||||
arr0_res = arr0[2,:]
|
||||
arr0_res.data[1] = 1
|
||||
assert arr0[2,1] == 1
|
||||
|
||||
arr1 = jt.array([1,2,3,4])
|
||||
arr1_res = arr1[None]
|
||||
arr1_res.data[0,2] = -1
|
||||
assert arr1[2] == -1
|
||||
|
||||
arr2 = jt.array([1,2,3,4])
|
||||
arr2_res = arr2[...]
|
||||
arr2_res.data[2] = -1
|
||||
assert arr2[2] == -1
|
||||
|
||||
arr3 = jt.array([1,2,3,4])
|
||||
arr3_res = arr3[3]
|
||||
arr3_res.data[0] = -1
|
||||
assert arr3[3] == -1
|
||||
|
||||
arr4 = jt.random((4,2,3,3))
|
||||
arr4_res = arr4[...,:,:]
|
||||
arr4_res.data[0,0,1,1] = 1
|
||||
assert arr4[0,0,1,1] == 1
|
||||
|
||||
arr5 = jt.random((4,2,3,3))
|
||||
arr5_res = arr5[1:3,:,:,:]
|
||||
arr5_res.data[1,0,1,1] = 1
|
||||
assert arr5[2,0,1,1] == 1
|
||||
|
||||
arr6 = jt.random((4,2,3,3))
|
||||
arr6_res = arr6[1]
|
||||
arr6_res.data[0,1,1] = 1
|
||||
assert arr6[1,0,1,1] == 1
|
||||
|
||||
# test for different data type (float32/float64/bool/int8/int32)
|
||||
arr_float32 = jt.random((4,2,3))
|
||||
arr_float32_res = arr_float32[1:3,:,:]
|
||||
arr_float32_res.data[0,0,0] = 1
|
||||
assert arr_float32[1,0,0] == 1
|
||||
arr_float32_res.data[1,1,2] = 1
|
||||
assert arr_float32[2,1,2] == 1
|
||||
arr_float32[1,0,0] = 0
|
||||
# getitem and setitem do not conflict
|
||||
assert arr_float32_res[0,0,0] == 1
|
||||
|
||||
arr_bool = jt.bool(np.ones((4,2,3)))
|
||||
arr_bool_res = arr_bool[1:3,:,:]
|
||||
arr_bool_res.data[0,0,0] = False
|
||||
assert arr_bool[1,0,0] == False
|
||||
arr_bool_res.data[0,0,1] = False
|
||||
assert arr_bool[1,0,1] == False
|
||||
|
||||
arr_float64 = jt.random((4,2,3), dtype='float64')
|
||||
arr_float64_res = arr_float64[1:3,:,:]
|
||||
arr_float64_res.data[0,0,0] = 1
|
||||
assert arr_float64[1,0,0] == 1
|
||||
arr_float64_res.data[1,1,2] = 1
|
||||
assert arr_float64[2,1,2] == 1
|
||||
|
||||
arr_int32 = jt.ones((4,2,3), dtype='int32')
|
||||
arr_int32_res = arr_int32[1:3,:,:]
|
||||
arr_int32_res.data[0,0,0] = 0
|
||||
assert arr_int32[1,0,0] == 0
|
||||
arr_int32_res.data[1,1,2] = 0
|
||||
assert arr_int32[2,1,2] == 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -140,6 +140,11 @@ class TestSlice(unittest.TestCase):
|
|||
a[c] = 0
|
||||
assert (a.data == [1,2,3,0,0]).all()
|
||||
|
||||
def test_numpy_scalar_slice(self):
|
||||
a = jt.random((2,2))
|
||||
b = np.array([1])[0]
|
||||
assert a[b].shape == [2]
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Xiangli Li <1905692338@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSparse(unittest.TestCase):
|
||||
def test_sparse_var(self):
|
||||
indices = np.array([[0,1,1],[2,0,2]])
|
||||
values = np.array([3,4,5]).astype(np.float32)
|
||||
shape = [2,3]
|
||||
jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape))
|
||||
torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape))
|
||||
jt_numpy = jt_array.to_dense().numpy()
|
||||
torch_numpy = torch_tensor.to_dense().numpy()
|
||||
assert np.allclose(jt_numpy,torch_numpy)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,27 @@
|
|||
from flask import Flask
|
||||
from flask import request
|
||||
from flask import jsonify
|
||||
app = Flask(__name__)
|
||||
import json
|
||||
|
||||
from jittor.utils.pytorch_converter import convert
|
||||
|
||||
@app.route('/', methods=["GET", "POST"])
|
||||
def hello():
|
||||
msg = request
|
||||
data = msg.data.decode("utf-8")
|
||||
try:
|
||||
data = json.loads(data)
|
||||
src = data["src"]
|
||||
pjmap = json.loads(data["pjmap"])
|
||||
jt_src = convert(src, pjmap)
|
||||
except Exception as e:
|
||||
jt_src = str(e)
|
||||
response = jsonify(jt_src=jt_src)
|
||||
|
||||
# Enable Access-Control-Allow-Origin
|
||||
response.headers.add("Access-Control-Allow-Origin", "*")
|
||||
return response
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0")
|
|
@ -38,7 +38,7 @@ def download_url_to_local(url, filename, root_folder, md5):
|
|||
ensure_dir(root_folder)
|
||||
file_path = os.path.join(root_folder, filename)
|
||||
if check_file_exist(file_path, md5):
|
||||
print("Data file has been downloaded and verified")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
print('Downloading ' + url + ' to ' + file_path)
|
||||
|
|
|
@ -179,6 +179,18 @@ pjmap = {
|
|||
'links': {},
|
||||
'extras': {'affine': 'None'},
|
||||
},
|
||||
'Parameter':{
|
||||
'pytorch': {
|
||||
'args': "data,require_grad=True"
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'jt',
|
||||
'name': 'array',
|
||||
'args': 'data,dtype=None',
|
||||
},
|
||||
'links': {},
|
||||
'extras': {},
|
||||
},
|
||||
'Dropout2d': {
|
||||
'pytorch': {
|
||||
'args': 'p=0.5, inplace=False',
|
||||
|
@ -351,6 +363,32 @@ pjmap = {
|
|||
}
|
||||
}
|
||||
|
||||
unsupport_ops = [
|
||||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
|
||||
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
|
||||
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
|
||||
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
|
||||
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
|
||||
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
|
||||
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
|
||||
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
|
||||
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
|
||||
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
|
||||
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
|
||||
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
|
||||
]
|
||||
|
||||
def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None):
|
||||
''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool
|
||||
|
@ -393,58 +431,268 @@ def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_fun
|
|||
'delete': delete,
|
||||
}
|
||||
|
||||
unsupport_ops = [
|
||||
# ***************************************************************
|
||||
# torch.nn
|
||||
# ***************************************************************
|
||||
'Parameter', 'ModuleDict', 'ParameterList', 'ParameterDict',
|
||||
'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold',
|
||||
'MaxPool1d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'AvgPool1d',
|
||||
'AvgPool3d', 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d',
|
||||
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
|
||||
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
|
||||
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
|
||||
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
|
||||
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
|
||||
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
|
||||
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
|
||||
'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', 'Identity',
|
||||
'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', 'identity',
|
||||
'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured',
|
||||
'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm',
|
||||
'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence'
|
||||
]
|
||||
|
||||
support_ops = {}
|
||||
for key in pjmap.keys():
|
||||
module = pjmap[key]['jittor']['module']
|
||||
name = pjmap[key]['jittor']['name']
|
||||
if module == 'nn':
|
||||
support_ops[key] = name
|
||||
def raise_unsupport(name, ori_src):
|
||||
ret = f"raise RuntimeError('''original source: <{ori_src.strip()}>, {name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.''')"
|
||||
print(ret+'\n')
|
||||
ret = ast.parse(ret).body[0]
|
||||
return ret
|
||||
|
||||
def raise_unsupport(name):
|
||||
raise RuntimeError(f'{name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.')
|
||||
class Converter:
|
||||
def __init__(self, ex_pjmap):
|
||||
import copy
|
||||
self.pjmap = copy.deepcopy(pjmap)
|
||||
if ex_pjmap:
|
||||
self.pjmap.update(ex_pjmap)
|
||||
self.unsupport_ops = set(unsupport_ops)
|
||||
support_ops = {}
|
||||
for key in self.pjmap.keys():
|
||||
module = self.pjmap[key]['jittor']['module']
|
||||
name = self.pjmap[key]['jittor']['name']
|
||||
if module == 'nn':
|
||||
support_ops[key] = name
|
||||
if key in self.unsupport_ops:
|
||||
self.unsupport_ops.remove(key)
|
||||
self.support_ops = support_ops
|
||||
self.import_flag = []
|
||||
|
||||
def replace(a):
|
||||
if hasattr(a, "attr") and a.attr in unsupport_ops:
|
||||
raise_unsupport(a.attr)
|
||||
|
||||
if hasattr(a, "id") and a.id in unsupport_ops:
|
||||
raise_unsupport(a.id)
|
||||
|
||||
if hasattr(a, "attr"):
|
||||
if a.attr in support_ops.keys(): a.attr = support_ops[a.attr]
|
||||
def replace(self, a):
|
||||
if hasattr(a, "attr") and a.attr in self.unsupport_ops:
|
||||
ori_src = astunparse.unparse(a)
|
||||
return raise_unsupport(a.attr, ori_src)
|
||||
|
||||
if hasattr(a, "id"):
|
||||
if a.id in support_ops.keys(): a.id = support_ops[a.id]
|
||||
if hasattr(a, "id") and a.id in self.unsupport_ops:
|
||||
ori_src = astunparse.unparse(a)
|
||||
return raise_unsupport(a.id, ori_src)
|
||||
|
||||
import_flag = []
|
||||
def convert(code):
|
||||
if hasattr(a, "attr"):
|
||||
if a.attr in self.support_ops.keys(): a.attr = self.support_ops[a.attr]
|
||||
|
||||
if hasattr(a, "id"):
|
||||
if a.id in self.support_ops.keys(): a.id = self.support_ops[a.id]
|
||||
|
||||
return None
|
||||
|
||||
def convert_(self, prefix, func_name, ags, kws, ori_src):
|
||||
info = self.pjmap[func_name]
|
||||
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
|
||||
if p_prefix is not None and prefix in p_prefix:
|
||||
p_ags = info['pytorch']['args_prefix']
|
||||
j_ags = info['jittor']['args_prefix']
|
||||
else:
|
||||
p_ags = info['pytorch']['args']
|
||||
j_ags = info['jittor']['args']
|
||||
if 'delete' in info.keys():
|
||||
delete = info['delete']
|
||||
else:
|
||||
delete = None
|
||||
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
|
||||
j_module = info['jittor']['module']
|
||||
j_name = info['jittor']['name']
|
||||
links = info['links']
|
||||
extras = info['extras']
|
||||
jj_ags = []
|
||||
jj_kws = {}
|
||||
pp_ags = []
|
||||
pp_kws = {}
|
||||
if j_ags == '' and p_ags == '':
|
||||
# no args in Pytorch and Jittor.
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}()"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}()"
|
||||
else:
|
||||
return f"{prefix}.{j_name}()"
|
||||
else:
|
||||
j_ags = j_ags.replace(' ','').split(',')
|
||||
for j_ag in j_ags:
|
||||
if '=' in j_ag:
|
||||
k,v = j_ag.split('=')
|
||||
jj_kws[k] = v
|
||||
else:
|
||||
jj_ags.append(j_ag)
|
||||
p_ags = p_ags.replace(' ','').split(',')
|
||||
for p_ag in p_ags:
|
||||
if '=' in p_ag:
|
||||
k,v = p_ag.split('=')
|
||||
pp_kws[k] = v
|
||||
else:
|
||||
pp_ags.append(p_ag)
|
||||
if len(jj_ags) == 0 and len(pp_ags) != 0:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {pp_ags[0]}''')"
|
||||
# raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in pp_ags:
|
||||
jj_ags.append(d)
|
||||
if d in pp_kws.keys():
|
||||
jj_kws[d] = None
|
||||
if len(pp_ags) > len(ags) + len(kws):
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>, There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
ags_ = []
|
||||
for i in range(len(pp_ags)):
|
||||
if i < len(ags):
|
||||
if '*' in pp_ags[i]:
|
||||
ags_.append('(' + ', '.join(ags[i:]) + ')')
|
||||
ags = ags_
|
||||
break
|
||||
else:
|
||||
ags_.append(ags[i])
|
||||
else:
|
||||
break
|
||||
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}''')"
|
||||
# raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
|
||||
j_ags_flag = np.zeros(len(jj_ags))
|
||||
j_ags_values = {}
|
||||
j_kws_values = {}
|
||||
for i,ag in enumerate(ags):
|
||||
if len(pp_ags) == 0:
|
||||
ag_name = list(pp_kws.keys())[i]
|
||||
elif i < len(pp_ags):
|
||||
ag_name = pp_ags[i]
|
||||
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
|
||||
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
|
||||
else:
|
||||
return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,The args number is not matc{func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
|
||||
if ag_name in links.keys():
|
||||
ag_name = links[ag_name]
|
||||
if ag_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(ag_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(ag_name))] = ag
|
||||
elif ag_name in jj_kws.keys():
|
||||
j_kws_values[ag_name] = ag
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {ag_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
|
||||
for i,kw in enumerate(kws):
|
||||
kw_name, kw_value = kw.split('=')
|
||||
if kw_name in links.keys():
|
||||
kw_name = links[kw_name]
|
||||
if kw_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(kw_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(kw_name))] = kw_value
|
||||
elif kw_name in jj_kws.keys():
|
||||
j_kws_values[kw_name] = kw_value
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {kw_name}''')"
|
||||
# raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
|
||||
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
|
||||
if j_ags_flag.sum() < len_jj_ags:
|
||||
missing_args = []
|
||||
for i in range(len(jj_ags)):
|
||||
if j_ags_flag[i] == 0:
|
||||
missing_args.append(jj_ags[i])
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.''')"
|
||||
# raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
|
||||
if extras:
|
||||
for k in extras.keys():
|
||||
if k in jj_ags:
|
||||
j_ags_values[str(jj_ags.index(k))] = extras[k]
|
||||
elif k in jj_kws.keys():
|
||||
j_kws_values[k] = extras[k]
|
||||
else:
|
||||
return f"raise AttributeError('''origin source: <{ori_src.strip()}>, there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.''')"
|
||||
# raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in j_ags_values:
|
||||
del j_ags_values[d]
|
||||
if d in j_kws_values.keys():
|
||||
j_kws_values.pop(d)
|
||||
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
|
||||
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
|
||||
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
return j_func
|
||||
|
||||
def dfs(self, a):
|
||||
if isinstance(a, ast.Import):
|
||||
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
|
||||
self.import_flag.append('init')
|
||||
return ast.parse('from jittor import init').body[0]
|
||||
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
|
||||
self.import_flag.append('nn')
|
||||
return ast.parse('from jittor import nn').body[0]
|
||||
if 'torch' in a.names[0].name:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.ImportFrom):
|
||||
if 'torch' in a.module:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.Call):
|
||||
for idx, ag in enumerate(a.args):
|
||||
ret = self.dfs(ag)
|
||||
if ret is not None:
|
||||
a.args[idx] = ret
|
||||
for idx, kw in enumerate(a.keywords):
|
||||
ret = self.dfs(kw)
|
||||
if ret is not None:
|
||||
a.keywords[idx] = ret
|
||||
ori_src = astunparse.unparse(a)
|
||||
func = astunparse.unparse(a.func).strip('\n').split('.')
|
||||
prefix = '.'.join(func[0:-1])
|
||||
func_name = func[-1]
|
||||
if func_name in self.unsupport_ops:
|
||||
ret = raise_unsupport(func_name, ori_src)
|
||||
return ret
|
||||
if func_name in self.pjmap:
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
|
||||
ret = self.convert_(prefix, func_name, ags, kws, ori_src)
|
||||
ret_tmp = ret
|
||||
ret = ast.parse(ret).body[0]
|
||||
if hasattr(ret,'value'):
|
||||
return ret.value
|
||||
else:
|
||||
print(ret_tmp+'\n')
|
||||
return ret
|
||||
if ".load_state_dict" in astunparse.unparse(a.func):
|
||||
a.func.attr = 'load_parameters'
|
||||
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
if len(ags) != 0:
|
||||
con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
|
||||
else:
|
||||
con = astunparse.unparse(a.func).replace('size', 'shape')
|
||||
return ast.parse(con).body[0].value
|
||||
elif isinstance(a, ast.Expr): pass
|
||||
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name):
|
||||
ret = self.replace(a)
|
||||
if ret is not None:
|
||||
print(ret)
|
||||
return ret
|
||||
elif isinstance(a, ast.FunctionDef):
|
||||
if a.name == 'forward': a.name = 'execute'
|
||||
if hasattr(a, '__dict__'):
|
||||
for k in a.__dict__.keys():
|
||||
if isinstance(a.__dict__[k], list):
|
||||
delete_flag = []
|
||||
for i,a_ in enumerate(a.__dict__[k]):
|
||||
ret = self.dfs(a_)
|
||||
if ret == 'delete':
|
||||
delete_flag.append(True)
|
||||
continue
|
||||
if ret is not None:
|
||||
a.__dict__[k][i] = ret
|
||||
delete_flag.append(False)
|
||||
tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
|
||||
a.__dict__[k] = tmp
|
||||
else:
|
||||
ret = self.dfs(a.__dict__[k])
|
||||
if ret is not None:
|
||||
a.__dict__[k] = ret
|
||||
|
||||
|
||||
def convert(code, ex_pjmaps=None):
|
||||
''' Model code converter, example:
|
||||
|
||||
from jittor.utils.pytorch_converter import convert
|
||||
|
@ -469,209 +717,13 @@ def convert(code):
|
|||
model = Model()
|
||||
print("## Jittor model:", model)
|
||||
'''
|
||||
|
||||
a = ast.parse(code)
|
||||
dfs(a)
|
||||
converter = Converter(ex_pjmaps)
|
||||
converter.dfs(a)
|
||||
a.body.insert(0, ast.parse('import jittor as jt').body[0])
|
||||
if 'init' not in import_flag:
|
||||
if 'init' not in converter.import_flag:
|
||||
a.body.insert(1, ast.parse('from jittor import init').body[0])
|
||||
if 'nn' not in import_flag:
|
||||
if 'nn' not in converter.import_flag:
|
||||
a.body.insert(2, ast.parse('from jittor import nn').body[0])
|
||||
return astunparse.unparse(a)
|
||||
|
||||
def convert_(prefix, func_name, ags, kws):
|
||||
info = pjmap[func_name]
|
||||
p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None
|
||||
if p_prefix is not None and prefix in p_prefix:
|
||||
p_ags = info['pytorch']['args_prefix']
|
||||
j_ags = info['jittor']['args_prefix']
|
||||
else:
|
||||
p_ags = info['pytorch']['args']
|
||||
j_ags = info['jittor']['args']
|
||||
if 'delete' in info.keys():
|
||||
delete = info['delete']
|
||||
else:
|
||||
delete = None
|
||||
j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None
|
||||
j_module = info['jittor']['module']
|
||||
j_name = info['jittor']['name']
|
||||
links = info['links']
|
||||
extras = info['extras']
|
||||
jj_ags = []
|
||||
jj_kws = {}
|
||||
pp_ags = []
|
||||
pp_kws = {}
|
||||
if j_ags == '' and p_ags == '':
|
||||
# no args in Pytorch and Jittor.
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}()"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}()"
|
||||
else:
|
||||
return f"{prefix}.{j_name}()"
|
||||
else:
|
||||
j_ags = j_ags.replace(' ','').split(',')
|
||||
for j_ag in j_ags:
|
||||
if '=' in j_ag:
|
||||
k,v = j_ag.split('=')
|
||||
jj_kws[k] = v
|
||||
else:
|
||||
jj_ags.append(j_ag)
|
||||
p_ags = p_ags.replace(' ','').split(',')
|
||||
for p_ag in p_ags:
|
||||
if '=' in p_ag:
|
||||
k,v = p_ag.split('=')
|
||||
pp_kws[k] = v
|
||||
else:
|
||||
pp_ags.append(p_ag)
|
||||
if len(jj_ags) == 0 and len(pp_ags) != 0:
|
||||
raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in pp_ags:
|
||||
jj_ags.append(d)
|
||||
if d in pp_kws.keys():
|
||||
jj_kws[d] = None
|
||||
if len(pp_ags) > len(ags) + len(kws):
|
||||
raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}')
|
||||
ags_ = []
|
||||
for i in range(len(pp_ags)):
|
||||
if i < len(ags):
|
||||
if '*' in pp_ags[i]:
|
||||
ags_.append('(' + ', '.join(ags[i:]) + ')')
|
||||
ags = ags_
|
||||
break
|
||||
else:
|
||||
ags_.append(ags[i])
|
||||
else:
|
||||
break
|
||||
if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws):
|
||||
raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}')
|
||||
j_ags_flag = np.zeros(len(jj_ags))
|
||||
j_ags_values = {}
|
||||
j_kws_values = {}
|
||||
for i,ag in enumerate(ags):
|
||||
if len(pp_ags) == 0:
|
||||
ag_name = list(pp_kws.keys())[i]
|
||||
elif i < len(pp_ags):
|
||||
ag_name = pp_ags[i]
|
||||
elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())):
|
||||
ag_name = list(pp_kws.keys())[i-len(pp_ags)]
|
||||
else:
|
||||
raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}')
|
||||
if ag_name in links.keys():
|
||||
ag_name = links[ag_name]
|
||||
if ag_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(ag_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(ag_name))] = ag
|
||||
elif ag_name in jj_kws.keys():
|
||||
j_kws_values[ag_name] = ag
|
||||
else:
|
||||
raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}')
|
||||
for i,kw in enumerate(kws):
|
||||
kw_name, kw_value = kw.split('=')
|
||||
if kw_name in links.keys():
|
||||
kw_name = links[kw_name]
|
||||
if kw_name in jj_ags:
|
||||
j_ags_flag[jj_ags.index(kw_name)] = 1
|
||||
j_ags_values[str(jj_ags.index(kw_name))] = kw_value
|
||||
elif kw_name in jj_kws.keys():
|
||||
j_kws_values[kw_name] = kw_value
|
||||
else:
|
||||
raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}')
|
||||
len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0
|
||||
if j_ags_flag.sum() < len_jj_ags:
|
||||
missing_args = []
|
||||
for i in range(len(jj_ags)):
|
||||
if j_ags_flag[i] == 0:
|
||||
missing_args.append(jj_ags[i])
|
||||
raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.")
|
||||
if extras:
|
||||
for k in extras.keys():
|
||||
if k in jj_ags:
|
||||
j_ags_values[str(jj_ags.index(k))] = extras[k]
|
||||
elif k in jj_kws.keys():
|
||||
j_kws_values[k] = extras[k]
|
||||
else:
|
||||
raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.")
|
||||
if delete is not None:
|
||||
for d in delete:
|
||||
if d in j_ags_values:
|
||||
j_ags_values.remove(d)
|
||||
if d in j_kws_values.keys():
|
||||
j_kws_values.pop(d)
|
||||
j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))]
|
||||
j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()]
|
||||
j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
if p_prefix is None:
|
||||
return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
if prefix in p_prefix:
|
||||
return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
else:
|
||||
return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})"
|
||||
return j_func
|
||||
|
||||
def dfs(a):
|
||||
if isinstance(a, ast.Import):
|
||||
if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a):
|
||||
import_flag.append('init')
|
||||
return ast.parse('from jittor import init').body[0]
|
||||
if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn':
|
||||
import_flag.append('nn')
|
||||
return ast.parse('from jittor import nn').body[0]
|
||||
if 'torch' in a.names[0].name:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.ImportFrom):
|
||||
if 'torch' in a.module:
|
||||
return 'delete'
|
||||
elif isinstance(a, ast.Call):
|
||||
for idx, ag in enumerate(a.args):
|
||||
ret = dfs(ag)
|
||||
if ret is not None:
|
||||
a.args[idx] = ret
|
||||
for idx, kw in enumerate(a.keywords):
|
||||
ret = dfs(kw)
|
||||
if ret is not None:
|
||||
a.keywords[idx] = ret
|
||||
func = astunparse.unparse(a.func).strip('\n').split('.')
|
||||
prefix = '.'.join(func[0:-1])
|
||||
func_name = func[-1]
|
||||
if func_name in unsupport_ops:
|
||||
raise_unsupport(func_name)
|
||||
if func_name in pjmap.keys():
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords]
|
||||
ret = convert_(prefix, func_name, ags, kws)
|
||||
return ast.parse(ret).body[0].value
|
||||
if ".load_state_dict" in astunparse.unparse(a.func):
|
||||
a.func.attr = 'load_parameters'
|
||||
if astunparse.unparse(a.func).strip('\n').endswith(".size"):
|
||||
ags = [astunparse.unparse(ag).strip('\n') for ag in a.args]
|
||||
if len(ags) != 0:
|
||||
con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']'
|
||||
else:
|
||||
con = astunparse.unparse(a.func).replace('size', 'shape')
|
||||
return ast.parse(con).body[0].value
|
||||
elif isinstance(a, ast.Expr): pass
|
||||
elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name): replace(a)
|
||||
elif isinstance(a, ast.FunctionDef):
|
||||
if a.name == 'forward': a.name = 'execute'
|
||||
if hasattr(a, '__dict__'):
|
||||
for k in a.__dict__.keys():
|
||||
if isinstance(a.__dict__[k], list):
|
||||
delete_flag = []
|
||||
for i,a_ in enumerate(a.__dict__[k]):
|
||||
ret = dfs(a_)
|
||||
if ret is 'delete':
|
||||
delete_flag.append(True)
|
||||
continue
|
||||
if ret is not None:
|
||||
a.__dict__[k][i] = ret
|
||||
delete_flag.append(False)
|
||||
tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False]
|
||||
a.__dict__[k] = tmp
|
||||
else:
|
||||
ret = dfs(a.__dict__[k])
|
||||
if ret is not None:
|
||||
a.__dict__[k] = ret
|
|
@ -1 +1 @@
|
|||
84596508776983dce645fc4ef77c7f35700549d5
|
||||
d2eb452b81e704188346a788d8d53889f7b12007
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
cat > /tmp/converter_server.dockerfile <<\EOF
|
||||
FROM jittor/jittor
|
||||
|
||||
RUN python3.7 -m pip install flask
|
||||
RUN apt update && apt install git -y
|
||||
EOF
|
||||
|
||||
docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile .
|
||||
|
||||
# docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
|
||||
while true; do
|
||||
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
|
||||
sleep 10
|
||||
done
|
|
@ -21,6 +21,7 @@
|
|||
#include "fuser.h"
|
||||
#include "profiler/profiler_guard.h"
|
||||
#include "parallel_compiler.h"
|
||||
#include "misc/nan_checker.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -46,7 +47,10 @@ void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, i
|
|||
for (Op* op : fused_op.ops) {
|
||||
uint fid1 = op->custom_data;
|
||||
int iid = 0;
|
||||
for (Var* v : op->inputs()) {
|
||||
for (auto ve : op->_inputs) {
|
||||
// this is a control dependency edge, dont used
|
||||
if (ve.back->index<0) continue;
|
||||
auto v = ve.node->var();
|
||||
iid++;
|
||||
int iop_id;
|
||||
int iv_id;
|
||||
|
@ -450,6 +454,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
if (use_cuda)
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
#endif
|
||||
for (Var* var : op->outputs())
|
||||
check_nan(var);
|
||||
}
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
|
|
|
@ -7,6 +7,10 @@
|
|||
#pragma once
|
||||
#include "common.h"
|
||||
#include "mem/allocator.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ static auto make_number = get_op_info("number")
|
|||
|
||||
VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) {
|
||||
if (dout == nullptr) return nullptr;
|
||||
if (x_index<0) return nullptr;
|
||||
LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs()
|
||||
<< "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index;
|
||||
auto dx = op->grad(out, dout, x, x_index);
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include "misc/nan_checker.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include "misc/cuda_flags.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#include "mem/allocator.h"
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
extern void check_nan_float32(float32* ptr, int64 num);
|
||||
extern void check_nan_float64(float64* ptr, int64 num);
|
||||
#endif
|
||||
|
||||
bool check_nan(Var* v) {
|
||||
if (!v->dtype().is_float()) return true;
|
||||
if (v->input() && (
|
||||
v->input()->name() == string("empty") ||
|
||||
v->input()->name() == string("setitem")))
|
||||
return true;
|
||||
#ifdef HAS_CUDA
|
||||
if (v->allocator->is_cuda()) {
|
||||
if (v->dtype() == ns_float32) {
|
||||
check_nan_float32((float32*)v->mem_ptr, v->num);
|
||||
} else
|
||||
if (v->dtype() == ns_float64) {
|
||||
check_nan_float64((float64*)v->mem_ptr, v->num);
|
||||
}
|
||||
ASSERT(cudaDeviceSynchronize()==0) << "detect nan or inf at" << v;
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
if (v->dtype() == ns_float32) {
|
||||
auto* __restrict__ ptr = v->ptr<float32>();
|
||||
auto num = v->num;
|
||||
bool ok = true;
|
||||
int64 i=0;
|
||||
for (; i<num; i++) {
|
||||
if (std::isinf(ptr[i]) || std::isnan(ptr[i])) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT(ok) << "detect nan at index" << i << v;
|
||||
}
|
||||
if (v->dtype() == ns_float64) {
|
||||
auto* __restrict__ ptr = v->ptr<float64>();
|
||||
auto num = v->num;
|
||||
bool ok = true;
|
||||
int64 i=0;
|
||||
for (; i<num; i++) {
|
||||
if (std::isinf(ptr[i]) || std::isnan(ptr[i])) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT(ok) << "detect nan at index" << i << v;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "misc/nan_checker.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
#include <cassert>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) {
|
||||
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
|
||||
if (i<num) {
|
||||
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
|
||||
__trap();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void _check_nan_float64(float64* __restrict__ ptr, int64 num) {
|
||||
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
|
||||
if (i<num) {
|
||||
if (::isnan(ptr[i]) || ::isinf(ptr[i]))
|
||||
__trap();
|
||||
}
|
||||
}
|
||||
|
||||
void check_nan_float64(float64* ptr, int64 num) {
|
||||
int block_num = std::max((int64)1, (num-1)/1024+1);
|
||||
int thread_num = std::min((int64)1024, num);
|
||||
_check_nan_float64<<<block_num, thread_num>>>(ptr, num);
|
||||
}
|
||||
|
||||
void check_nan_float32(float32* ptr, int64 num) {
|
||||
int block_num = std::max((int64)1, (num-1)/1024+1);
|
||||
int thread_num = std::min((int64)1024, num);
|
||||
_check_nan_float32<<<block_num, thread_num>>>(ptr, num);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
bool check_nan(Var* v);
|
||||
|
||||
}
|
|
@ -312,6 +312,10 @@ void SetitemOp::jit_run() {
|
|||
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
|
||||
#endif
|
||||
|
||||
if (data->allocation == in->allocation &&
|
||||
data->allocator == in->allocator)
|
||||
return;
|
||||
|
||||
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
|
||||
index_t did = 0 @for(d, 0, ODIM, @if((BMASK>>d)&1,+ i@d * dstride@d));
|
||||
@for(d, 0, IDIM, index_t iid@d =
|
||||
|
|
|
@ -14,11 +14,6 @@ namespace jittor {
|
|||
static auto make_transpose = get_op_info("transpose")
|
||||
.get_constructor<VarPtr, Var*, NanoVector>();
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
static auto make_reshape = get_op_info("reshape")
|
||||
.get_constructor<VarPtr, Var*, NanoVector>();
|
||||
#endif
|
||||
|
||||
TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) {
|
||||
int i=0;
|
||||
for (; i<axes.size(); i++)
|
||||
|
|
|
@ -16,8 +16,19 @@ inline static bool fast_strcmp(const char* a, const char* b) {
|
|||
return !*b;
|
||||
}
|
||||
|
||||
// add dependency b -> a
|
||||
static inline void add_dependency(Node* a, const vector<Node*>& b) {
|
||||
a->add_inputs(b);
|
||||
auto edge = a->_inputs.end();
|
||||
for (int i=0; i<b.size(); i++) {
|
||||
edge = std::prev(edge);
|
||||
// set -1 mean this is a control dependency edge
|
||||
edge->back->index = -1;
|
||||
}
|
||||
}
|
||||
|
||||
static void setitem_inplace(SetitemOp* op) {
|
||||
// LOGir << "setitem_inplace";
|
||||
// LOGir << "in setitem_inplace";
|
||||
auto input = op->inputs().front();
|
||||
if (!(input->outputs().size() == 1 &&
|
||||
input->forward_liveness<=1 &&
|
||||
|
@ -29,8 +40,7 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
// make sure input op will not use input
|
||||
auto input_name = input_op->name();
|
||||
if (!(input_op->type() == OpType::broadcast ||
|
||||
fast_strcmp(input_name, "array") ||
|
||||
fast_strcmp(input_name, "empty") ||
|
||||
input_op->inputs().size() == 0 ||
|
||||
fast_strcmp(input_name, "setitem") ||
|
||||
fast_strcmp(input_name, "getitem")))
|
||||
// TODO: inplace getitem maybe risky, getitem maybe inplace too
|
||||
|
@ -38,7 +48,50 @@ static void setitem_inplace(SetitemOp* op) {
|
|||
}
|
||||
auto output = op->outputs().front();
|
||||
output->share_with(input);
|
||||
// LOGir << "apply setitem_inplace on" << op << "input:" << input << "output:" << output;
|
||||
// return;
|
||||
|
||||
// LOGir << "pass setitem optim one";
|
||||
|
||||
auto data = op->input(1);
|
||||
input_op = input->input();
|
||||
|
||||
if (input_op && input_op->inputs().size() == 1) {
|
||||
input_op = input_op->inputs().front()->input();
|
||||
}
|
||||
if (input_op && input_op->inputs().size() == 1) {
|
||||
input_op = input_op->inputs().front()->input();
|
||||
}
|
||||
|
||||
VarSlices vs = op->vs;
|
||||
if (!(data->is_finished() == 0 &&
|
||||
(data->outputs().size() == 1 ||
|
||||
(!input_op
|
||||
|| input_op->inputs().size() == 0))))
|
||||
return;
|
||||
if (data->allocator)
|
||||
return;
|
||||
|
||||
auto in_shape = input->shape;
|
||||
for (int i = vs.n - 1; i > 0; --i) {
|
||||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
return;
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
if (s.is_var()) return;
|
||||
|
||||
auto size = 0;
|
||||
if (s.is_int())
|
||||
size = s.i * input->size / in_shape[0];
|
||||
else if (s.is_slice())
|
||||
size = s.slice.start * input->size / in_shape[0];
|
||||
|
||||
add_dependency(data->input(), {input->node()});
|
||||
data->share_with(input, size);
|
||||
// LOGir << "pass setitem optim two";
|
||||
}
|
||||
|
||||
struct BBox {
|
||||
|
@ -103,9 +156,43 @@ static void setitem_grad_opt(GetitemOp* op) {
|
|||
|
||||
}
|
||||
|
||||
static void getitem_inplace(GetitemOp* op) {
|
||||
// LOGir << "in getitem_inplace";
|
||||
|
||||
auto in = op->inputs().front();
|
||||
auto ou = op->outputs().front();
|
||||
|
||||
// return if input or output's shape is variable
|
||||
if (in->num < 0 || ou->num < 0)
|
||||
return;
|
||||
|
||||
VarSlices vs = op->vs;
|
||||
auto in_shape = in->shape;
|
||||
|
||||
for (int i = vs.n - 1; i > 0; --i) {
|
||||
VarSlice s = vs.slices[i];
|
||||
if (!(s.is_slice())) return;
|
||||
Slice ss = s.slice;
|
||||
if (!(ss.start == 0 && ss.stop >= in_shape[i] && ss.step == 1))
|
||||
return;
|
||||
}
|
||||
|
||||
VarSlice s = vs.slices[0];
|
||||
if (s.is_var()) return;
|
||||
|
||||
auto size = 0;
|
||||
if (s.is_int())
|
||||
size = s.i * in->size / in_shape[0];
|
||||
else if (s.is_slice())
|
||||
size = s.slice.start * in->size / in_shape[0];
|
||||
ou->share_with(in, size);
|
||||
// LOGir << "pass getitem_inplace";
|
||||
}
|
||||
|
||||
void SetitemOp::graph_optimize() {
|
||||
// LOGir << "hello graph_optimize";
|
||||
setitem_inplace(this);
|
||||
(void)setitem_inplace;
|
||||
}
|
||||
|
||||
void GetitemOp::graph_optimize() {
|
||||
|
@ -113,6 +200,9 @@ void GetitemOp::graph_optimize() {
|
|||
// LOGir << "hello getitem graph_optimize";
|
||||
// setitem_grad_opt(this);
|
||||
(void)setitem_grad_opt;
|
||||
// (void)getitem_inplace;
|
||||
getitem_inplace(this);
|
||||
(void)getitem_inplace;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -91,7 +91,8 @@ void PassManager::run_passes() {
|
|||
|
||||
run_pass<SolveConflictDefinePass>();
|
||||
run_pass<MergeLoopVarPass>();
|
||||
run_pass<ConstVarPass>();
|
||||
// tmp disable ConstVarPass
|
||||
// run_pass<ConstVarPass>();
|
||||
|
||||
run_pass<RestridePass>();
|
||||
|
||||
|
|
|
@ -734,9 +734,10 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
|
|||
} else
|
||||
if (obj == Py_None) {
|
||||
var_slice->set_none();
|
||||
} else
|
||||
} else
|
||||
if (PyObject_TypeCheck(obj, PyNumberArrType_Type)) {
|
||||
PyArrayDescr_Proxy array_descr = {.type_num = 5}; // 5: int32
|
||||
PyArrayDescr_Proxy array_descr;
|
||||
array_descr.type_num = 5; // 5: int32
|
||||
int value;
|
||||
PyArray_CastScalarToCtype(obj, &value, &array_descr);
|
||||
var_slice->set_int(value);
|
||||
|
|
|
@ -142,7 +142,7 @@ static PyObject* to_py_object3(ArrayArgs&& a) {
|
|||
return to_py_object(jit_op_maker::array_(move(a)));
|
||||
}
|
||||
|
||||
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
||||
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset, bool keep_numpy_array) {
|
||||
auto t = rb->pop_t<uint8>(offset);
|
||||
if (t==0) {
|
||||
auto x = rb->pop_t<int64>(offset);
|
||||
|
@ -161,7 +161,7 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
|||
auto size = rb->pop_t<int64>(offset);
|
||||
PyObjHolder list(PyList_New(size));
|
||||
for (uint i=0; i<size; i++) {
|
||||
PyObject* o = pop_py_object(rb, offset);
|
||||
PyObject* o = pop_py_object(rb, offset, keep_numpy_array);
|
||||
PyList_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
|
@ -170,8 +170,8 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
|||
auto size = rb->pop_t<int64>(offset);
|
||||
PyObjHolder dict(PyDict_New());
|
||||
for (int64 i=0; i<size; i++) {
|
||||
PyObject* key = pop_py_object(rb, offset);
|
||||
PyObject* value = pop_py_object(rb, offset);
|
||||
PyObject* key = pop_py_object(rb, offset, keep_numpy_array);
|
||||
PyObject* value = pop_py_object(rb, offset, keep_numpy_array);
|
||||
PyDict_SetItem(dict.obj, key, value);
|
||||
}
|
||||
return dict.release();
|
||||
|
@ -185,7 +185,10 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
|||
size *= args.shape[i];
|
||||
rb->pop(size, offset);
|
||||
args.ptr = rb->get_ptr(size, offset);
|
||||
return to_py_object3(move(args));
|
||||
if (!keep_numpy_array)
|
||||
return to_py_object3(move(args));
|
||||
else
|
||||
return to_py_object<ArrayArgs>(args);
|
||||
}
|
||||
if (t==6) {
|
||||
return pop_py_object_pickle(rb, offset);
|
||||
|
@ -212,7 +215,7 @@ void PyMultiprocessRingBuffer::push(PyObject* obj) {
|
|||
|
||||
PyObject* PyMultiprocessRingBuffer::pop() {
|
||||
auto offset = rb->l;
|
||||
auto obj = pop_py_object(rb, offset);
|
||||
auto obj = pop_py_object(rb, offset, _keep_numpy_array);
|
||||
rb->commit_pop(offset);
|
||||
return obj;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ namespace jittor {
|
|||
// @pyjt(RingBuffer)
|
||||
struct PyMultiprocessRingBuffer {
|
||||
RingBuffer* rb;
|
||||
bool _keep_numpy_array = false;
|
||||
// @pyjt(__init__)
|
||||
PyMultiprocessRingBuffer(uint64 size);
|
||||
// @pyjt(__dealloc__)
|
||||
|
@ -23,6 +24,8 @@ struct PyMultiprocessRingBuffer {
|
|||
PyObject* pop();
|
||||
// @pyjt(clear)
|
||||
inline void clear() { rb->clear(); }
|
||||
// @pyjt(keep_numpy_array)
|
||||
inline void keep_numpy_array(bool keep) { _keep_numpy_array = keep; }
|
||||
// @pyjt(stop)
|
||||
inline void stop() { rb->stop(); }
|
||||
// @pyjt(is_stop)
|
||||
|
|
|
@ -64,7 +64,7 @@ bool Var::alloc(Allocator* allocator) {
|
|||
if (mem_ptr) return true;
|
||||
if (auto* x = (Var*)(this->allocator)) {
|
||||
if (x->allocator->share_with(size, x->allocation)) {
|
||||
mem_ptr = x->mem_ptr;
|
||||
mem_ptr = ((char*) x->mem_ptr) + allocation;
|
||||
allocation = x->allocation;
|
||||
this->allocator = x->allocator;
|
||||
return true;
|
||||
|
|
|
@ -42,7 +42,7 @@ struct Var : Node {
|
|||
int64_t numel();
|
||||
void set_shape(NanoVector shape);
|
||||
bool alloc(Allocator* allocator);
|
||||
inline void share_with(Var* x) { CHECK_EXIST; allocator = (Allocator*)x; }
|
||||
inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; }
|
||||
};
|
||||
|
||||
struct VarPtr {
|
||||
|
|
Loading…
Reference in New Issue