Merge branch 'master' of github.com:jittor/jittor

This commit is contained in:
li-xl 2021-12-30 11:05:23 +08:00
commit 928f7ae5be
42 changed files with 396 additions and 87 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.1.16'
__version__ = '1.3.1.31'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -1183,6 +1183,13 @@ Arguments of hook are defined as::
for p in self.parameters():
p.update(p.mpi_broadcast(root))
def __setattr__(self, key, value):
object.__setattr__(self, key, value)
def __getattr__(self, key):
return object.__getattribute__(self, key)
class Function(Module):
''' Function Module for customized backward operations

View File

@ -935,6 +935,9 @@ with jit_utils.import_scope(import_flags):
jit_utils.try_import_jit_utils_core()
jittor_path = find_jittor_path()
if os.name == 'nt':
# prevent windows recompile
jittor_path = jittor_path.lower()
check_debug_flags()
sys.path.append(cache_path)
@ -988,6 +991,7 @@ if nvcc_path:
cu += "_sm_" + "_".join(s)
if "cuda_arch" not in os.environ:
os.environ["cuda_arch"] = " ".join(cu)
cu = cu.replace(":", "").replace(" ", "")
except:
pass
LOG.i("cuda key:", cu)

View File

@ -28,8 +28,9 @@ def concat(arr, dim):
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
>>> jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
jt.Var([[1 2]
[2 2]], dtype=int32)
'''
# TODO: low performance when concat lots of vars
total_dim = 0

View File

@ -12,6 +12,7 @@
#include "cudnn_conv_backward_w_op.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
@ -75,6 +76,30 @@ void CudnnConvBackwardWOp::jit_prepare(JK& jk) {
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
static auto make_conv = get_op_info("cudnn_conv")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
VarPtr CudnnConvBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
if (xformat == "nchw") {
x->shape.unpack(xn, xc, xh, xw);
dy->shape.unpack(yn, yc, yh, yw);
} else {
x->shape.unpack(xn, xh, xw, xc);
dy->shape.unpack(yn, yh, yw, yc);
}
if (v_index == 0) {
return make_backwardx(dout, dy, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
} else {
return make_conv(x, dout, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
}
}
unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
#else // JIT

View File

@ -20,6 +20,7 @@ struct CudnnConvBackwardWOp : Op {
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_w"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};

View File

@ -12,6 +12,7 @@
#include "cudnn_conv_backward_x_op.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
@ -74,6 +75,22 @@ void CudnnConvBackwardXOp::jit_prepare(JK& jk) {
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
static auto make_conv = get_op_info("cudnn_conv")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
static auto make_backwardw = get_op_info("cudnn_conv_backward_w")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
VarPtr CudnnConvBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, wh, ww, wci, wco, yn, yc, yd, yh, yw;
w->shape.unpack(wco, wci, wh, ww);
if (v_index == 0) {
return make_backwardw(dout, dy, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
} else {
return make_conv(dout, w, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
}
}
unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
#else // JIT

View File

@ -20,6 +20,7 @@ struct CudnnConvBackwardXOp : Op {
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_x"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};

View File

@ -9,6 +9,7 @@
#include "cudnn_conv_op.h"
#include "cudnn_wrapper.h"
#include "executor.h"
#include "ops/op_register.h"
using namespace std;
@ -76,6 +77,24 @@ void CudnnConvOp::jit_prepare(JK& jk) {
jk << _CS("][YFORMAT:") << yformat;
jk << ']';
}
static auto make_backwardx = get_op_info("cudnn_conv_backward_x")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
static auto make_backwardw = get_op_info("cudnn_conv_backward_w")
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
VarPtr CudnnConvOp::grad(Var* out, Var* dout, Var* v, int v_index) {
int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw;
if (xformat == "ncdhw")
x->shape.unpack(xn, xc, xh, xw);
else
x->shape.unpack(xn, xh, xw, xc);
w->shape.unpack(wco, wci, wh, ww);
if (v_index == 0) {
return make_backwardx(w, dout, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
} else {
return make_backwardw(x, dout, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat);
}
}
unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
#else // JIT

View File

@ -17,6 +17,7 @@ struct CudnnConvOp : Op {
CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "cudnn_conv"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};

View File

@ -108,11 +108,16 @@ static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x")
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
Var *dy = dout[0];
Var *dhy = dout[1];
Var *dcy = cx ? dout[2] : nullptr;
VarPtr dy = dout[0];
VarPtr dhy = dout[1];
VarPtr dcy = cx ? dout[2] : nullptr;
if (!dhy.ptr) dhy = make_number(0.0, hy);
if (!dcy.ptr && cx) dcy = make_number(0.0, cy);
vector<VarPtr> dInput;
if (cx)

View File

@ -486,15 +486,14 @@ jt.Var.deg2rad = deg2rad
def arctan2(y,x):
angle = jt.zeros(x.shape,dtype=x.dtype)
mask = x!=0.0
if angle[mask].numel()>0:
angle[mask] = jt.arctan(y[mask]/x[mask])
x = (x!=0.0).ternary(x, x+1e-30)
angle = (y/x).arctan()
mask = (y<0) & (x<0)
if angle[mask].numel()>0:
angle[mask] -= np.pi
mask = (y>0) &(x<0)
mask = (y>=0) &(x<0)
if angle[mask].numel()>0:
angle[mask] +=np.pi
return angle

View File

@ -174,12 +174,16 @@ def Resnet50(pretrained=False, **kwargs):
resnet50 = Resnet50
def Resnet38(**kwargs):
return _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
def Resnet38(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [2, 3, 5, 2], **kwargs)
if pretrained: model.load("jittorhub://resnet38.pkl")
return model
resnet38 = Resnet38
def Resnet26(**kwargs):
return _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
def Resnet26(pretrained=False, **kwargs):
model = _resnet(Bottleneck, [1, 2, 4, 1], **kwargs)
if pretrained: model.load("jittorhub://resnet26.pkl")
return model
resnet26 = Resnet26
def Resnet101(pretrained=False, **kwargs):

View File

@ -803,9 +803,6 @@ class Conv(Module):
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
Kh, Kw = self.kernel_size
self.groups = groups
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
@ -1195,8 +1192,7 @@ class ConvTranspose(Module):
# added
self.dilation = dilation
self.group = groups
assert groups==1, "Group conv not supported yet."
self.groups = groups
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
@ -1209,8 +1205,10 @@ class ConvTranspose(Module):
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
self.weight = init.invariant_uniform((in_channels, out_channels//groups) + self.kernel_size, dtype="float")
if bias:
fan=1
for i in self.weight.shape[1:]:
@ -1221,29 +1219,70 @@ class ConvTranspose(Module):
self.bias = None
def execute(self, x):
N,C,H,W = x.shape
i,o,h,w = self.weight.shape
assert C==i
stride_h, stride_w = self.stride
padding_h, padding_w = self.padding
dilation_h, dilation_w = self.dilation
if self.groups == 1:
N,C,H,W = x.shape
i,o,h,w = self.weight.shape
assert C==i
stride_h, stride_w = self.stride
padding_h, padding_w = self.padding
dilation_h, dilation_w = self.dilation
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
else:
N,C,H,W = x.shape
Kh, Kw = self.kernel_size
i,o,h,w = self.weight.shape
oc = self.out_channels
G = self.groups
CpG = C // G # channels per group
assert C==self.in_channels
stride_h, stride_w = self.stride
padding_h, padding_w = self.padding
dilation_h, dilation_w = self.dilation
oh = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
ow = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, oc, oh, ow)
shape = [N,G,oc//G,CpG,oh,ow,Kh,Kw]
xx = x.reindex(shape, [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5'
])
ww = self.weight.reindex(shape, [
f'i1*{oc//G}+i2',
'i3',
'i6',
'i7'
])
ww.compile_options = xx.compile_options = {"G":G,"C":C}
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h
w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w
out_shape = (N, o, h_out, w_out)
shape = (N, i, o, H, W, h, w)
xx = x.broadcast(shape, (2, 5, 6)) # i,h,w
ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W
y = (ww*xx).reindex_reduce("add", out_shape, [
'i0', # N
'i2', # o
f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid
f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid
])
if self.bias is not None:
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
class ConvTranspose3d(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
@ -1653,7 +1692,7 @@ upsample = resize
def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False, tf_mode=False):
if scale_factor is not None:
size = [X.shape[-2] * scale_factor, X.shape[-1] * scale_factor]
size = [int(X.shape[-2] * scale_factor), int(X.shape[-1] * scale_factor)]
if isinstance(size, int):
size = (size, size)
if scale_factor is not None and scale_factor > 1:

View File

@ -597,10 +597,16 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
if (device_sync && use_cuda) {
last_is_cuda = false;
sync_times++;
try {
// CHECK(EventQueue::OK == event_queue.run_sync([]() {
checkCudaErrors(cudaDeviceSynchronize());
// }));
// TODO: run_sync cause hang, tmp fix it
} catch (const std::exception& e) {
// log memory info
display_memory_info(__FILELINE__, false, true);
throw e;
}
event_queue.flush();
}
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size() << "device_sync:" << device_sync;

View File

@ -28,7 +28,7 @@ EXTERN_LIB string_view_map<FusedOpContext*> jit_fused_ops;
struct FusedOp final : Op {
vector<Op*> ops;
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(i)
// edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(l)
vector<std::tuple<uint,uint,uint,uint>> edges;
vector<VarInfo> vars;
loop_options_t loop_options_merged, loop_options_tuned;

View File

@ -28,6 +28,8 @@ namespace jittor {
static int lock_fd = -1;
int _has_lock = 0;
DEFINE_FLAG(bool, disable_lock, 0, "Disable file lock");
void set_lock_path(string path) {
lock_fd = open(path.c_str(), O_RDWR);
ASSERT(lock_fd >= 0);
@ -35,6 +37,7 @@ void set_lock_path(string path) {
}
void lock() {
if (disable_lock) return;
ASSERT(lock_fd >= 0);
#ifdef _WIN32
OVERLAPPED offset = {0, 0, 0, 0, NULL};
@ -54,6 +57,7 @@ void lock() {
}
void unlock() {
if (disable_lock) return;
ASSERT(lock_fd >= 0);
#ifdef _WIN32
OVERLAPPED offset = {0, 0, 0, 0, NULL};

View File

@ -44,8 +44,8 @@ Date: February 1996
x = std::copysign(num, y) / dem;
}
/* Two steps of Newton-Raphson correction */
x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(M_PI)))*std::exp(-x*x));
x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(M_PI)))*std::exp(-x*x));
x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(3.14159265358979323846)))*std::exp(-x*x));
x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(3.14159265358979323846)))*std::exp(-x*x));
return x;
}

View File

@ -42,7 +42,7 @@ struct ArgReduceOp : Op {
>>> jt.arg_reduce(x, 'max', dim=1, keepdims=False)
[jt.Var([2 1], dtype=int32), jt.Var([5 7], dtype=int32)]
>>> jt.arg_reduce(x, 'min', dim=1, keepdims=False)
[jt.Var([1 2], dtype=int32), jt.Var([5 7], dtype=int32)]
[jt.Var([1 2], dtype=int32), jt.Var([2 1], dtype=int32)]
*/
// @attrs(multiple_outputs)
ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims);

View File

@ -391,9 +391,9 @@ VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// need analysis the overlap attr os var slices
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var()) {
return make_setitem(zeros, VarSlices(vs), dout, ns_add);
return make_setitem(zeros, VarSlices(vs, true), dout, ns_add);
}
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
return make_setitem(zeros, VarSlices(vs, true), dout, ns_void);
}
void GetitemOp::jit_prepare(JK& jk) {

View File

@ -128,41 +128,41 @@ VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index == 0) {
float32 number = 0;
VarPtr zero = make_array(&number, 1, ns_float32);
return make_setitem(dout, VarSlices(vs), zero, ns_void);
return make_setitem(dout, VarSlices(vs, true), zero, ns_void);
} else {
return make_getitem(dout, VarSlices(vs));
return make_getitem(dout, VarSlices(vs, true));
}
}
if (op == ns_add) {
if (v_index == 0) {
return dout;
} else {
return make_getitem(dout, VarSlices(vs));
return make_getitem(dout, VarSlices(vs, true));
}
}
if (op == ns_subtract) {
if (v_index == 0) {
return dout;
} else {
return make_unary(make_getitem(dout, VarSlices(vs)), ns_negative);
return make_unary(make_getitem(dout, VarSlices(vs, true)), ns_negative);
}
}
if (op == ns_multiply) {
if (v_index == 0) {
return make_setitem(dout, VarSlices(vs), input(1), ns_multiply);
return make_setitem(dout, VarSlices(vs, true), input(1), ns_multiply);
} else {
return make_binary(
make_getitem(inputs().front(), VarSlices(vs)),
make_getitem(dout, VarSlices(vs)), ns_multiply);
make_getitem(inputs().front(), VarSlices(vs, true)),
make_getitem(dout, VarSlices(vs, true)), ns_multiply);
}
}
if (op == ns_divide) {
if (v_index == 0) {
return make_setitem(dout, VarSlices(vs), input(1), ns_divide);
return make_setitem(dout, VarSlices(vs, true), input(1), ns_divide);
} else {
// dy = -dz*x / y^2
auto dout2 = make_getitem(dout, VarSlices(vs));
auto x = make_getitem(inputs().front(), VarSlices(vs));
auto dout2 = make_getitem(dout, VarSlices(vs, true));
auto x = make_getitem(inputs().front(), VarSlices(vs, true));
auto y = v;
auto ndz = make_unary(dout2, ns_negative);
auto ndzx = make_binary(ndz, x, ns_multiply);

View File

@ -25,7 +25,7 @@ struct WhereOp : Op {
Example::
jt.where([[0,0,1],[1,0,0]])
# return ( [0,2], [1,0] )
# return [jt.Var([0 1], dtype=int32), jt.Var([2 0], dtype=int32)]
*/
// @attrs(multiple_outputs)
WhereOp(Var* cond, NanoString dtype=ns_int32);

View File

@ -941,16 +941,22 @@ void KernelIR::remove_intermediate(const unordered_set<string>& names) {
if (i >= code.size()) break;
uint j=i+1;
while (j<code.size() && isvar(code[j])) j++;
uint k=j-1;
if (j<code.size() && code[j]=='[') {
// find xxx[...]
while (j<code.size() && code[j]!=']') j++;
// i k j
int prefix = 0;
while (j<code.size()) {
if (code[j] == ']') prefix--;
if (code[j] == '[') prefix++;
if (prefix == 0) break;
j++;
}
CHECK(prefix==0);
j++;
}
uint k=j-1;
if (code[k] == ']') {
if (code[j-1] == ']') {
// xxxp[...] -> xxxd
while (k>=i && code[k]!='[') k--;
k--;
if (k>=i && code[k]=='p' && names.count(code.substr(i,k-i))) {
code[k] = 'd';
for (uint l=k+1; l<j; l++) code[l] = ' ';
@ -961,6 +967,8 @@ void KernelIR::remove_intermediate(const unordered_set<string>& names) {
j += 5;
}
}
i = k+1;
continue;
} else
if (code[k] == 'p' && string(s)=="lvalue" && type=="define") {
if (names.count(code.substr(i,k-i))) {

View File

@ -10,6 +10,7 @@
#include "opt/pass/loop_var_analyze_pass.h"
#include "ops/reduce_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/reindex_op.h"
namespace jittor {
@ -266,10 +267,34 @@ void LoopVarAnalyzePass::run() {
LOGvvv << "replace_vars" << replace_vars;
ir->replace(replace_vars);
for (int i=0; i<this->op->ops.size(); i++) {
auto op = this->op->ops[i];
if (op->type() == OpType::element &&
op->name() == string("array") &&
op->outputs().front()->num == 1) {
ir->replace({{"op"+S(i)+"_outputshape0", "1"}});
}
}
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
// move define
ir->move_loop_back();
LOGvvvv << "KernelIR after move_loop_back\n" >> ir->to_string(0, true);
// check reindex run arguments op
for (Op* op : this->op->ops) {
string op_name = op->name();
if (op_name == "reindex" || op_name == "reindex_reduce") {
ReindexOp* rop = (ReindexOp*)op;
vector<string> ss = rop->indexes;
for (auto& s : rop->overflow_conditions) ss.push_back(s);
for (auto& s : ss) {
if (s.find("//") != string::npos) {
LOGf << "Arguments of reindex op should not contain '//' operation, please replace 'a//b' to 'int(a/b)', Arguments of reindex op: " << s << ss;
}
}
}
}
}
} // jittor

View File

@ -197,12 +197,14 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
#endif
static std::atomic<int> ai;
static volatile int has_error;
static string error_msg;
static vector<vector<std::tuple<int,int,void*,string>>> op_entrys(thread_num);
// <int,int,void*,string> represents: task id, is_fused_op, entry or context, new_jit_key
threads.create_threads(thread_num);
static std::mutex entry_lock;
ai = 0;
has_error = 0;
error_msg = "";
int n = op_needs_compile.size();
LOGvv << "Total number of op needs compile" << op_needs_compile.size()
<< "thread_num:" << thread_num;
@ -267,14 +269,16 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
// log jit_key and file location
op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;
std::stringstream ss;
ss << "[Error] source file location:" << jit_src_path << '\n';
if (is_fused_op) {
LOGe << "Compile fused operator(" >> i >> '/' >> n >> ")"
<< "failed:" << ((FusedOp*)op)->ops << "\n\nReason: " >> e.what();
ss << "Compile fused operator(" << i << '/' << n << ")"
<< "failed:" << ((FusedOp*)op)->ops << "\n\nReason: " << e.what() << '\n';
} else
LOGe << "Compile operator(" >> i >> '/' >> n >> ")"
<< "failed:" << op << "\n\nReason: " >> e.what();
ss << "Compile operator(" << i << '/' << n << ")"
<< "failed:" << op << "\n\nReason: " << e.what() << '\n';
error_msg = ss.str();
has_error = 1;
break;
}
@ -322,7 +326,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
if (has_error) {
threads.wait_all();
LOGf << "Error happend during compilation, see error above.";
LOGf << "Error happend during compilation:\n" << error_msg;
}
// fill all op entry

View File

@ -183,7 +183,7 @@ void process(string src, vector<string>& input_names, string& cmd) {
}
}
if (l-k>2 && src[k] == 'J' && src[k+1] == 'T' && j-i==6 && src.substr(i,j-i) == "#ifdef") {
auto inc = src.substr(k, l-k);
auto inc = strip(src.substr(k, l-k));
auto env = getenv(inc.c_str());
if (env && string(env)!="0") {
auto senv = string(env);

View File

@ -54,7 +54,7 @@ bool check_async_executor_error(const std::exception& e, std::ostream& os) {
SEH_HOOK;
void init_subprocess() {
#ifdef __linux__
#if defined(__linux__) && defined(PR_SET_PDEATHSIG)
prctl(PR_SET_PDEATHSIG, SIGKILL);
#endif
}

View File

@ -243,7 +243,7 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
}
if (signal == SIGCHLD) {
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM && _pid == getpid()) {
LOGe << "Caught SIGCHLD"
LOGe << "Caught SIGCHLD. Maybe out of memory, please reduce your worker size."
<< "si_errno:" << si->si_errno
<< "si_code:" << si->si_code
<< "si_status:" << si->si_status

View File

@ -41,9 +41,9 @@ vector<string> split(const string& s, const string& sep, int max_split) {
string strip(const string& s) {
int i=0;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n')) i++;
while (i<s.size() && (s[i]==' ' || s[i]=='\t' || s[i]=='\n' || s[i]=='\r')) i++;
int j = s.size();
while (j>i && (s[j]==' ' || s[j]=='\t' || s[j]=='\n')) j--;
while (j>i && (s[j-1]==' ' || s[j-1]=='\t' || s[j-1]=='\n' || s[j-1]=='\r')) j--;
return s.substr(i,j-i);
}

View File

@ -117,7 +117,7 @@ void setter_gdb_attach(int v) {
exit(1);
} else {
// allow children ptrace parent
#ifdef __linux__
#if defined(__linux__) && defined(PR_SET_PTRACER)
prctl(PR_SET_PTRACER, child_pid, 0, 0, 0);
#endif
// sleep 5s, wait gdb attach
@ -177,7 +177,7 @@ void print_trace() {
exit(0);
} else {
// allow children ptrace parent
#ifdef __linux__
#if defined(__linux__) && defined(PR_SET_PTRACER)
prctl(PR_SET_PTRACER, child_pid, 0, 0, 0);
#endif
waitpid(child_pid,NULL,0);

View File

@ -240,12 +240,12 @@ struct VarHolder {
// @pyjt(__set__data)
inline void set_data(ArrayArgs&& array) {
sync(true);
ASSERT(array.dtype.dsize() == var->dtype().dsize()
CHECK(array.dtype.dsize() == var->dtype().dsize()
&& array.dtype.is_int() == var->dtype().is_int());
int64 size = array.dtype.dsize();
for (int i=0; i<array.shape.size(); i++)
size *= array.shape[i];
ASSERT(size==var->size);
CHECK(size==var->size);
#ifdef HAS_CUDA
migrate_to_cpu(var, exe.allocator);
#endif

View File

@ -47,9 +47,16 @@ struct VarSlices {
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
other.slices = nullptr;
}
inline VarSlices(const VarSlices& other) : slices(new VarSlice[other.n]), n(other.n) {
for (int i=0; i<n; i++)
inline VarSlices(const VarSlices& other, bool negtive_set_none=false) : slices(new VarSlice[other.n]), n(other.n) {
for (int i=0; i<n; i++) {
slices[i] = other.slices[i];
if (negtive_set_none &&
slices[i].is_slice() &&
slices[i].slice.step < 0 &&
slices[i].slice.stop < 0) {
slices[i].slice.mask |= 2;
}
}
}
inline void operator=(VarSlices&& other) {
if (slices) delete[] slices;

View File

@ -29,13 +29,13 @@ class TestConvTranspose(unittest.TestCase):
self.test()
def test(self):
def check(data_shape, weights_shape, stride=1, dilation=1):
def check(data_shape, weights_shape, stride=1, dilation=1, groups=1):
N,C,H,W = data_shape
i,o,h,w = weights_shape
img = np.random.rand(N,C,H,W).astype("float32")
weights = np.random.rand(i,o,h,w).astype("float32")
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False)
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False)
weights = np.random.rand(i,o//groups,h,w).astype("float32")
m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups)
m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups)
m1.weight.data = weights
m2.weight.data = torch.Tensor(weights)
x = jt.array(img)
@ -61,6 +61,7 @@ class TestConvTranspose(unittest.TestCase):
check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2)
check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3)
check((4, 6, 100, 100), (6, 6, 5, 5), 2, 3, 2)
def test_function(self):
def check(data_shape, weights_shape, stride=1, dilation=1):

View File

@ -99,5 +99,12 @@ class TestCore(unittest.TestCase):
net.conv1.save(pkl_name)
net.conv1.load(pkl_name)
def test_module(self):
a = jt.Module()
a.__setattr__("x", 1)
assert a.__getattr__("x") == 1
a.y = 2
assert a.y == 2
if __name__ == "__main__":
unittest.main()

View File

@ -236,6 +236,11 @@ class TestFusedOp(unittest.TestCase):
check(64, 60, 64, 1, 0, 42)
check(64, 60, 64, 0, 0, 30) # TODO: why slower?
def test_array_reindex(self):
a = jt.array([1])
b = a.reindex([3], ['i0-1'])
np.testing.assert_allclose(b.data, [0,1,0])
@unittest.skipIf(skip_slow_test, "Skip slow test")
def test_profile_fused_op_restride(self):

View File

@ -250,5 +250,9 @@ class TestOther(unittest.TestCase):
assert x[3]['a'] == [1,2,3]
assert (x[3]['b'] == np.array([1,2,3])).all()
def test_arctan2(self):
a = jt.arctan2(jt.array([1,1.0,0]), jt.array([1,0.0,-1]))
np.testing.assert_allclose(a.data, [0.7853982,1.5707964,3.1415927])
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,66 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
import numpy as np
from jittor import nn
from jittor import dataset
mpi = jt.compile_extern.mpi
class Model(nn.Module):
def __init__(self, input_size):
self.linear1 = nn.Linear(input_size, 10)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(10, 10)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
def fork_with_mpi(num_procs=4):
import sys
if jt.in_mpi:
# you can mult other process output
if jt.rank != 0:
sys.stdout = open("/dev/null", "w")
return
else:
print(sys.argv)
cmd = " ".join(["mpirun", "-np", str(num_procs), sys.executable] + sys.argv)
print("[RUN CMD]:", cmd)
os.system(cmd)
exit(0)
def main():
mnist = dataset.MNIST()
model = Model(mnist[0][0].size)
sgd = jt.optim.SGD(model.parameters(), 1e-3)
fork_with_mpi()
for data, label in mnist:
pred = model(data.reshape(data.shape[0], -1))
# print(data.shape, label.shape, pred.shape)
loss = nn.cross_entropy_loss(pred, label)
sgd.step(loss)
print(jt.rank, mnist.epoch_id, mnist.batch_id, loss)
# break
# class TestMpiInPy(unittest.TestCase):
# def test(self):
# main()
if __name__ == "__main__":
# unittest.main()
main()

View File

@ -280,6 +280,21 @@ class TestReindexOp(unittest.TestCase):
def test_doc(self):
assert "Reindex Operator" in jt.reindex.__doc__
def test_reindex_fuse_error(self):
a = jt.zeros([10,10])
b = jt.array([1])
c = a.reindex([8,8], ["@e0(0)", "@e1(0,i0 / @e0(0))"], extras=[b, jt.ones([10,10])])
c.sync()
# print(c)
def test_reindex_wrong_op(self):
a = jt.zeros([10,10])
b = jt.array([1])
c = a.reindex([8,8], ["@e0(0) // 1", "@e0(0)"], extras=[b, b])
expect_error(lambda: c.sync())
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")

View File

@ -130,6 +130,12 @@ class TestResizeAndCrop(unittest.TestCase):
arr = np.random.randn(1,1,2,2)
check_equal(arr, jnn.Resize((4,4)), tnn.Upsample(scale_factor=2))
# check_equal(arr, jnn.Upsample(scale_factor=0.5), tnn.Upsample(scale_factor=0.5))
def test_interpolate(self):
a = jt.rand(1,3,64,64)
b = jt.nn.interpolate(a, scale_factor=0.5)
b.sync()
assert b.shape == (1,3,32,32)
if __name__ == "__main__":

View File

@ -327,6 +327,19 @@ class TestRNN(unittest.TestCase):
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
def test_twobilinear_lstm(self):
x = jt.rand(5, 4, 10)
rnn1 = nn.LSTM(10, 20, bidirectional=True)
out1, _ = rnn1(x)
rnn2 = nn.LSTM(40, 20, bidirectional=True)
out2, _ = rnn2(out1)
target = jt.zeros_like(out2)
loss = nn.mse_loss(out2, target)
from jittor import optim
optimizer = optim.RMSprop(rnn1.parameters())
optimizer.step(loss)
@skipIf(not jt.has_cuda, "No Cuda found")
@jt.flag_scope(use_cuda=1)
def test_cudnn_rnn(self):

View File

@ -225,8 +225,14 @@ class TestSetitem(unittest.TestCase):
b = a[...,:,None,:2]
assert b.shape == [2,4,1,2]
np.testing.assert_allclose(b.data, a.data[...,:,None,:2])
def test_flip_grad(self):
a = jt.rand(10)
b = a[::-1]
c = b[::-1]
d = c.sum()
jt.grad(d, [a])
if __name__ == "__main__":
unittest.main()

View File

@ -408,6 +408,9 @@ def to_tensor(pic):
img = Image.open(...)
img_ = transform.to_tensor(img)
"""
if isinstance(pic, jt.Var):
return pic
if isinstance(pic, tuple):
# try convert ten crop tuple
pic = ( to_tensor(pic) for p in pic )

View File

@ -14,6 +14,8 @@ If conda is used, please install with command:
import os
from jittor_utils import cache_path, LOG
disable_lock = os.environ.get("disable_lock", "0") == "1"
class Lock:
def __init__(self, filename):
self.handle = open(filename, 'w')
@ -21,6 +23,8 @@ class Lock:
self.is_locked = False
def lock(self):
if disable_lock:
return
if fcntl:
fcntl.flock(self.handle, fcntl.LOCK_EX)
else:
@ -30,6 +34,8 @@ class Lock:
LOG.vv(f'LOCK PID: {os.getpid()}')
def unlock(self):
if disable_lock:
return
if fcntl:
fcntl.flock(self.handle, fcntl.LOCK_UN)
else: