version 1.2

This commit is contained in:
Dun Liang 2020-10-10 16:38:21 +08:00
parent d533f3960b
commit c792b23f48
48 changed files with 3081 additions and 130 deletions

View File

@ -49,6 +49,8 @@ void CubArgReduceOp::infer_shape() {
if (keepdims) {
shape.push_back(1);
}
if (shape.size() == 0)
shape.push_back(1);
y->set_shape(shape);
y_key->set_shape(shape);
}
@ -104,4 +106,4 @@ void CubArgReduceOp::jit_run() {
#endif // JIT_cuda
#endif // JIT
} // jittor
} // jittor

View File

@ -42,8 +42,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -57,8 +57,8 @@ void CudnnConvBackwardWOp::infer_shape() {
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
wco = yc, wci = xc / groups;
wh = kernel_size;
ww = kernel_size;
wh = kh;
ww = kw;
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
}

View File

@ -13,14 +13,14 @@ namespace jittor {
struct CudnnConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kernel_size, stride, padding, dilation, groups;
int kh, kw, stride, padding, dilation, groups;
string xformat, wformat, yformat;
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_w"; }
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor
} // jittor

View File

@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
}
@ -58,8 +58,8 @@ void MklConvBackwardWOp::infer_shape() {
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
wco = yc, wci = xc / groups;
wh = kernel_size;
ww = kernel_size;
wh = kh;
ww = kw;
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
}
@ -97,7 +97,8 @@ void MklConvBackwardWOp::jit_run() {
int height = x->shape[findc("@XFORMAT",'c')];
int width = x->shape[findc("@XFORMAT",'d')];
int ch_out = dw->shape[findc("@WFORMAT",'o')];
int kernel_size = dw->shape[findc("@WFORMAT",'h')];
int kh = dw->shape[findc("@WFORMAT",'h')];
int kw = dw->shape[findc("@WFORMAT",'w')];
auto* __restrict__ net_src = x->ptr<Txd>();
auto* __restrict__ net_diff_dst = dy->ptr<Tyd>();
@ -114,9 +115,9 @@ void MklConvBackwardWOp::jit_run() {
memory::dims conv_src_tz = {batch, ch_in, height, width};
memory::dims conv_weights_tz = groups>1
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size}
: memory::dims{ch_out, ch_in, kernel_size, kernel_size};
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
? memory::dims{groups, ch_out/groups, ch_in/groups, kh, kw}
: memory::dims{ch_out, ch_in, kh, kw};
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kh*dilation+dilation-1)/stride+1, (width+padding*2-kw*dilation+dilation-1)/stride+1};
memory::dims conv_strides = {stride, stride};
memory::dims conv_padding = {padding, padding};
memory::dims conv_dilation = {dilation-1, dilation-1};

View File

@ -13,14 +13,14 @@ namespace jittor {
struct MklConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kernel_size, stride, padding, dilation, groups;
int kh, kw, stride, padding, dilation, groups;
string xformat, wformat, yformat;
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_w"; }
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor
} // jittor

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.1.7.20'
__version__ = '1.2.0.0'
from . import lock
with lock.lock_scope():
from . import compiler
@ -233,11 +233,22 @@ def ones(shape, dtype="float32"):
shape = (shape,)
return unary(1, dtype).broadcast(shape)
def ones_like(x):
return ones(x.shape,x.dtype)
def zeros(shape, dtype="float32"):
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
return unary(0, dtype).broadcast(shape)
def full(shape,val,dtype="float32"):
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
return unary(val, dtype).broadcast(shape)
def zeros_like(x):
return zeros(x.shape,x.dtype)
flags = core.flags()
def std(x):
@ -311,9 +322,17 @@ def squeeze(x, dim):
return x.reshape(shape[:dim] + shape[dim+1:])
Var.squeeze = squeeze
def clamp(x, min_v, max_v):
assert min_v <= max_v
return x.maximum(min_v).minimum(max_v)
def clamp(x, min_v=None, max_v=None):
if x.shape[0]==0:
return x
if min_v is not None and max_v is not None:
assert min_v <= max_v
if min_v is not None:
x = x.maximum(min_v)
if max_v is not None:
x = x.minimum(max_v)
return x
Var.clamp = clamp
def type_as(a, b):
@ -574,6 +593,8 @@ class Module:
else:
if hasattr(v, k):
v = getattr(v, k)
assert isinstance(v, (Module, Var)), \
f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}"
else:
end = 1
break
@ -582,6 +603,8 @@ class Module:
n_failed += 1
LOG.w(f'load parameter {key} failed ...')
else:
assert isinstance(v, Var), \
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
LOG.v(f'load parameter {key} success ...')
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
v.update(array(params[key]))
@ -872,4 +895,4 @@ from .nn import matmul
from . import contrib
from . import numpy2cupy
from .contrib import concat
from .misc import *
from .misc import *

View File

@ -241,7 +241,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
if "multiple_outputs" not in attrs:
jit_cc_src.append(f"""
VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{
Op* _op = new {op_name}({", ".join(op_make_args)});
auto _op = new {op_name}({", ".join(op_make_args)});
if (_op->outputs_holder.size() != 1) {{
delete _op;
LOGf << "Wrong output size of" << \"{op_name}\";
@ -261,7 +261,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
else:
jit_cc_src.append(f"""
vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
Op* _op = new {op_name}({", ".join(op_make_args)});
auto _op = new {op_name}({", ".join(op_make_args)});
if (_op->flags.get(NodeFlags::_forwarded)) {{
vector<VarPtr> outputs = move(_op->outputs_holder);
delete _op;
@ -408,6 +408,15 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
arg_type.replace("Var", "VarHolder")+' '+arg)
new_args.append(arg)
more_src.append(f"_op->add_inputs({arg});")
elif arg_type.startswith("VarSlices"):
new_args_def.append(arg_def)
new_args.append(arg)
more_src.append(f"""
vector<Var*> svars;
for (int i=0; i<_op->vs.n; i++)
if (_op->vs.slices[i].is_var())
svars.push_back(_op->vs.slices[i].var);
_op->add_inputs(svars);""")
else:
new_args_def.append(arg_def)
new_args.append(arg)

View File

@ -42,7 +42,7 @@ Example::
indexes[dim] = f"i{dim}-{cdim}"
b = a.reindex(shape, indexes)
# ugly fix for preventing large fused op
if len(arr)>=10:
if len(arr)>=100:
b.stop_fuse()
if s is None:
s = b
@ -99,6 +99,20 @@ def slice_var_index(x, slices):
cnt_list = 0
extras_idx = []
extras = []
has_ellipse = 0
ellipse_index = 0
for s,i in zip(slices,range(len(slices))):
if isinstance(s,type(...)):
has_ellipse+=1
ellipse_index = i
if has_ellipse>1:
raise Exception(f"There are more than one ...")
elif has_ellipse==1:
slices = list(slices)
del slices[ellipse_index]
while len(slices)<len(shape):
slices.insert(ellipse_index,slice(None))
for i in range(len(shape)):
if i>=len(slices):
s = slice(None)
@ -119,6 +133,7 @@ def slice_var_index(x, slices):
step = 1 if s.step is None else s.step
if start<0: start += sp
if stop<0: stop += sp
if stop>sp+1: stop = sp
out_shape.append(1+int(max(0, (stop-start-1)//step)))
out_index.append(f"{start}+i{j}*{step}")
elif isinstance(s, jt.Var):
@ -160,3 +175,57 @@ def setitem(x, slices, value):
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
jt.Var.__setitem__ = setitem
# PATCH
def getitem(x, slices):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
return getitem(x, slices.where())
if isinstance(slices, list):
slices = tuple(slices)
return x.getitem(slices)
def setitem(x, slices, value):
if isinstance(slices, jt.Var) and slices.dtype == "bool":
mask = jt.broadcast(slices, x)
value = jt.broadcast(value, x)
return mask.ternary(value, mask)
if isinstance(slices, list):
slices = tuple(slices)
return x.assign(x.setitem(slices, value))
jt.Var.__getitem__ = jt.Var.slice_var = getitem
jt.Var.__setitem__ = setitem
def concat(arr, dim):
'''Concat Operator can concat a list of jt Var at a specfic dimension.
* [in] x: input var list for concat
* [in] dim: concat which dim
* [out] out: concat result
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
'''
# TODO: low performance when concat lots of vars
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
for a in arr:
total_dim += a.shape[dim]
cdim = 0
shape = list(a.shape)
shape[dim] = total_dim
s = jt.empty(shape, a.dtype)
slices = [slice(None)]*len(a.shape)
for a in arr:
if a.shape[dim] == 0:
continue
slices[dim] = slice(cdim, cdim+a.shape[dim])
# print(slices, type(a))
s = s.setitem(tuple(slices), a)
# s = jt.setitem(s, tuple(slices), a)
cdim += a.shape[dim]
return s

View File

@ -56,6 +56,45 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
def relu_invariant_gauss_(var, mode="fan_in"):
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
def calculate_std(var,mode,nonlinearity,param=0.01):
mode = mode.lower()
assert isinstance(param,(int,float))
assert var.ndim>=2
assert mode in ['fan_in', 'fan_out']
fan = var.shape[1] if mode == 'fan_in' else var.shape[0]
fan *= var[0][0].numel()
gains = {
'linear':1,
'conv1d':1,
'conv2d':1,
'conv3d':1,
'conv_transpose1d':1,
'conv_transpose2d':1,
'conv_transpose3d':1,
'sigmoid':1,
'tanh':5.0/3,
'relu':math.sqrt(2.0),
'leaky_relu':math.sqrt(2.0 / (1 + param ** 2)),
}
gain = gains[nonlinearity]
std = gain/math.sqrt(fan)
return std
def kaiming_uniform_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
std = calculate_std(var,mode,nonlinearity,a)
bound = math.sqrt(3.0) * std
with jt.no_grad():
return uniform_(var,-bound, bound)
def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
std = calculate_std(var,mode,nonlinearity,a)
with jt.no_grad():
return gauss_(var,0, std)
#TODO: bound = gain * math.sqrt(6.0/fan) ??
def xavier_uniform(shape, dtype, gain=1.0):
assert len(shape)>1
@ -81,4 +120,4 @@ def xavier_gauss(shape, dtype, gain=1.0):
return gauss(shape, dtype, 0, std)
def xavier_gauss_(var, gain=1.0):
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))

View File

@ -10,7 +10,8 @@
import jittor as jt
import numpy as np
import math
from collections.abc import Sequence
from collections.abc import Sequence,Iterable
def repeat(x, *shape):
r'''
@ -94,6 +95,22 @@ def expand(x, shape):
return x.broadcast(shape)
jt.Var.expand = expand
def median(x,dim=None,keepdim=False):
if dim is None:
x = x.reshape(-1)
dim=0
_,x = x.argsort(dim)
slices = [slice(None) for i in range(dim-1)]
k = (x.shape[dim]-1)//2
if keepdim:
slices.append(slice(k,k+1))
else:
slices.append(k)
return x[tuple(slices)]
jt.Var.median = median
def stack(x, dim=0):
r'''
Concatenates sequence of vars along a new dimension.
@ -116,8 +133,10 @@ def stack(x, dim=0):
[[[1 2 3]
[[4 5 6]]]
'''
assert isinstance(x, list)
assert len(x) >= 2
assert isinstance(x, Sequence)
if len(x) < 2:
return x[0].unsqueeze(dim)
res = [x_.unsqueeze(dim) for x_ in x]
return jt.contrib.concat(res, dim=dim)
jt.Var.stack = stack
@ -140,6 +159,10 @@ def flip(x, dim=0):
[[4 3 2 1]]
'''
assert isinstance(dim, int)
if dim<0:
dim+=x.ndim
assert dim>=0 and dim<len(x.shape)
tar_dims = []
for i in range(len(x.shape)):
if i == dim:
@ -258,6 +281,8 @@ def unbind(x, dim=0):
if dim < 0: dim += len(x.shape)
return [x[(slice(None),)*dim+(i,)] for i in range(x.shape[dim])]
jt.Var.unbind = unbind
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
assert range == None
assert scale_each == False
@ -268,3 +293,321 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable):
return x
return tuple([x]*n)
return parse
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
def unique(x):
r'''
Returns the unique elements of the input tensor.
Args:
x the input tensor.
'''
x = x.reshape(-1)
_,x = jt.argsort(x)
index2 = [i for i in range(1,x.shape[0])]
index1 = [i for i in range(x.shape[0]-1)]
y = x[1:][x[index2] != x[index1]]
x = jt.contrib.concat([x[:1],y],dim=0)
return x
jt.Var.unique = unique
def hypot(a,b):
return jt.sqrt(a.sqr()+b.sqr())
def rad2deg(x):
return 180 * x / np.pi
jt.Var.rad2deg = rad2deg
def deg2rad(x):
return x * np.pi / 180.
jt.Var.deg2rad = deg2rad
def arctan2(y,x):
angle = jt.zeros(x.shape,dtype=x.dtype)
mask = x!=0.0
if angle[mask].numel()>0:
angle[mask] = jt.arctan(y[mask]/x[mask])
mask = (y<0) & (x<0)
if angle[mask].numel()>0:
angle[mask] -= np.pi
mask = (y>0) &(x<0)
if angle[mask].numel()>0:
angle[mask] +=np.pi
return angle
def nonzero(x):
r'''
Return the index of the elements of input tensor which are not equal to zero.
'''
x = jt.where(x)
x = [xx.unsqueeze(1) for xx in x]
if len(x)<2:
return x[0]
x = jt.contrib.concat(x,dim=1)
return x
jt.Var.nonzero = nonzero
def arange(start=0, end=None, step=1,dtype=None):
if end is None:
end,start = start,0
l = round((end-start)//step)+1
if (l-1)*step+start>=end:
l-=1
x = jt.index((l,),0)
x = x*step+start
if dtype is not None:
x= x.cast(dtype)
return x
def randperm(n, dtype="int64"):
x = np.arange(n)
np.random.shuffle(x)
return jt.array(x).cast(dtype)
def log2(x):
return jt.log(x)/math.log(2.0)
jt.Var.log2 = log2
def item(x):
assert x.ndim==1 and x.shape[0]==1
return x.data[0]
jt.Var.item = item
def meshgrid(*tensors):
r'''
Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids,
where the i th grid is defined by expanding the i th input over dimensions defined by other inputs.
'''
size = len(tensors)
shape = []
for i in range(size):
assert isinstance(tensors[i],jt.Var) and tensors[i].ndim==1
shape.append(tensors[i].shape[0])
grids = []
view_shape = [1]*size
for i in range(size):
vs = view_shape[:]
vs[i]=-1
grids.append(tensors[i].reshape(vs).expand(shape))
return grids
def split(d,split_size,dim):
r'''
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If split_size is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size is a list, then tensor will be split into len(split_size) chunks with sizes in dim according to split_size_or_sections.
Args:
d (Tensor) tensor to split.
split_size (int) or (list(int)) size of a single chunk or list of sizes for each chunk
dim (int) dimension along which to split the tensor.
'''
if isinstance(split_size,int):
shape = d.shape[dim]
if shape % split_size == 0:
split_size = [split_size]*(shape//split_size)
else:
split_size = [split_size]*(shape//split_size)+[shape%split_size]
if isinstance(split_size, Iterable):
assert sum(split_size)==d.shape[dim]
if dim<0:
dim+=d.ndim
ans = []
last = 0
for i in split_size:
if i==0:
shape = list(d.shape)
shape[dim]=0
new_d = jt.zeros(tuple(shape),dtype=d.dtype)
ans.append(new_d)
continue
ss = (slice(None),)*dim+(slice(last,last+i),)
new_d = d[ss]
last +=i
ans.append(new_d)
return tuple(ans)
jt.Var.split = split
def tolist(x):
return x.numpy().tolist()
jt.Var.tolist = tolist
def topk(input, k, dim=None, largest=True, sorted=True):
if input.numel()==0:
return jt.array([],dtype=input.dtype),jt.array([],dtype='int32')
if dim is None:
dim = -1
if dim<0:
dim+=input.ndim
index,values = jt.argsort(input,dim=dim,descending=largest)
dims = (slice(None),)*dim+(slice(0,k),)
indices = index[dims]
values = values[dims]
return values,indices
jt.Var.topk = topk
def kthvalue(input, k, dim=None, keepdim=False):
if dim is None:
dim = -1
if dim<0:
dim+=input.ndim
index,values = jt.argsort(input,dim=dim)
dims = (slice(None),)*dim+(slice(k-1,k),)
indices = index[dims]
values = values[dims]
if not keepdim and indices.ndim>1:
indices = indices.squeeze(dim)
values = values.squeeze(dim)
return values,indices
jt.Var.kthvalue = kthvalue
def gather(x,dim,index):
if dim<0:
dim+=index.ndim
x_shape = list(x.shape )
i_shape = list(index.shape)
assert i_shape[dim]>0
assert x.ndim == index.ndim
i_shape[dim]=x_shape[dim]
assert i_shape == x_shape
ins = []
for i in range(index.ndim):
ins.append(jt.index(index.shape,dim=i))
ins[dim]=index
return x.reindex(ins)
def prod(x,dim=0):
x = jt.log(x)
x = x.sum(dim=dim)
return jt.exp(x)
jt.Var.prod = prod
def cumsum_forward(np, data):
a = data['inputs'][0]
b = data['outputs'][0]
np.cumsum(a, axis=1, out=b)
def cumsum_backward(np, data):
dout = data['dout']
out = data['outputs'][0]
np.cumsum(dout[:, ::-1], axis=1, out=out)
np.copyto(out, out[:, ::-1])
def cumsum(x, dim=None):
'''
Parameters:
-----------
x: [batch_size, N], jt.var
Returns:
--------
the cumulative sum of x
'''
return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])
jt.Var.cumsum = cumsum
def cumprod(x,dim=0):
x = jt.log(x)
x = cumsum(x,dim=dim)
return jt.exp(x)
jt.Var.cumprod=cumprod
def nms(dets,thresh):
'''
dets jt.array [x1,y1,x2,y2,score]
x(:,0)->x1,x(:,1)->y1,x(:,2)->x2,x(:,3)->y2,x(:,4)->score
'''
threshold = str(thresh)
order = jt.argsort(dets[:,4],descending=True)[0]
dets = dets[order]
s_1 = '(@x(j,2)-@x(j,0)+1)*(@x(j,3)-@x(j,1)+1)'
s_2 = '(@x(i,2)-@x(i,0)+1)*(@x(i,3)-@x(i,1)+1)'
s_inter_w = 'max((Tx)0,min(@x(j,2),@x(i,2))-max(@x(j,0),@x(i,0))+1)'
s_inter_h = 'max((Tx)0,min(@x(j,3),@x(i,3))-max(@x(j,1),@x(i,1))+1)'
s_inter = s_inter_h+'*'+s_inter_w
iou = s_inter + '/(' + s_1 +'+' + s_2 + '-' + s_inter + ')'
fail_cond = iou+'>'+threshold
selected = jt.candidate(dets, fail_cond)
return order[selected]
jt.Var.expand = jt.Var.broadcast
jt.Var.expand_as = jt.Var.broadcast_var
def index_fill_(x,dim,indexs,val):
r'''
Fills the elements of the input tensor with value val by selecting the indices in the order given in index.
Args:
x - the input tensor
dim - dimension along which to index
index indices of input tensor to fill in
val the value to fill with
'''
overflow_conditions = [f'i{dim}=={i}'for i in indexs]
indexs = [f'i{i}' for i in range(len(x.shape))]
return x.reindex(shape = x.shape,indexes = indexs,overflow_conditions=overflow_conditions,overflow_value=val)
def triu_(x,diagonal=0):
r'''
Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
The upper triangular part of the matrix is defined as the elements on and above the diagonal.
Args:
x the input tensor.
diagonal the diagonal to consider,default =0
'''
l = len(x.shape)
assert l>1
overflow_conditions=[f'i{l-1}<i{l-2}+{diagonal}']
indexs = [f'i{i}' for i in range(l)]
return x.reindex(x.shape,indexs,overflow_conditions=overflow_conditions,overflow_value=0)
jt.Var.triu_ = triu_

View File

@ -15,8 +15,10 @@ from jittor import init, Module
import numpy as np
import collections
import math
from collections import OrderedDict
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.optim import *
from jittor.misc import _pair
def matmul_transpose(a, b):
@ -25,6 +27,10 @@ def matmul_transpose(a, b):
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1]
if len(a.shape)>2:
aa = a.reshape((-1, a.shape[-1]))
cc = matmul_transpose(aa, b)
return cc.reshape(a.shape[:-1]+(-1,))
shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2])
@ -108,6 +114,12 @@ Example::
# cc:[..., n, m, k]
# -->
# 012
if len_b == 2 and len_a>2:
# TODO:ugly implementation for tuner
aa = a.reshape((-1, m))
cc = matmul(aa, b)
print(a.shape, b.shape, cc.shape)
return cc.reshape(a.shape[:-1] + [k])
for i in range(len_c-2):
ai = len_a-(len_c-i)
bi = len_b-(len_c-i)
@ -182,11 +194,34 @@ def bce_loss(output, target, weight=None, size_average=True):
def l1_loss(output, target):
return (output-target).abs().mean()
def smooth_l1_loss(y_true, y_pred,reduction="mean"):
"""Implements Smooth-L1 loss.
y_true and y_pred are typically: [N, 4], but could be any shape.
Args:
y_true - ground truth
y_pred - predictions
reduction - the mode of cal loss which must be in ['mean','sum','none']
"""
diff = jt.abs(y_true - y_pred)
less_than_one = (diff<1.0).float32()
loss = (less_than_one * 0.5 * diff.sqr()) + (1 - less_than_one) * (diff - 0.5)
if reduction=="mean":
return loss.mean()
elif reduction=="sum":
return loss.sum()
elif reduction=="none":
return loss
else:
raise ValueError(f'not support {reduction}')
class CrossEntropyLoss(Module):
def __init__(self):
pass
def __init__(self,ignore_index=None):
self.ignore_index = ignore_index
def execute(self, output, target):
return cross_entropy_loss(output, target)
return cross_entropy_loss(output, target,self.ignore_index)
class MSELoss(Module):
def __init__(self):
@ -228,6 +263,13 @@ def softmax(x, dim = None):
ret = x / x.sum(dim, keepdims=True)
return ret
def log_softmax(x,dim=None):
x = softmax(x,dim=dim)
return jt.log(x)
def log_sigmoid(x):
return jt.log(jt.sigmoid(x))
class Dropout(Module):
def __init__(self, p=0.5, is_train=False):
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
@ -267,13 +309,12 @@ class BatchNorm(Module):
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
self.affine = affine
if self.affine:
if affine:
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
@ -300,43 +341,63 @@ class BatchNorm(Module):
return norm_x * w + b
class BatchNorm1d(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
self.affine = affine
if self.affine:
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0], keepdims=1)
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
if len(x.shape) == 3:
if self.is_train:
xmean = jt.mean(x, dims=[0, 2], keepdims=1)
x2mean = jt.mean(x*x, dims=[0, 2], keepdims=1)
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean.update(self.running_mean +
(xmean.sum([0])-self.running_mean)*self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0])-self.running_var)*self.momentum)
else:
running_mean = self.running_mean.broadcast(x, [0])
running_var = self.running_var.broadcast(x, [0])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean.update(self.running_mean +
(xmean.sum([0, 2])-self.running_mean)*self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0, 2])-self.running_var)*self.momentum)
else:
running_mean = self.running_mean.broadcast(x, [0, 2])
running_var = self.running_var.broadcast(x, [0, 2])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0, 2])
b = self.bias.broadcast(x, [0, 2])
else:
if self.is_train:
xmean = jt.mean(x, dims=[0], keepdims=1)
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
if self.sync and jt.in_mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean.update(self.running_mean +
(xmean.sum([0])-self.running_mean)*self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0])-self.running_var)*self.momentum)
else:
running_mean = self.running_mean.broadcast(x, [0])
running_var = self.running_var.broadcast(x, [0])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0])
b = self.bias.broadcast(x, [0])
if not self.affine:
return norm_x
w = self.weight.broadcast(x, [0])
b = self.bias.broadcast(x, [0])
return norm_x * w + b
class InstanceNorm2d(Module):
@ -379,19 +440,23 @@ class GroupNorm(Module):
self.bias = init.constant((num_channels,), "float32", 0.0)
def execute(self, x):
N,C,H,W = x.shape
assert C == self.num_channels
N = x.shape[0]
C = self.num_channels
output_shape = (N,-1)
# TODO: 3d group norm
if x.ndim==4:
output_shape = x.shape
assert C % self.num_groups == 0
x = x.reshape((N, self.num_groups, int(C/self.num_groups), H*W))
x = x.reshape((N, self.num_groups, int(C/self.num_groups), -1))
xmean = jt.mean(x, dims=[2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
if not self.affine:
return norm_x
return norm_x.reshape(output_shape)
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
return (norm_x * w + b).reshape((N,C,H,W))
return (norm_x * w + b).reshape(output_shape)
Relu = jt.make_module(relu)
ReLU = Relu
@ -482,6 +547,86 @@ class Conv(Module):
y = y + b
return y
class Conv1d(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, 1)
self.stride = (stride, 1)
self.padding = (padding, 0)
self.dilation = (dilation, 1)
self.groups = groups
self.bias = bias
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)
def execute(self, x):
N,C,D = x.shape
assert C==self.in_channels
x = x.unsqueeze(-1)
x = self.conv(x)
y = x.squeeze(-1)
return y
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
padding = _pair(padding)
stride = _pair(stride)
dilation = _pair(dilation)
out_channels = weight.shape[0]
if groups == 1:
N,C,H,W = x.shape
Kh, Kw = weight.shape[-2:]
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [
'i0', # Nid
'i2', # Cid
f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid
f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid
])
ww = weight.broadcast(xx.shape, [0,3,4])
yy = xx*ww
y = yy.sum([2,5,6]) # Kc, Kh, Kw
if bias is not None:
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
else:
N,C,H,W = x.shape
Kh, Kw = weight.shape[-2:]
G = groups
CpG = C // G # channels per group
oc = out_channels
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
'i0', # Nid
f'i1*{CpG}+i3', # Gid
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
])
xx.compile_options = {"G":G}
# w: [oc, CpG, Kh, Kw]
ww = weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
f'i1*{oc//G}+i2',
'i3',
'i6',
'i7'
])
yy = xx*ww
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
'i0',
f'i1*{oc//G}+i2',
'i4',
'i5'
])
if bias is not None:
b = bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
class ConvTranspose(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
@ -745,7 +890,17 @@ def upsample(img, size, mode="nearest", align_corners=False):
y = wid * (w / W)
return _interpolate(img, x, y, (nid,cid), mode)
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
def interpolate(X,size=None,scale_factor=None,mode='bilinear',align_corners=False):
if scale_factor is not None:
size = [X.shape[-2]*scale_factor,X.shape[-1]*scale_factor]
if isinstance(size,int):
size = (size,size)
if scale_factor is not None and scale_factor>1:
return upsample(X,size,mode,align_corners)
else:
return resize(X,size,mode,align_corners)
def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
r'''
Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.
@ -789,6 +944,195 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
return _interpolate(input, x, y, (nid,cid), mode)
def linspace_from_neg_one(grid,num_steps,align_corners):
if num_steps <= 1:
return jt.array([],dtype=grid.dtype)
# TODO: use jt.index
ra = np.linspace(-1,1,num_steps)
if not align_corners:
ra = ra*(num_steps-1)/num_steps
return jt.array(ra,dtype=grid.dtype)
def make_base_grid_4D(theta,N,C,H,W,align_corners):
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype);
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
base_grid[...,-1] = 1
return base_grid
def make_base_grid_5D(theta,N,C,D,H,W,align_corners):
base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype)
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1)
base_grid[...,-1] = 1
return base_grid
def affine_grid_generator_4D(theta,N,C,H,W,align_corners):
base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners)
grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3),theta.transpose(0,2,1))
return grid.reshape(N, H, W, 2)
def affine_grid_generator_5D(theta,N,C,D,H,W,align_corners):
base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners)
grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4),theta.transpose(0,2,1))
return grid.reshape(N, D, H, W, 3)
def affine_grid(theta, size, align_corners=False):
assert str(theta.dtype) in ['float','float32','float64']
assert min(size)>0
assert len(size) in [4,5]
if len(size)== 4:
assert theta.ndim == 3 and theta.shape[-2] == 2 and theta.shape[-1] == 3
return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3], align_corners)
elif len(size)==5:
assert theta.ndim == 3 and theta.shape[-2] == 3 and theta.shape[-1] == 4
return affine_grid_generator_5D(theta, size[0], size[1], size[2], size[3], size[4], align_corners)
def grid_sampler_unnormalize(coord,size,align_corners):
if align_corners:
#unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1) / 2) * (size - 1)
else:
#unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1) * size - 1) / 2
def clip_coordinates(x,clip_limit):
return jt.clamp(x,min_v=0,max_v=clip_limit-1)
def reflect_coordinates(x,twice_low,twice_high):
if twice_low == twice_high:
return jt.zeros_like(x)
m = twice_low / 2
span = (twice_high - twice_low) / 2
x = (x - m).abs()
#`fmod` returns same sign as `in`, which is positive after the `fabs` above.
extra = x.mod(span)
flips = (x / span).floor()
result1 = extra+m
result2 = span-extra+m
con = flips%2==0
not_con = flips%2!=0
result1[not_con]=0.0
result2[con]=0.0
return result1+result2
def grid_sampler_compute_source_index(coord,size,padding_mode,align_corners):
coord = grid_sampler_unnormalize(coord, size, align_corners)
if padding_mode == 'border':
#clip coordinates to image borders
coord = clip_coordinates(coord, size)
elif padding_mode == 'reflection':
#reflect coordinates by image borders
if align_corners:
coord = reflect_coordinates(coord, 0, 2*(size - 1))
else:
coord = reflect_coordinates(coord, -1, 2*size - 1)
#clip coordinates to image borders
coord = clip_coordinates(coord, size)
return coord
def grid_sampler_3d(X,grid,mode,padding_mode,align_corners):
N = X.shape[0]
C = X.shape[1]
inp_D = X.shape[2]
inp_H = X.shape[3]
inp_W = X.shape[4]
D = grid.shape[1]
H = grid.shape[2]
W = grid.shape[3]
x = grid[:,:,:,:,0]
y = grid[:,:,:,:,1]
z = grid[:,:,:,:,2]
shape = [N,C,D,H,W]
cid = jt.index(shape, dim=1)
nid = jt.index(shape, dim=0)
x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners)
y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners)
z = grid_sampler_compute_source_index(z,inp_D,padding_mode,align_corners)
xid = x.reindex(shape,['i0','i2','i3','i4'])
yid = y.reindex(shape,['i0','i2','i3','i4'])
zid = z.reindex(shape,['i0','i2','i3','i4'])
if mode=='nearest':
return X.reindex([nid,cid,zid.round(),yid.round(),xid.round()])
elif mode=='bilinear':
fx,fy,fz = xid.floor(),yid.floor(),zid.floor()
cx,cy,cz = fx+1,fy+1,fz+1
dx,dy,dz = xid-fx,yid-fy,zid-fz
dnx,dny,dnz = cx-xid,cy-yid,cz-zid
a = X.reindex([nid,cid,fz,fy,fx])
b = X.reindex([nid,cid,cz,fy,fx])
c = X.reindex([nid,cid,fz,cy,fx])
d = X.reindex([nid,cid,fz,fy,cx])
e = X.reindex([nid,cid,fz,cy,cx])
f = X.reindex([nid,cid,cz,fy,cx])
g = X.reindex([nid,cid,cz,cy,fx])
h = X.reindex([nid,cid,cz,cy,cx])
o = a*dnx*dny*dnz+b*dnx*dny*dz+c*dnx*dy*dnz+d*dx*dny*dnz+e*dx*dy*dnz+f*dx*dny*dz+g*dnx*dy*dz+h*dx*dy*dz
return o
def grid_sampler_2d(X,grid,mode,padding_mode,align_corners):
N = X.shape[0]
C = X.shape[1]
inp_H = X.shape[2]
inp_W = X.shape[3]
H = grid.shape[1]
W = grid.shape[2]
x = grid[:,:,:,0]
y = grid[:,:,:,1]
shape = [N,C,H,W]
cid = jt.index(shape, dim=1)
nid = jt.index(shape, dim=0)
x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners)
y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners)
xid = x.reindex(shape,['i0','i2','i3'])
yid = y.reindex(shape,['i0','i2','i3'])
if mode=='nearest':
return X.reindex([nid,cid,yid.round(),xid.round()])
elif mode=='bilinear':
#xid,yid = (xid+0.00001),(yid+0.00001)
fx,fy = (xid).floor(),(yid).floor()
cx,cy = fx+1,fy+1
dx,dy = xid-fx,yid-fy
dnx,dny = cx-xid,cy-yid
a = X.reindex([nid,cid,fy,fx],overflow_value=0.0)
b = X.reindex([nid,cid,cy,fx],overflow_value=0.0)
c = X.reindex([nid,cid,fy,cx],overflow_value=0.0)
d = X.reindex([nid,cid,cy,cx],overflow_value=0.0)
o = a*dnx*dny+b*dnx*dy+c*dx*dny+d*dx*dy
return o
def grid_sampler(X, grid, mode, padding_mode, align_corners):
assert X.dtype==grid.dtype
assert ((X.ndim==4 or X.ndim==5) and X.ndim==grid.ndim)
assert X.shape[0]==grid.shape[0] and grid.shape[-1]==X.ndim-2
assert X.numel()>0
if X.ndim == 4:
return grid_sampler_2d(X, grid, mode, padding_mode, align_corners)
else:
return grid_sampler_3d(X, grid, mode, padding_mode, align_corners)
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
assert mode in ['bilinear','nearest']
assert padding_mode in ['zeros','border','reflection']
return grid_sampler(input, grid, mode, padding_mode, align_corners)
class Upsample(Module):
def __init__(self, scale_factor=None, mode='nearest'):
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
@ -808,6 +1152,9 @@ class Sequential(Module):
if isinstance(mod, collections.OrderedDict):
for k, m in mod.items():
self.add_module(k, m)
elif isinstance(mod,list):
for m in mod:
self.append(m)
else:
self.append(mod)
def __getitem__(self, idx):
@ -836,4 +1183,7 @@ class Sequential(Module):
assert not isinstance(mod, type), f"Module is not a type"
self.layers[name]=mod
def __len__(self):
return len(self.layers)
ModuleList = Sequential

View File

@ -109,6 +109,8 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
func_args_convert = ""
func_call = df["func_name"]+"("
pytypes = [ get_pytype_map(a[0],0) for a in args ]
holder_dec_array = []
holder_set_array = []
for tid, tpc in enumerate(pytypes):
check = get_pytype_map(args[tid][0],2)
default_arg = args[tid][2]
@ -118,6 +120,11 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
if jtp == "VarHolder*":
holder_dec = f"unique_ptr<VarHolder> arg{tid}_holder"
holder_set = f", arg{tid}_holder"
if jtp == "VarSlices":
holder_dec = f"vector<unique_ptr<VarHolder>> arg{tid}_holder"
holder_set = f", arg{tid}_holder"
holder_dec_array.append(holder_dec)
holder_set_array.append(holder_set)
if len(default_arg):
func_args_convert += f"""
{holder_dec};
@ -165,7 +172,7 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
if (khash == {get_hash(args[aid][1])}u) {{
// hash match {args[aid][1]}
CHECK(({get_pytype_map(args[aid][0],2)}(vo)));
arg{aid} = {pytypes[aid]}(vo);
arg{aid} = {pytypes[aid]}(vo{holder_set_array[aid]});
arg_filled |= 1ull << {aid};
continue;
}}
@ -759,7 +766,11 @@ def compile_src(src, h, basename):
core_name = submodule_info["attrs"]["core_name"]
has_map = class_name in ["VarHolder", "NanoVector"]
has_seq = class_name == "NanoVector"
code = f"""
# add extra include to avoid compile error
src_code = ""
if include_name.endswith("var_slices.h"):
src_code += '#include "var_holder.h"\n'
src_code += f"""
#include "pyjt/py_converter.h"
#include "pyjt/py_arg_printer.h"
#include "common.h"
@ -830,7 +841,7 @@ def compile_src(src, h, basename):
}}
"""
return code
return src_code
def compile_single(head_file_name, src_file_name, src=None):
basename = head_file_name.split("/")[-1].split(".")[0]

View File

@ -0,0 +1,67 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor.nn import affine_grid,grid_sample
class TestAffineGrid(unittest.TestCase):
def test_affine_grid_2d(self):
import torch.nn.functional as F
import torch
N = 8
C = 3
H = 256
W = 128
theta = np.random.randn(N,2,3).astype(np.float32)
features = np.random.randint(256,size=(N,C,H,W)).astype(np.float32)
torch_theta = torch.Tensor(theta)
torch_features = torch.Tensor(features)
torch_grid = F.affine_grid(torch_theta,size=(N,C,H,W),align_corners=False)
torch_sample = F.grid_sample(torch_features,torch_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
jt_theta = jt.array(theta)
jt_features = jt.array(features)
jt_grid = affine_grid(jt_theta,size=(N,C,H,W),align_corners=False)
jt_sample = grid_sample(jt_features,jt_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
assert np.allclose(jt_theta.numpy(),torch_theta.numpy())
assert np.allclose(jt_features.numpy(),torch_features.numpy())
assert np.allclose(jt_grid.numpy(),torch_grid.numpy(),atol=1e-05)
assert np.allclose(torch_sample.numpy(),jt_sample.numpy(),atol=1e-01)
def test_affine_grid_3d(self):
import torch.nn.functional as F
import torch
N = 8
C = 3
D = 64
H = 256
W = 128
theta = np.random.randn(N,3,4).astype(np.float32)
features = np.random.randint(256,size=(N,C,D,H,W)).astype(np.float32)
torch_theta = torch.Tensor(theta)
torch_features = torch.Tensor(features)
torch_grid = F.affine_grid(torch_theta,size=(N,C,D,H,W),align_corners=False)
torch_sample = F.grid_sample(torch_features,torch_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
jt_theta = jt.array(theta)
jt_features = jt.array(features)
jt_grid = affine_grid(jt_theta,size=(N,C,D,H,W),align_corners=False)
jt_sample = grid_sample(jt_features,jt_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
assert np.allclose(jt_theta.numpy(),torch_theta.numpy())
assert np.allclose(jt_features.numpy(),torch_features.numpy())
assert np.allclose(jt_grid.numpy(),torch_grid.numpy(),atol=1e-05)
assert np.allclose(torch_sample.numpy(),jt_sample.numpy(),atol=1e-01)
if __name__ == "__main__":
unittest.main()

View File

@ -55,6 +55,7 @@ class TestArgReduceOp(unittest.TestCase):
check_backward([5,5,5], 'min', 0, True)
check_backward([5,5,5], 'min', 2, True)
check_backward([5,5,5], 'min', 1, True)
check_backward([5,], 'min', 0, True)
check_backward([20,20,20,20], 'max', 0, True)
check_backward([20,20,20,20], 'max', 2, True)
check_backward([20,20,20,20], 'max', 1, True)
@ -62,6 +63,7 @@ class TestArgReduceOp(unittest.TestCase):
check_backward([5,5,5], 'min', 0, False)
check_backward([5,5,5], 'min', 2, False)
check_backward([5,5,5], 'min', 1, False)
check_backward([5,], 'min', 0, False)
check_backward([20,20,20,20], 'max', 0, False)
check_backward([20,20,20,20], 'max', 2, False)
check_backward([20,20,20,20], 'max', 1, False)
@ -73,6 +75,7 @@ class TestArgReduceOp(unittest.TestCase):
check_backward([5,5,5], 'min', 0, True)
check_backward([5,5,5], 'min', 2, True)
check_backward([5,5,5], 'min', 1, True)
check_backward([5,], 'min', 0, True)
check_backward([20,20,20,20], 'max', 0, True)
check_backward([20,20,20,20], 'max', 2, True)
check_backward([20,20,20,20], 'max', 1, True)
@ -80,6 +83,7 @@ class TestArgReduceOp(unittest.TestCase):
check_backward([5,5,5], 'min', 0, False)
check_backward([5,5,5], 'min', 2, False)
check_backward([5,5,5], 'min', 1, False)
check_backward([5,], 'min', 0, False)
check_backward([20,20,20,20], 'max', 0, False)
check_backward([20,20,20,20], 'max', 2, False)
check_backward([20,20,20,20], 'max', 1, False)
@ -89,6 +93,7 @@ class TestArgReduceOp(unittest.TestCase):
check_reduce([5,5,5], 'min', 0, True)
check_reduce([5,5,5], 'min', 2, True)
check_reduce([5,5,5], 'min', 1, True)
check_reduce([5], 'min', 0, True)
check_reduce([20,20,20,20], 'max', 0, True)
check_reduce([20,20,20,20], 'max', 2, True)
check_reduce([20,20,20,20], 'max', 1, True)
@ -96,6 +101,7 @@ class TestArgReduceOp(unittest.TestCase):
check_reduce([5,5,5], 'min', 0, False)
check_reduce([5,5,5], 'min', 2, False)
check_reduce([5,5,5], 'min', 1, False)
check_reduce([5], 'min', 0, False)
check_reduce([20,20,20,20], 'max', 0, False)
check_reduce([20,20,20,20], 'max', 2, False)
check_reduce([20,20,20,20], 'max', 1, False)
@ -107,6 +113,7 @@ class TestArgReduceOp(unittest.TestCase):
check_reduce([5,5,5], 'min', 0, True, True)
check_reduce([5,5,5], 'min', 2, True, True)
check_reduce([5,5,5], 'min', 1, True, True)
check_reduce([5], 'min', 0, True)
check_reduce([20,20,20,20], 'max', 0, True, True)
check_reduce([20,20,20,20], 'max', 2, True, True)
check_reduce([20,20,20,20], 'max', 1, True, True)
@ -114,9 +121,10 @@ class TestArgReduceOp(unittest.TestCase):
check_reduce([5,5], 'min', 0, False, True)
check_reduce([5,5,5], 'min', 2, False, True)
check_reduce([5,5,5], 'min', 1, False, True)
check_reduce([5], 'min', 0, False)
check_reduce([20,20,20,20], 'max', 0, False, True)
check_reduce([20,20,20,20], 'max', 2, False, True)
check_reduce([20,20,20,20], 'max', 1, False, True)
check_reduce([20,20,20,20], 'max', 3, False, True)
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -7,6 +7,37 @@ import unittest
import jittor as jt
import numpy as np
def concat2(arr, dim):
'''Concat Operator can concat a list of jt Var at a specfic dimension.
* [in] x: input var list for concat
* [in] dim: concat which dim
* [out] out: concat result
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
'''
# TODO: low performance when concat lots of vars
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
for a in arr:
total_dim += a.shape[dim]
cdim = 0
shape = list(a.shape)
shape[dim] = total_dim
s = jt.empty(shape, a.dtype)
slices = [slice(None)]*len(a.shape)
for a in arr:
slices[dim] = slice(cdim, cdim+a.shape[dim])
# print(slices, type(a))
s = s.setitem(tuple(slices), a)
# s = jt.setitem(s, tuple(slices), a)
cdim += a.shape[dim]
return s
class TestConcatOp(unittest.TestCase):
def test_concat_op(self):
@ -18,7 +49,117 @@ class TestConcatOp(unittest.TestCase):
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
print('concat success...')
print('concat success...')
@jt.flag_scope(use_cuda = 1)
def test_concat_perf(self):
def check(dim, size, backward=False):
n = 64
a = jt.random((n,n,n,n))
a.sync()
m = n // size
arr = []
for i in range(m):
arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)])
b = jt.contrib.concat(arr, dim)
if backward:
loss = b * a
b = jt.grad(loss, a)
with jt.profile_scope(1, 0) as rep:
b.sync()
# print(rep)
i = rep[0].index("TotalTime")
stime = 0
for r in rep[1:]:
stime += float(r[i])
bw = 4*64**4*2*2 / stime
# sizeof(float) * numel * (split and concat) * (read and write)
print(f"{dim} {size} {stime/1e6}ms, {bw}GB/s")
return bw
ndim = 4
splits = [1, 2, 4, 8, 16, 32, 64]
m = len(splits)
result = np.zeros((4, m))
result_back = np.zeros((4, m))
for i in range(ndim):
for j in range(m):
result[i,j] = check(i, splits[j])
result_back[i,j] = check(i, splits[j], True)
print(result.T)
print(result_back.T)
'''
[[ 17.02802497 17.12933081 17.10814418 15.49217942]
[ 33.10922467 33.01865886 33.08940182 30.24637466]
[ 62.27219795 62.06702029 61.90039457 58.68727009]
[112.31933307 111.89659519 111.02357161 108.98520165]
[187.24806534 190.68837367 186.73965711 186.32242015]
[280.28594579 278.94498734 284.42015302 284.98722929]
[387.03887468 386.14916854 386.47551229 385.28621521]]
[[ 5.04141217 4.55677858 4.55677363 3.79321142]
[ 9.05243799 8.99777599 8.96021333 7.49345194]
[ 17.45032635 17.36882645 17.14316909 14.98928307]
[ 35.60450372 35.55333375 35.32826879 32.00750909]
[ 61.72854251 62.285231 61.64460882 58.17541776]
[ 97.44981525 96.79104909 95.38118155 95.09154931]
[135.11495888 134.60444658 135.41807381 135.38139881]]
'''
@jt.flag_scope(use_cuda = 1)
def test_concat2_perf(self):
def check(dim, size, backward=False):
n = 64
a = jt.random((n,n,n,n))
a.sync()
m = n // size
arr = []
for i in range(m):
arr.append(a.getitem((slice(None),)*dim + (slice(i*size,i*size+size),)))
b = concat2(arr, dim)
if backward:
loss = b * a
b = jt.grad(loss, a)
with jt.profile_scope(1, 0) as rep:
b.sync()
# print(rep)
i = rep[0].index("TotalTime")
stime = 0
for r in rep[1:]:
stime += float(r[i])
bw = 4*64**4*2*2 / stime
# sizeof(float) * numel * (split and concat) * (read and write)
print(f"{dim} {size} {stime/1e6}ms, {bw}GB/s")
return bw
ndim = 4
splits = [1, 2, 4, 8, 16, 32, 64]
m = len(splits)
result = np.zeros((4, m))
result_back = np.zeros((4, m))
for i in range(ndim):
for j in range(m):
result[i,j] = check(i, splits[j])
result_back[i,j] = check(i, splits[j], True)
print(result.T)
print(result_back.T)
'''
[[ 15.59142118 15.8001291 15.77589713 11.79319714]
[ 31.33130734 31.2476813 31.20394782 23.19700034]
[ 57.90763098 57.71203221 58.02228419 45.60297828]
[104.20428796 104.08291412 104.18568373 91.648383 ]
[175.21896606 175.44422637 176.57915576 168.33344684]
[264.35929995 267.63202466 262.92687504 268.41854563]
[352.36998687 355.89200025 360.95753527 361.34916742]]
[[ 3.39802237 3.42782551 3.43126375 2.85884566]
[ 7.12993628 7.11445323 7.11482319 5.90134142]
[ 15.13540229 15.11031669 15.12954432 12.76302703]
[ 28.08930928 28.09445985 28.01005224 25.43536254]
[ 49.58246623 49.70843778 49.49253912 48.07459389]
[ 80.3745414 80.85044884 79.74203591 80.97114412]
[117.14450249 119.22320442 119.2380328 119.63622556]]
'''
if __name__ == "__main__":
unittest.main()

View File

@ -9,7 +9,7 @@
import unittest
import jittor as jt
import numpy as np
from jittor.test.test_grad import ngrad
class TestSlice(unittest.TestCase):
def test_slice_bool(self):
@ -17,7 +17,122 @@ class TestSlice(unittest.TestCase):
a[1] = True
a[2] = 1
assert a.dtype == "bool"
print(a)
a.sync()
def test_var_slices(self):
def check(slices, msg):
with jt.log_capture_scope() as logs:
jt.core._print_var_slice(slices)
s = logs[0]['msg']
assert s == msg, s
check((1), "[1,]")
check(([[0],[1]],slice(None),[1,2],1), "[int32[2,1,],::,int32[2,],1,]")
check((slice(None),slice(None),slice(None),slice(None)), "[::,::,::,::,]")
check(([0,1],[0,1],[0,1],[0,1]), "[int32[2,],int32[2,],int32[2,],int32[2,],]")
check(([0,1],-2,slice(None),[0,1]), "[int32[2,],-2,::,int32[2,],]")
check(([0,1],slice(1,2,2),[1,2],1), "[int32[2,],1:2:2,int32[2,],1,]")
check(([0,1],slice(None),[1,2],1), "[int32[2,],::,int32[2,],1,]")
check((slice(1,None,2),slice(-1,None,2),[1,2],-4), "[1::2,-1::2,int32[2,],-4,]")
check(0, "[0,]")
check(10, "[10,]")
check(-10, "[-10,]")
check(1, "[1,]")
check((1,slice(None),2), "[1,::,2,]")
check((-2,slice(None),2,slice(1,9,2)), "[-2,::,2,1:9:2,]")
check((None,1,None,2,None), "[-,1,-,2,-,]")
check((...,1,...,2,...), "[...,1,...,2,...,]")
@unittest.skipIf(not jt.has_cuda, "No cuda")
@jt.flag_scope(use_cuda=1)
def test_getitem(self):
def check(shape, slices, i_to_vs="", i_to_o="", o_shape=""):
# print(slices)
x = jt.random(shape)
with jt.log_capture_scope(log_vprefix="getitem=1000") as logs:
a = x.getitem(slices)
a.sync()
b = x.data[slices]
bshape = b.shape if len(b.shape) else (1,)
assert a.shape == bshape, (a.shape, bshape)
s = logs[-1]['msg']
assert "i_to_vs: "+i_to_vs in s
assert "i_to_o: "+i_to_o in s
assert "o_shape: "+o_shape in s
aa = a.numpy()
assert (aa==b).all(), (aa, b)
y = x.numpy()
v = jt.random(a.shape)
z = x.setitem(slices, v)
y[slices] = v.data
assert (z.data==y).all(), (z.data, y, v.data, x.data)
# test_setitem broadcast
adim = len(a.shape)
for mask in range(1<<adim):
new_shape = list(a.shape)
for i in range(adim):
if (mask>>i)&1:
new_shape[i] = 1
y = x.numpy()
v = jt.random(new_shape)
z = x.setitem(slices, v)
y[slices] = v.data
assert (z.data==y).all(), (z.data, y, v.data, x.data)
# TODO: when slice same row/col many times and assign value, numpy will retain the last value but we assign their sum. eg: check([3,3,3,3], ([[0,1,1]],slice(None),[[1],[2],[0]],1))
check([3], (1), "[0,]", "[-1,]", "[]")
check([3,3,3,3], ([[0],[1]],slice(None),[1,2],1), "[0,-1,2,3,]", "[-1,2,-1,-1,]", "[2,2,3,]")
check([3,3,3,3], (slice(None),slice(None),slice(None),slice(None)), "[-1,-2,-2,-2,]", "[0,0,0,0,]", "[81,]")
check([3,3,3,3], ([0,1],[0,1],[0,1],[0,1]), "[0,1,2,3,]", "[-1,-1,-1,-1,]", "[2,]")
check([3,3,3,3], ([0,1],-2,slice(None),[0,1]), "[0,1,-1,3,]", "[-1,-1,1,-1,]", "[2,3,]")
check([3,3,3,3], ([0,1],slice(1,2,2),[1,2],1), "[0,1,2,3,]", "[-1,1,-1,-1,]", "[2,1,]")
check([3,3,3,3], ([0,1],slice(None),[1,2],1), "[0,-1,2,3,]", "[-1,1,-1,-1,]", "[2,3,]")
check([3,3,3,3], (slice(1,10,1),...,slice(2,None,-1)), "[0,-1,-2,2,]", "[0,1,1,2,]", "[2,9,3,]")
check([10,10,10,10], (slice(1,None,2),slice(-1,None,2),[1,2],-4), "[0,1,2,3,]", "[0,1,-1,-1,]", "")
check([20], 0, "[0,]", "[-1,]", "[]")
check([20], 10, "[0,]", "[-1,]", "[]")
check([20], -10, "[0,]", "[-1,]", "[]")
check([10,10,10,10], 1, "[0,-1,-2,-2,]", "[-1,0,0,0,]", "[1000,]")
check([10,10,10,10], (1,slice(None),2), "[0,-1,2,-1,]", "[-1,0,-1,1,]", "")
check([10,10,10,10], (-2,slice(None),2,slice(1,9,2)), "[0,-1,2,3,]", "[-1,0,-1,1,]")
def test_getitem_grad(self):
shape = (10,)
slices = slice(2,4)
a = jt.random(shape)
b = a.getitem(slices)
mask = jt.random(b.shape)
loss = b*mask
da = jt.grad(loss, a)
_, np_grad = ngrad(lambda vars: (vars[0][slices]*mask.data).sum(), [a.numpy()], 1e-3)
assert np.allclose(da.numpy(), np_grad, atol = 1e-3), (da.numpy(), np_grad)
shape = (10,)
slices = slice(2,4)
a = jt.random(shape)
b = a.getitem(slices)
b = jt.random(b.shape)
c = a.setitem(slices, b)
mask = jt.random(c.shape)
loss = c*mask
da,db = jt.grad(loss, [a,b])
def numpy_grad(vars):
a, b = vars
a = a.copy()
a[slices] = b
return (a*mask.data).sum()
_, (nda, ndb) = ngrad(numpy_grad, [a.numpy(), b.numpy()], 1e-3)
assert np.allclose(da.numpy(), nda, atol = 1e-3)
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
if __name__ == "__main__":

View File

@ -12,6 +12,7 @@ import math
import numpy as np
import warnings
from collections.abc import Sequence, Mapping
import jittor as jt
def crop(img, top, left, height, width):
'''
@ -215,6 +216,90 @@ def to_tensor(img):
return np.array(img).transpose((2,0,1)) / np.float32(255)
return img
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if not(isinstance(pic, jt.Var) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
elif isinstance(pic, jt.Var):
if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
elif pic.ndim == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
npimg = pic
if isinstance(pic, jt.Var):
if 'float' in str(pic.dtype) and mode != 'F':
pic = pic.mul(255).uint8()
npimg = np.transpose(pic.numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a jt.Var or NumPy ndarray, ' +
'not {}'.format(type(npimg)))
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = 'L'
elif npimg.dtype == np.int16:
expected_mode = 'I;16'
elif npimg.dtype == np.int32:
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
return Image.fromarray(npimg, mode=mode)
class ImageNormalize:
'''
Class for normalizing the input image.
@ -328,4 +413,53 @@ class RandomCrop:
top = np.random.randint(0,height-self.size[0]+1)
left = np.random.randint(0,width-self.size[1]+1)
return crop(img, top, left, self.size[0], self.size[1])
class Lambda:
"""Apply a user-defined lambda as a transform.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
self.lambd = lambd
def __call__(self, img):
return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class ToTensor:
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class ToPILImage(object):
def __init__(self, mode=None):
self.mode = mode
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
PIL Image: Image converted to PIL Image.
"""
return to_pil_image(pic, self.mode)
def __repr__(self):
format_string = self.__class__.__name__ + '('
if self.mode is not None:
format_string += 'mode={0}'.format(self.mode)
format_string += ')'
return format_string

View File

@ -222,6 +222,16 @@ class Hook:
pickle.dump(ps, f)
LOG.i(f"save params ok")
def hook_function(self, func):
name = func.__name__
def new_func(*args, **kw):
ret = func(*args, **kw)
self.record(name+".args", args)
self.record(name+".kw", kw)
self.record(name+".ret", ret)
return ret
return new_func
def hook_module(self, mod, mod_name=""):
if os.environ.get("use_auto_diff", '1') == '0':

View File

@ -74,8 +74,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
vector<Node*> bfs_q;
bfs_q.reserve(vars.size());
int start_var_num = 0;
{
while (1) {
op_num = 0;
start_var_num = 0;
bfs_q.clear();
// get all nodes need to be executed
int need_opt = 0;
auto t = ++Node::tflag_count;
for (Var* v : vars)
if (!v->is_finished() && v->tflag != t) {
@ -89,6 +93,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
for (auto i : node->_inputs)
if (i.node->tflag != t && !i.node->is_finished()) {
i.node->tflag = t;
need_opt += i.node->flags.get(NodeFlags::_has_gopt);
bfs_q.push_back(i.node);
}
// this var has been fetched
@ -99,11 +104,19 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
!n.node->is_finished() &&
n.node->flags.get(NodeFlags::_fetch)) {
n.node->tflag = t;
need_opt += n.node->flags.get(NodeFlags::_has_gopt);
bfs_q.push_back(n.node);
}
}
}
}
if (!need_opt) break;
for (Node* n : bfs_q) {
if (n->flags.get(NodeFlags::_has_gopt)) {
n->op()->graph_optimize();
n->flags.set(NodeFlags::_has_gopt, 0);
}
}
}
auto tt = Node::tflag_count;
vector<Op*> ops;

View File

@ -55,9 +55,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
return false;
if (node->is_stop_grad())
return false;
// int value has zero grad
if (node->is_var())
return node->var()->is_float();
return true;
});
LOGvv << "Size of grad nodes:" << gnodes.size();

View File

@ -44,6 +44,11 @@ struct JitKey {
explicit hex1(uint data) : data(data) {}
};
struct shex1 {
int data;
explicit shex1(int data) : data(data) {}
};
struct hex2 {
uint data;
explicit hex2(uint data) : data(data) {}
@ -90,6 +95,13 @@ inline JK& operator<<(JK& jk, const JK::hex1& h) {
return jk << (char)((data<10) ? data+'0' : data-10+'a');
}
inline JK& operator<<(JK& jk, const JK::shex1& h) {
if (h.data<0)
return jk << '-' << JK::hex1(-h.data);
else
return jk << JK::hex1(h.data);
}
inline JK& operator<<(JK& jk, const JK::hex2& h) {
return jk << JK::hex1(h.data>>4) << JK::hex1(h.data);
}

View File

@ -212,14 +212,13 @@ void* SFRLAllocator::alloc(size_t size, size_t& allocation) {
if (block == nullptr) {
free_all_sfrl_allocators();
size_t alloc_size = allocation_size(size);
void* ptr = underlying->alloc(alloc_size, allocation);
if (ptr == nullptr) {
void* ptr = nullptr;
try {
ptr = underlying->alloc(alloc_size, allocation);
} catch (...) {
unused_memory -= large_blocks.free_all_cached_blocks(underlying);
unused_memory -= small_blocks.free_all_cached_blocks(underlying);
void* ptr = underlying->alloc(alloc_size, allocation);
if (ptr == nullptr) {
return nullptr;
}
ptr = underlying->alloc(alloc_size, allocation);
}
block = new CachingBlock(alloc_size, blocks, ptr);
} else {

View File

@ -22,8 +22,33 @@ static inline int lzcnt(int64 v) {
struct Slice {
int64 start, stop, step, mask;
inline void fill(int64 size) {
if (step>0) {
if (mask&2)
stop = size;
else if (stop<0)
stop += size;
else
stop = std::min(size, stop);
} else {
if (mask&1) start = size-1;
if (mask&2)
stop = -1;
else if (stop<0)
stop = std::max((int64)0, stop+size);
}
if (start<0) start += size;
mask = 0;
ASSERT(start>=0 && stop>=-1 && start<size && stop<=size)
<< "slice overflow:" << start << stop << step;
}
};
// return a / ceil_to_2_pow(b)
inline uint64 fast_div(uint64 a, uint64 b) {
return a >> (64 - lzcnt(b));
}
// @pyjt(NanoVector)
struct NanoVector {
int64 data=0, offset=0;
@ -108,22 +133,13 @@ struct NanoVector {
// @pyjt(__map_getitem__)
inline NanoVector slice(Slice slice) {
if (slice.step>0) {
if (slice.mask&2) slice.stop = size();
} else {
if (slice.mask&1) slice.start = size()-1;
if (slice.mask&2) slice.stop = 0;
}
if (slice.start<0) slice.start += size();
if (slice.stop<0) slice.stop += size();
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<=size())
<< "slice overflow:" << slice.start << slice.stop << slice.step;
slice.fill(size());
NanoVector v;
if (slice.step>0) {
for (int i=slice.start; i<slice.stop; i+=slice.step)
v.push_back(this->operator[](i));
} else {
for (int i=slice.start; i>=slice.stop; i+=slice.step)
for (int i=slice.start; i>slice.stop; i+=slice.step)
v.push_back(this->operator[](i));
}
return v;
@ -166,14 +182,14 @@ struct NanoVector {
struct Iter {
const NanoVector* self;
int index;
int64 operator*() { return self->at(index); }
Iter& operator++() { index++; return *this; }
Iter operator+(int i) { return {self, i+index}; }
bool operator!=(Iter& other) { return index != other.index; }
inline int64 operator*() { return self->at(index); }
inline Iter& operator++() { index++; return *this; }
inline Iter operator+(int i) { return {self, i+index}; }
inline bool operator!=(Iter& other) { return index != other.index; }
};
Iter begin() { return {this, 0}; }
Iter end() { return {this, size()}; }
inline Iter begin() { return {this, 0}; }
inline Iter end() { return {this, size()}; }
inline void pop_back() { offset--; data &= (1ll<<total_bits())-1; }
inline void push_back(Iter s, Iter t) {

55
src/misc/stack_vector.h Normal file
View File

@ -0,0 +1,55 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
#include "misc/nano_vector.h"
namespace jittor {
template<class T=int64, int N=10>
struct StackVector {
int n;
T a[N+1];
inline T& front() { return a[0]; }
inline T& back() { return a[n-1]; }
inline int size() { return n;}
inline StackVector(int n=0) : n(n) {}
struct Iter {
const StackVector<T,N>* self;
int index;
inline T operator*() { return self->at(index); }
inline Iter& operator++() { index++; return *this; }
inline Iter operator+(int i) { return {self, i+index}; }
inline bool operator!=(Iter& other) { return index != other.index; }
};
inline Iter begin() { return {this, 0}; }
inline Iter end() { return {this, size()}; }
inline T& operator[](int i) { return a[i]; }
inline void pop_back() { n--; }
inline void push_back(T v) { if (n<N) a[n++] = v; }
inline void check() { ASSERT(n<N); }
inline NanoVector to_nano_vector() {
check();
NanoVector nv;
for (int i=0; i<n; i++)
nv.push_back_check_overflow(a[i]);
return nv;
}
};
template<class T, int N>
inline std::ostream& operator<<(std::ostream& os, const StackVector<T,N>& v) {
os << '[';
for (int i=0; i<v.n; i++)
os << v.a[i] << ',';
return os << ']';
}
} // jittor

View File

@ -48,6 +48,8 @@ struct NodeFlags {
_op_type=_n+4, _op_type_nbits=2,
// bit6: backprop grad at ones
_grads=_n+6,
// bit7: has graph optimize
_has_gopt=_n+7,
};
inline void set(Flags f, int a=1, int nbits=1) {

View File

@ -75,6 +75,12 @@ bool Op::shape_infered() {
return true;
}
void Op::compile_optimize(string& src) {}
void Op::infer_shape() {}
void Op::run() {}
void Op::jit_prepare() {}
void Op::graph_optimize() {}
string Op::name_ex() const {
string a=name();

View File

@ -38,9 +38,9 @@ struct Op : Node {
virtual VarPtr grad(Var* out, Var* dout, Var* v, int v_index);
virtual void grads(Var** douts, VarPtr* dins);
virtual void infer_shape() {}
virtual void run() {};
virtual void jit_prepare() {};
virtual void infer_shape();
virtual void run();
virtual void jit_prepare();
virtual void do_jit_prepare();
virtual const char* name() const = 0;
virtual void statistics(uint64_t& in, uint64_t& out, uint64_t& compute);
@ -48,6 +48,8 @@ struct Op : Node {
virtual void do_run_after_prepare();
virtual void do_run();
virtual VarPtr duplicate();
virtual void compile_optimize(string& src);
virtual void graph_optimize();
void jit_run();
string name_ex() const;

View File

@ -416,6 +416,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
vector<string> args;
size_t l = k+1;
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
expr == "is_def" ||
(k<src.size() && src[k]=='(')) {
ASSERT(src[k] == '(');
comma.push_back(k);
@ -447,6 +448,9 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
if (args.size() >= 5) {
step = OpCompiler::eval(vs, defs);
vs = args[4];
for (int i=5; i<args.size(); i++) {
vs += "," + args[i];
}
}
auto new_defs = defs;
LOGvvv << "Expand for" << expr >> "[" >> vil >> "," >> vir >> "," >> step >> "]";
@ -473,6 +477,18 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
i = l-1;
continue;
} else
if (expr == "is_def") {
ASSERT(args.size()==1)
<< "Jit error: is_def wrong arguments.";
string vdef = args[0];
vdef = precompile(defs, vdef, macros);
if (defs.count(vdef) || macros.count(vdef))
new_src += "1";
else
new_src += "0";
i = l-1;
continue;
} else
if (expr == "expand_macro") {
// syntax: @expand_macro(macro, args)
// ij k l
@ -983,8 +999,9 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
src_after_passes = tm.tune();
src = &src_after_passes;
}
op->compile_optimize(*src);
auto ret = oc.compile(op->get_jit_key(), *src);
return ret;
}
}
}

View File

@ -149,6 +149,8 @@ void ArgReduceOp::infer_shape() {
shape.push_back(x->shape[i]);
}
}
if (shape.size() == 0)
shape.push_back(1);
y->set_shape(shape);
y_key->set_shape(shape);
}
@ -205,4 +207,4 @@ void ArgReduceOp::jit_run() {
}
#endif // JIT
} // jittor
} // jittor

View File

@ -162,7 +162,7 @@ void BinaryOp::infer_shape() {
// -1 1 need b
// has 1, b, both 1, not b, 0, error
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
CHECK(xshape && yshape) << "Shape can not broadcast to 0.";
// CHECK(xshape && yshape) << "Shape can not broadcast to 0.";
need_broadcast = true;
continue;
}
@ -198,4 +198,4 @@ void BinaryOp::jit_run() {
}
#endif // JIT
} // jittor
} // jittor

View File

@ -5,12 +5,16 @@
// ***************************************************************
#include "var.h"
#include "ops/candidate_op.h"
#ifdef JIT_cuda
#include "executor.h"
#endif
namespace jittor {
#ifndef JIT
CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_vary_shape);
y = create_output(nullptr, dtype);
}
@ -27,10 +31,76 @@ void CandidateOp::jit_prepare() {
}
#else // JIT
#ifdef JIT_cuda
__global__ static void candidate_kernel(
@for(i, 0, XDIM, 1, index_t xshape@i, )
Tx* __restrict__ xp,
Ty* __restrict__ yp,
bool* __restrict__ maskp,
int* __restrict__ np
) {
int n=0;
int tid = threadIdx.x;
int tnum = blockDim.x;
// define cond stride
index_t xstride@{XDIM-1} = 1;
@for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
// generate d-for loop
for (index_t i=0; i < xshape0; i++) {
__syncthreads();
if (!maskp[i]) continue;
if (tid == 0) {
yp[n] = i;
n++;
}
for (index_t j=i+1+tid; j < xshape0; j+=tnum) {
if (@FUNC) maskp[j] = 0;
}
}
if (tid == 0) {
np[0] = n;
}
}
void CandidateOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
// define cond shape
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
// define ys
auto* __restrict__ yp = y->ptr<Ty>();
size_t n_allocation;
int* np = (int*)exe.allocator->alloc(4, n_allocation);
size_t mask_allocation;
bool* maskp = (bool*)exe.allocator->alloc(xshape0, mask_allocation);
checkCudaErrors(cudaMemsetAsync(maskp, 1, xshape0));
candidate_kernel<<<1, std::max(1, std::min(1024, xshape0)) >>>(
@for(i, 0, XDIM, 1, xshape@i, )
xp,
yp,
maskp,
np
);
int n=0;
// checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
y->set_shape({n});
exe.allocator->free(np, 4, n_allocation);
exe.allocator->free(maskp, xshape0, mask_allocation);
}
#else
void CandidateOp::jit_run() {
using namespace std;
auto* __restrict__ xp = x->ptr<Tx>();
// define cond shape
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
// define cond stride
index_t xstride@{XDIM-1} = 1;
@for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
@ -56,6 +126,7 @@ void CandidateOp::jit_run() {
}
y->set_shape({n});
}
#endif
#endif // JIT
} // jittor
} // jittor

499
src/ops/getitem_op.cc Normal file
View File

@ -0,0 +1,499 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "ops/getitem_op.h"
#include "ops/op_register.h"
#ifdef JIT_cuda
#include <cuda_runtime.h>
#include <helper_cuda.h>
#endif
#ifndef JIT
#include "misc/stack_vector.h"
#include "opt/kernel_ir.h"
#ifdef HAS_CUDA
#include "misc/cuda_flags.h"
#endif
#endif
namespace jittor {
#ifndef JIT
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
static auto make_setitem = get_op_info("setitem")
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
: vs(move(slices)) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt);
create_output(nullptr, x->dtype());
}
void GetitemOp::infer_slices(
StackVector<>& __restrict__ i_to_vs,
StackVector<>& __restrict__ i_to_o,
StackVector<>& __restrict__ out_shape
) {
auto in = inputs().front();
auto in_shape = in->shape;
auto nin = in_shape.size();
i_to_vs.n = i_to_o.n = nin;
out_shape.n = 0;
int vid = 0;
first_oid_of_var = -1;
var_dim = 0;
for (int i=0; i<nin; i++) {
auto& s = vs.slices[vid];
if (vid >= vs.n) {
// i i i
// | | |
// v v v --> overflow
// s s
i_to_vs[i] = -1;
i_to_o[i] = out_shape.size();
out_shape.push_back(in_shape[i]);
} else
if (s.is_var()) {
// i --> s ---> o
// + ---> o
// var maybe multiple dims
if (first_oid_of_var == -1) {
for (int i=0; i<vs.n; i++)
if (vs.slices[i].is_var())
var_dim = std::max(var_dim, vs.slices[i].var->shape.size());
first_oid_of_var = out_shape.size();
for (int j=0; j<var_dim; j++) {
out_shape.push_back(1);
}
}
i_to_vs[i] = vid++;
i_to_o[i] = -1;
auto iv = s.var;
auto iv_shape = iv->shape;
auto niv = iv_shape.size();
for (int j=0; j<niv; j++) {
auto iv_shape_j = iv_shape[niv-j-1];
auto& out_shape_j = out_shape[first_oid_of_var+var_dim-j-1];
if (out_shape_j == 1)
out_shape_j = iv_shape_j;
else
ASSERT(out_shape_j == iv_shape_j || out_shape_j < 0 || iv_shape_j < 0)
<< out_shape_j << iv_shape_j << out_shape;
}
} else
if (s.is_ellipsis()) {
auto remain_slice = vs.n-vid-1;
auto remain_idims = nin-i;
auto ellipsis_size = remain_idims - remain_slice;
ASSERT(ellipsis_size>=0) << "NDims not match";
for (int j=0; j<ellipsis_size; j++) {
i_to_vs[i+j] = -1;
i_to_o[i+j] = out_shape.size();
out_shape.push_back(in_shape[i+j]);
}
vid ++;
i += ellipsis_size-1;
} else
if (s.is_none()) {
i--;
out_shape.push_back(1);
vid++;
continue;
} else
if (s.is_int()) {
i_to_vs[i] = vid++;
i_to_o[i] = -1;
auto in_shape_i = in_shape[i];
auto& v = s.slice.start;
if (v<0) v += in_shape_i;
CHECK(v>=0 && v<in_shape_i) << "slice overflow, " << v << "not in [0,">>in_shape_i>>")";
} else {
// slice
auto& slice = s.slice;
auto in_shape_i = in_shape[i];
auto out_shape_j = in_shape_i;
if (slice.mask == 7) {
// slice is a[::]
// start, stop, step is not filled
vid++;
i_to_vs[i] = -1;
i_to_o[i] = out_shape.size();
out_shape.push_back(out_shape_j);
} else {
i_to_vs[i] = vid++;
i_to_o[i] = out_shape.size();
if (in_shape_i > 0) {
slice.fill(in_shape_i);
if (abs(slice.step) <= 1)
out_shape_j = (slice.stop - slice.start) * slice.step;
else if (slice.step>0)
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
else
out_shape_j = (slice.start - slice.stop - 1) / -slice.step + 1;
out_shape_j = std::max(0l, out_shape_j);
}
out_shape.push_back(out_shape_j);
}
}
}
}
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) {
// bz by bx tz ty tx
// 5 4 3 2 1 0
// LOi: bitmask of used dims of loop i
// LOi bit 6: need for
// if need for, keep for range: for (int i@i=tid; tid<range; tid+=tnum)
// if not need for, replace range -> tnum, for -> int i@i = tid
int rtnum = 1024;
// int max_tnum = {1024, 1024, 64, (1u<<31)-1, 65535, 65535};
int loop_id = (int)o_shape.size()-1;
int tid = 0;
int64 block_size = 1;
int thread_size = 1;
for (int i=0; i<6; i++) tdims[i] = 1;
for (; tid<3 && loop_id>=0 && rtnum>1; tid++) {
int64 si = o_shape[loop_id];
int mask = 1<<tid;
if (tid==2) rtnum = std::min(64, rtnum);
if (si>rtnum*4) {
// need for, use tid(1<<i) and bx(8)
mask |= 8|(1<<6);
block_size = (si-1)/rtnum+1;
tdims[tid] = rtnum;
tdims[3] = block_size;
tid = 3;
thread_size *= rtnum;
rtnum = 0;
} else
if (si>rtnum) {
mask |= (1<<6);
thread_size *= rtnum;
tdims[tid] = rtnum;
rtnum = 0;
} else {
rtnum = rtnum / std::max(si, (int64)1);
thread_size *= si;
tdims[tid] = si;
}
masks[loop_id] = mask;
loop_id --;
}
int64 total_size = (int64)block_size*thread_size;
if (tid<3) tid=3;
for (; tid<6 && loop_id>=0 && total_size<(256*1024); tid++) {
int64 si = o_shape[loop_id];
int mask = 1<<tid;
int64 max_thread = tid>=4 ? 65535 : (1u<<31)-1;
if (si > max_thread) {
si = max_thread;
mask |= 1<<6;
}
total_size *= si;
tdims[tid] = si;
masks[loop_id] = mask;
loop_id --;
}
while (loop_id>=0) {
masks[loop_id--] = 0;
}
}
void GetitemOp::compile_optimize(string& src) {
_compile_optimize(src);
}
void GetitemOp::_compile_optimize(string& src) {
if (!flags.get(NodeFlags::_cuda))
return;
auto jd = get_jit_define();
map<string,string> jd_map(jd.begin(), jd.end());
KernelIR main(src);
auto& func = main.children.back()->children.back();
// auto& loop = func->children.back();
func->push_back("void func() {}", &func->before);
auto& new_func = func->before.back();
// auto new_func = func->before.back()->move_out();
new_func->attrs["dtype"] = "static __global__ void";
// LOGir << main.to_string();
src = main.to_string();
string arg_call = "";
const char* tname[] = {"threadIdx.x", "threadIdx.y", "threadIdx.z", "blockIdx.x", "blockIdx.y", "blockIdx.z"};
const char* tname2[] = {"blockDim.x", "blockDim.y", "blockDim.z", "gridDim.x", "gridDim.y", "gridDim.z"};
for (auto& ir : func->children) {
if (ir->type == "define") {
string& rvalue = ir->attrs.at("rvalue");
string& lvalue = ir->attrs.at("lvalue");
string& dtype = ir->attrs.at("dtype");
if (startswith(rvalue, "input")
|| startswith(rvalue, "output")
|| startswith(rvalue, "vs.")
|| rvalue.back() == ')'
|| rvalue.back() == ']')
{
if (dtype == "auto")
LOGvvvv << "keep" << rvalue;
else {
LOGvvvv << "args" << rvalue;
if (arg_call.size()) arg_call += ", ";
arg_call += lvalue;
LOGvvvv << dtype+" "+lvalue;
new_func->push_back(dtype+" "+lvalue+";", &new_func->inner);
}
} else {
LOGvvvv << "move" <<rvalue;
new_func->push_back(ir->clone());
}
}
}
new_func->push_back(func->children.back()->move_out());
auto& loop = new_func->children.back();
int no = o_shape.size();
KernelIR* loops[no];
if (!no) {
func->push_back("func<<<1,1>>>("+arg_call+");");
} else {
loops[0] = loop.get();
for (int i=1; i<no; i++)
loops[i] = loops[i-1]->children.back().get();
for (int i=0; i<no; i++) {
auto l = loops[i];
ASSERT(l->inner.size() == 3);
auto lo = l->find_define("LO"+S(i));
ASSERT(lo);
auto loi = std::stoi(lo->attrs.at("rvalue"));
string tid = "";
string tnum = "";
for (int j=0; j<6; j++) {
if ((loi>>j)&1) {
if (tid.size()) {
tid += string("+")+tnum+"*"+tname[j];
tnum += string("*")+tname2[j];
} else {
tid = tname[j];
tnum = tname2[j];
}
}
}
if (!tid.size()) {
continue;
}
if (loi&(1<<6)) {
l->inner.at(0)->attrs.at("rvalue") = tid;
l->inner.at(2)->attrs.at("code") = "i"+S(i)+"+="+tnum+";";
} else {
// no need for
while (l->inner.size())
l->inner.at(0)->erase();
l->push_front("index_t i"+S(i)+" = "+tid+";");
}
}
func->push_back("int no = o_shape.size();");
func->push_back("int masks[no];");
func->push_back("int tdims[6];");
func->push_back("cuda_loop_schedule(o_shape, masks, tdims);");
func->push_back("dim3 grid_dim(tdims[3],tdims[4],tdims[5]);");
func->push_back("dim3 block_dim(tdims[0],tdims[1],tdims[2]);");
func->push_back("func<<<grid_dim, block_dim>>>("+arg_call+");");
}
src = main.to_string();
}
void GetitemOp::infer_shape() {
auto in = inputs().front();
auto out = outputs().front();
auto in_shape = in->shape;
auto nin = in_shape.size();
StackVector<> i_to_vs(nin);
StackVector<> i_to_o(nin);
// shape return to use
StackVector<> out_shape;
infer_slices(i_to_vs, i_to_o, out_shape);
// optimized shape (each dim is a loop var)
StackVector<> o_shape;
int fov = -1;
for (int i=0; i<nin; i++) {
auto& vid = i_to_vs[i];
auto& oid = i_to_o[i];
auto os = out_shape[oid];
if (oid>=0) {
if (vid==-1 && i && i_to_vs[i-1]<0) {
vid = -2;
o_shape.back() *= os;
} else
o_shape.push_back(os);
oid = o_shape.size()-1;
} else {
auto& s = vs.slices[vid];
if (s.is_var() && fov == -1) {
fov = o_shape.size();
for (int i=0; i<var_dim; i++)
o_shape.push_back(out_shape[first_oid_of_var+i]);
}
}
}
first_oid_of_var = fov;
if (!out_shape.size()) out_shape.push_back(1);
out->set_shape(out_shape.to_nano_vector());
this->i_to_vs = i_to_vs.to_nano_vector();
this->i_to_o = i_to_o.to_nano_vector();
this->o_shape = o_shape.to_nano_vector();
LOGvvvv << "\ni_to_vs:" << i_to_vs
<< "\ni_to_o:" << i_to_o
<< "\no_shape:" << o_shape;
}
VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index)
return nullptr;
auto zeros = make_number(0, v);
// TODO: maybe add here?
// need analysis the overlap attr os var slices
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
}
void GetitemOp::jit_prepare() {
auto in = inputs().front();
int idim = i_to_vs.size();
add_jit_define("Ti", in->dtype());
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
add_jit_define("ODIM", JK::hex1(o_shape.size()));
if (first_oid_of_var>=0) {
add_jit_define("FOV", JK::hex1(first_oid_of_var));
add_jit_define("VD", JK::hex1(var_dim));
}
for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i];
auto io = i_to_o[i];
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
auto& v = vs.slices[iv];
if (iv>=0 && io==-1) {
if (v.is_int()) {
add_jit_define("VS", JK::hex1(i), "-1");
} else {
ASSERT(v.is_var());
auto var = v.var;
auto vshape = var->shape;
auto vdim = vshape.size();
int vsmask = 0;
for (int j=0; j<vdim; j++) {
int k = first_oid_of_var+j+var_dim-vdim;
if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim);
}
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
add_jit_define("VST", JK::hex1(i), var->dtype());
}
} else
if (iv>=0 && io>=0) {
ASSERT(v.is_slice());
if (std::abs(v.slice.step) <= 1)
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
else
add_jit_define("VS", JK::hex1(i), "0");
}
}
#ifdef HAS_CUDA
if (use_cuda) {
int no = o_shape.size();
int masks[no];
int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) {
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
}
}
#endif
}
#else // JIT
#pragma GCC diagnostic ignored "-Wunused-variable"
void GetitemOp::jit_run() {
auto in = inputs().front();
auto out = outputs().front();
if (out->num == 0) return;
@for(i, 0, ODIM, index_t oshape@i = o_shape[@i];)
@if(ODIM>0,
index_t ostride@{ODIM-1} = 1;
@for(i, ODIM-2, -1, -1, index_t ostride@i = ostride@{i+1} * oshape@{i+1};)
)
Ti* op = out->ptr<Ti>();
Ti* ip = in->ptr<Ti>();
@for(i, 0, IDIM, index_t ishape@i =
@if(IV@i==-1,oshape@{IO@i},
@if(IV@i==-2,1,in->shape[@i]));
)
index_t istride@{IDIM-1} = 1;
@for(i, IDIM-2, -1, -1, index_t istride@i = istride@{i+1} * ishape@{i+1};)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i>=0,
index_t vstart@i = vs.slices[@{IV@i}].slice.start;
index_t vstep@i = @if(VS@i==0,vs.slices[@{IV@i}].slice.step;,@{VS@i});
)
)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i<0,
@if(VS@i==-1,index_t vi@i = vs.slices[@{IV@i}].slice.start;);
)
)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i<0,
@if(VS@i>=0,
index_t vs@i@@s@{VD-1} = 1;
VST@i* vp@i = vs.slices[IV@i].var->ptr<VST@i>();
@for(j,VD-2,-1,-1,index_t vs@i@@s@j = vs@i@@s@{j+1} *
@if((VS@i>>(j+1))&1,oshape@{j+1+FOV},1);
)
);
)
)
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
index_t oid = 0 @for(d, 0, ODIM, + i@d * ostride@d);
@for(d, 0, IDIM, index_t iid@d =
@if(IV@d==-1, i@{IO@d},
@if(IV@d==-2, 0,
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
@if(VS@d==-1, vi@d,
@if(VS@d>=0,
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? )))));
)
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
op[oid] = ip[iid];
}
}
#endif // JIT
} // jittor

42
src/ops/getitem_op.h Normal file
View File

@ -0,0 +1,42 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
#include "var_slices.h"
#include "misc/stack_vector.h"
namespace jittor {
struct GetitemOp : Op {
VarSlices vs;
// map i to related var slice
NanoVector i_to_vs;
// map i to related o
NanoVector i_to_o;
NanoVector o_shape;
int first_oid_of_var, var_dim;
GetitemOp(Var* x, VarSlices&& slices);
const char* name() const override { return "getitem"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
void compile_optimize(string& src) override;
void graph_optimize() override;
DECLARE_jit_run;
void infer_slices(
StackVector<>& __restrict__ i_to_vs,
StackVector<>& __restrict__ i_to_o,
StackVector<>& __restrict__ out_shape
);
void _compile_optimize(string& src);
};
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims);
} // jittor

View File

@ -16,9 +16,22 @@ static auto make_broadcast_to = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
VarPtr make_number(float number, Var* x) {
VarPtr nums = make_array(&number, 1, ns_float32);
nums = make_broadcast_to(nums, x, {});
return make_unary(nums, x->dtype());
union Number {
float32 f32;
float64 f64;
int32 i32;
int64 i64;
} v;
if (x->dtype() == ns_float32) v.f32 = number; else
if (x->dtype() == ns_float64) v.f64 = number; else
if (x->dtype() == ns_int32) v.i32 = number; else
if (x->dtype() == ns_int64) v.i64 = number; else {
VarPtr nums = make_array(&number, 1, ns_float32);
nums = make_broadcast_to(nums, x, {});
return make_unary(nums, x->dtype());
}
VarPtr nums = make_array(&v, 1, x->dtype());
return make_broadcast_to(nums, x, {});
}
static void init() {

327
src/ops/setitem_op.cc Normal file
View File

@ -0,0 +1,327 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "ops/setitem_op.h"
#include "ops/getitem_op.h"
#ifdef JIT
#include "ops/binary_op_defs.h"
#ifdef JIT_cuda
#include <cuda_runtime.h>
#include <helper_cuda.h>
#endif
#else
#include "ops/op_register.h"
#ifdef HAS_CUDA
#include "misc/cuda_flags.h"
#endif
#endif
namespace jittor {
#ifndef JIT
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_number = get_op_info("number")
.get_constructor<VarPtr, float, Var*>();
static auto make_getitem = get_op_info("getitem")
.get_constructor<VarPtr, Var*, VarSlices&&>();
static auto make_setitem = get_op_info("setitem")
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
static auto make_binary = get_op_info("binary")
.get_constructor<VarPtr, Var*, Var*, NanoString>();
static auto make_unary = get_op_info("unary")
.get_constructor<VarPtr, Var*, NanoString>();
SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
: vs(move(slices)), op(op) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_has_gopt);
ASSERT(ns == ns_void || ns.is_binary());
create_output(nullptr, x->dtype());
}
void SetitemOp::infer_shape() {
auto in = inputs().front();
auto data = input(1);
auto out = outputs().front();
auto in_shape = in->shape;
auto nin = in_shape.size();
StackVector<> i_to_vs(nin);
StackVector<> i_to_o(nin);
// shape return to use
StackVector<> out_shape;
((GetitemOp*)this)->infer_slices(i_to_vs, i_to_o, out_shape);
if (!out_shape.size()) out_shape.push_back(1);
// get broadcast mask of set value
auto data_shape = data->shape;
auto data_dim = data_shape.size();
int bmask = 0;
int bmask2 = 0;
ASSERTop(data_dim,<=,out_shape.size()) << "Data dimension not match";
for (int i=0; i<data_dim; i++) {
int j = i - data_dim + out_shape.size();
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
ASSERTop(data_shape[i],==,out_shape[j]) << "Data shape not match" << data_shape << out_shape;
bmask |= 1<<j;
}
}
// optimized shape (each dim is a loop var)
StackVector<> o_shape;
int fov = -1;
for (int i=0; i<nin; i++) {
auto& vid = i_to_vs[i];
auto& oid = i_to_o[i];
auto os = out_shape[oid];
if (oid>=0) {
if (vid==-1 && i && i_to_vs[i-1]<0
&& ((bmask>>oid)&1) == ((bmask>>(oid-1))&1))
// same broadcast condition with prev dim
{
vid = -2;
o_shape.back() *= os;
} else {
o_shape.push_back(os);
// fix bmask2 offset
bmask2 |= ((bmask>>oid)&1) << (o_shape.size()-1);
}
oid = o_shape.size()-1;
} else {
auto& s = vs.slices[vid];
if (s.is_var() && fov == -1) {
fov = o_shape.size();
for (int i=0; i<var_dim; i++) {
o_shape.push_back(out_shape[first_oid_of_var+i]);
// fix bmask2 offset
bmask2 |= ((bmask>>(first_oid_of_var+i))&1) << (o_shape.size()-1);
}
}
}
}
first_oid_of_var = fov;
this->bmask = bmask2;
out->set_shape(in_shape);
this->i_to_vs = i_to_vs.to_nano_vector();
this->i_to_o = i_to_o.to_nano_vector();
this->o_shape = o_shape.to_nano_vector();
LOGvvvv << "\ni_to_vs:" << i_to_vs
<< "\ni_to_o:" << i_to_o
<< "\no_shape:" << o_shape;
}
VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index >= 2)
return nullptr;
if (op == ns_void) {
if (v_index == 0) {
float32 number = 0;
VarPtr zero = make_array(&number, 1, ns_float32);
return make_setitem(dout, VarSlices(vs), zero, ns_void);
} else {
return make_getitem(dout, VarSlices(vs));
}
}
if (op == ns_add) {
if (v_index == 0) {
return dout;
} else {
return make_getitem(dout, VarSlices(vs));
}
}
if (op == ns_subtract) {
if (v_index == 0) {
return dout;
} else {
return make_unary(make_getitem(dout, VarSlices(vs)), ns_negative);
}
}
if (op == ns_multiply) {
if (v_index == 0) {
return make_setitem(dout, VarSlices(vs), input(1), ns_multiply);
} else {
return make_binary(
make_getitem(inputs().front(), VarSlices(vs)),
make_getitem(dout, VarSlices(vs)), ns_multiply);
}
}
if (op == ns_divide) {
if (v_index == 0) {
return make_setitem(dout, VarSlices(vs), input(1), ns_divide);
} else {
// dy = -dz*x / y^2
auto dout2 = make_getitem(dout, VarSlices(vs));
auto x = make_getitem(inputs().front(), VarSlices(vs));
auto y = v;
auto ndz = make_unary(dout2, ns_negative);
auto ndzx = make_binary(ndz, x, ns_multiply);
auto y2 = make_binary(y, y, ns_multiply);
return make_binary(ndzx, y2, ns_divide);
}
}
LOGf << "Setitem grad of op" << op << "is not supported yet";
return nullptr;
}
void SetitemOp::jit_prepare() {
auto data = input(1);
add_jit_define("OP", op);
add_jit_define("Td", data->dtype());
add_jit_define("BMASK", JK::hex(bmask));
// TODO: merge code
auto in = inputs().front();
int idim = i_to_vs.size();
add_jit_define("Ti", in->dtype());
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
add_jit_define("ODIM", JK::hex1(o_shape.size()));
if (first_oid_of_var>=0) {
add_jit_define("FOV", JK::hex1(first_oid_of_var));
add_jit_define("VD", JK::hex1(var_dim));
}
for (int i=0; i<idim; i++) {
auto iv = i_to_vs[i];
auto io = i_to_o[i];
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
auto& v = vs.slices[iv];
if (iv>=0 && io==-1) {
if (v.is_int()) {
add_jit_define("VS", JK::hex1(i), "-1");
} else {
ASSERT(v.is_var());
auto var = v.var;
auto vshape = var->shape;
auto vdim = vshape.size();
int vsmask = 0;
for (int j=0; j<vdim; j++) {
int k = first_oid_of_var+j+var_dim-vdim;
if (vshape[j] == o_shape[k])
vsmask |= 1<<(j+var_dim-vdim);
}
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
add_jit_define("VST", JK::hex1(i), var->dtype());
}
} else
if (iv>=0 && io>=0) {
ASSERT(v.is_slice());
if (std::abs(v.slice.step) <= 1)
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
else
add_jit_define("VS", JK::hex1(i), "0");
}
}
#ifdef HAS_CUDA
if (use_cuda) {
int no = o_shape.size();
int masks[no];
int tdims[6];
cuda_loop_schedule(o_shape, masks, tdims);
for (int i=0; i<no; i++) {
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
}
}
#endif
}
void SetitemOp::compile_optimize(string& src) {
((GetitemOp*)this)->_compile_optimize(src);
}
#else // JIT
#pragma GCC diagnostic ignored "-Wunused-variable"
void SetitemOp::jit_run() {
auto in = inputs().front();
auto data = input(1);
auto out = outputs().front();
if (out->num == 0) return;
@for(i, 0, ODIM, index_t oshape@i = o_shape[@i];)
@if(ODIM>0,
index_t ostride@{ODIM-1} = 1;
@for(i, ODIM-2, -1, -1, index_t ostride@i = ostride@{i+1} * oshape@{i+1};)
)
Ti* op = out->ptr<Ti>();
Ti* ip = in->ptr<Ti>();
@for(i, 0, IDIM, index_t ishape@i =
@if(IV@i==-1,oshape@{IO@i},
@if(IV@i==-2,1,in->shape[@i]));
)
index_t istride@{IDIM-1} = 1;
@for(i, IDIM-2, -1, -1, index_t istride@i = istride@{i+1} * ishape@{i+1};)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i>=0,
index_t vstart@i = vs.slices[@{IV@i}].slice.start;
index_t vstep@i = @if(VS@i==0,vs.slices[@{IV@i}].slice.step;,@{VS@i});
)
)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i<0,
@if(VS@i==-1,index_t vi@i = vs.slices[@{IV@i}].slice.start;);
)
)
@for(i, 0, IDIM,
@if(IV@i>=0 && IO@i<0,
@if(VS@i>=0,
index_t vs@i@@s@{VD-1} = 1;
VST@i* vp@i = vs.slices[IV@i].var->ptr<VST@i>();
@for(j,VD-2,-1,-1,index_t vs@i@@s@j = vs@i@@s@{j+1} *
@if((VS@i>>(j+1))&1,oshape@{j+1+FOV},1);
)
);
)
)
Td* dp = data->ptr<Td>();
@if(ODIM>0,
index_t dstride@{ODIM-1} = 1;
@for(i, ODIM-2, -1, -1, index_t dstride@i = dstride@{i+1} * @if((BMASK>>(i+1))&1,oshape@{i+1},1);)
)
#ifdef JIT_cpu
if (op != ip)
std::memcpy(op, ip, out->size);
#else
if (op != ip)
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
#endif
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
index_t did = 0 @for(d, 0, ODIM, @if((BMASK>>d)&1,+ i@d * dstride@d));
@for(d, 0, IDIM, index_t iid@d =
@if(IV@d==-1, i@{IO@d},
@if(IV@d==-2, 0,
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
@if(VS@d==-1, vi@d,
@if(VS@d>=0,
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
, ??? )))));
)
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
@if(@strcmp(@OP,void)==0,
op[iid] = (Ti)dp[did],
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
);
}
}
#endif // JIT
} // jittor

34
src/ops/setitem_op.h Normal file
View File

@ -0,0 +1,34 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
#include "var_slices.h"
namespace jittor {
struct SetitemOp : Op {
VarSlices vs;
// map i to related var slice
NanoVector i_to_vs;
// map i to related o
NanoVector i_to_o;
NanoVector o_shape;
int first_oid_of_var, var_dim;
int bmask;
NanoString op;
SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op=ns_void);
const char* name() const override { return "setitem"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
void compile_optimize(string& src) override;
void graph_optimize() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -49,8 +49,11 @@ Tapes::Tapes(
const vector<VarHolder*>& taped_outputs,
GradCallback&& grad_callback
) {
callback = move(grad_callback);
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_grads);
callback = move(grad_callback);
/*
stop grad stop grad
@ -89,4 +92,4 @@ void tape_together(
new Tapes(taped_inputs, taped_outputs, move(grad_callback));
}
} // jittor
} // jittor

View File

@ -5,12 +5,17 @@
// ***************************************************************
#include "var.h"
#include "ops/where_op.h"
#ifdef JIT_cuda
#include "executor.h"
#include <assert.h>
#endif
namespace jittor {
#ifndef JIT
WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_vary_shape);
auto ndim = cond->shape.size();
outs.reset(new Var*[ndim]);
@ -33,6 +38,186 @@ void WhereOp::jit_prepare() {
}
#else // JIT
#ifdef JIT_cuda
__global__ static void where_kernel(
@for(i, 0, NDIM, 1, index_t condshape@i, )
Ti* __restrict__ condp,
@for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, )
int* __restrict__ np
) {
__shared__ uint n;
int tid = threadIdx.x;
int tnum = blockDim.x;
if (tid == 0)
n = 0;
// define cond stride
index_t condstride@{NDIM-1} = 1;
@for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};)
__syncthreads();
// generate d-for loop
@for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++))
for (index_t i@{NDIM-1}=tid; i@{NDIM-1}<condshape@{NDIM-1}; i@{NDIM-1}+=tnum)
{
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
if (condp[condid]) {
uint cn = atomicInc(&n, 1u<<30);
@for(i, 0, NDIM, outs@i@@p[cn] = i@i;)
}
}
__syncthreads();
if (tid == 0)
(*np) = n;
}
__device__ inline uint prefix_sum(uint val, uint lane_id) {
#define FULL_MASK 0xffffffff
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
uint x = __shfl_up_sync(FULL_MASK, val, offset);
val += lane_id>=offset? x : 0;
}
return val;
}
__device__ inline uint bc(uint val, uint lane_id) {
return __shfl_sync(FULL_MASK, val, lane_id);
}
__global__ static void where_kernel_one_warp(
@for(i, 0, NDIM, 1, index_t condshape@i, )
Ti* __restrict__ condp,
@for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, )
int* __restrict__ np
) {
uint n = 0;
int tid = threadIdx.x;
int tnum = 32;
// define cond stride
index_t condstride@{NDIM-1} = 1;
@for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};)
// generate d-for loop
@for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++))
for (index_t i=0; i<condshape@{NDIM-1}; i+=tnum)
{
index_t i@{NDIM-1} = i + tid;
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
uint prefix_x = prefix_sum(x, tid);
if (x) {
uint cn = n + prefix_x - 1;
@for(i, 0, NDIM, outs@i@@p[cn] = i@i;)
}
n += bc(prefix_x, 31);
}
if (tid == 0)
(*np) = n;
}
#define WTN 1024
__global__ static void where_kernel_one_block(
@for(i, 0, NDIM, 1, index_t condshape@i, )
Ti* __restrict__ condp,
@for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, )
int* __restrict__ np
) {
uint n = 0;
int tid = threadIdx.x;
int tnum = WTN;
__shared__ uint s[WTN/32];
int wid = tid / 32;
int lid = tid % 32;
int wnum = WTN / 32;
// define cond stride
index_t condstride@{NDIM-1} = 1;
@for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};)
// generate d-for loop
@for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++))
for (index_t i=0; i<condshape@{NDIM-1}; i+=tnum)
{
index_t i@{NDIM-1} = i + tid;
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
uint prefix_x = prefix_sum(x, lid);
uint warp_sum = bc(prefix_x, 31);
// prefix sum between warps
if (lid == 0) {
s[wid] = warp_sum;
}
__syncthreads();
if (wid == 0) {
s[lid] = prefix_sum(s[lid], lid);
}
__syncthreads();
uint warp_prefix_sum = s[wid];
uint block_sum = s[wnum-1];
__syncthreads();
if (x) {
uint cn = n + prefix_x - 1 + warp_prefix_sum - warp_sum;
@for(i, 0, NDIM, outs@i@@p[cn] = i@i;)
}
n += block_sum;
}
if (tid == 0)
(*np) = n;
}
void WhereOp::jit_run() {
auto* __restrict__ condp = cond->ptr<Ti>();
// define cond shape
@for(i, 0, NDIM, index_t condshape@i = cond->shape[@i];)
// define outs
@for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr<To>();)
size_t n_allocation;
int* np = (int*)exe.allocator->alloc(4, n_allocation);
// one block kernel, result maybe unstable
// int tnum = condshape@{NDIM-1};
// tnum = std::max(1, std::min(1024, tnum));
// where_kernel<<<1,tnum>>>(
// @for(i, 0, NDIM, 1, condshape@i, )
// condp,
// @for(i, 0, NDIM, 1, outs@i@@p, )
// np
// );
int tnum = condshape@{NDIM-1};
if (tnum < 100) {
// one warp kernel, result is stable
where_kernel_one_warp<<<1,32>>>(
@for(i, 0, NDIM, 1, condshape@i, )
condp,
@for(i, 0, NDIM, 1, outs@i@@p, )
np
);
} else {
// one block kernel, result is stable
where_kernel_one_block<<<1,WTN>>>(
@for(i, 0, NDIM, 1, condshape@i, )
condp,
@for(i, 0, NDIM, 1, outs@i@@p, )
np
);
}
int n=0;
// checkCudaErrors(cudaDeviceSynchronize());
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
@for(i, 0, NDIM, outs[@i]->set_shape({n});)
exe.allocator->free(np, 4, n_allocation);
}
#else
void WhereOp::jit_run() {
auto* __restrict__ condp = cond->ptr<Ti>();
// define cond shape
@ -55,6 +240,8 @@ void WhereOp::jit_run() {
}
@for(i, 0, NDIM, outs[@i]->set_shape({n});)
}
#endif
#endif // JIT
} // jittor
} // jittor

View File

@ -0,0 +1,119 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "ops/setitem_op.h"
#include "ops/getitem_op.h"
namespace jittor {
inline static bool fast_strcmp(const char* a, const char* b) {
while (*b && *a == *b) a++, b++;
return !*b;
}
static void setitem_inplace(SetitemOp* op) {
// LOGir << "setitem_inplace";
auto input = op->inputs().front();
if (!(input->outputs().size() == 1 &&
input->forward_liveness<=1 &&
(op->op == ns_void || op->op == ns_add || op->op == ns_subtract))) {
return;
}
auto input_op = input->input();
if (input_op) {
// make sure input op will not use input
auto input_name = input_op->name();
if (!(input_op->type() == OpType::broadcast ||
fast_strcmp(input_name, "array") ||
fast_strcmp(input_name, "empty") ||
fast_strcmp(input_name, "setitem") ||
fast_strcmp(input_name, "getitem")))
// TODO: inplace getitem maybe risky, getitem maybe inplace too
return;
}
auto output = op->outputs().front();
output->share_with(input);
// LOGir << "apply setitem_inplace on" << op << "input:" << input << "output:" << output;
}
struct BBox {
int n = 0;
int* minmax = nullptr;
void load_var_slice(const VarSlice& vs) {
}
};
static void setitem_grad_opt(GetitemOp* op) {
if (!op->flags.get(NodeFlags::_has_gopt))
return;
auto get_in = op->inputs().front();
auto get_in_op = get_in->input();
if (!get_in_op)
return;
auto name = get_in_op->name();
if (!fast_strcmp(name, "setitem"))
return;
// find setitem op chain
auto first_set = (SetitemOp*)get_in_op;
vector<SetitemOp*> chain;
while (1) {
auto next = first_set->inputs().front()->input();
if (!next) break;
if (!fast_strcmp(next->name(), "setitem"))
break;
chain.push_back(first_set);
first_set = (SetitemOp*)next;
}
chain.push_back(first_set);
for (int i=0; i<chain.size()/2; i++)
std::swap(chain[i], chain[chain.size()-1-i]);
auto last_set = (SetitemOp*)get_in_op;
while (1) {
SetitemOp* next = nullptr;
auto out_var = last_set->outputs().front();
for (auto* out : out_var->outputs()) {
if (fast_strcmp(out->name(), "setitem")) {
next = (SetitemOp*)out;
break;
}
}
if (!next) break;
last_set = next;
chain.push_back(next);
}
// LOGir << "find setitem chain" << chain.size() << chain;
for (auto* sop : chain) {
// LOGig << sop << sop->vs;
auto out_var = sop->outputs().front();
for (auto* out : out_var->outputs()) {
if (fast_strcmp(out->name(), "getitem")) {
out->flags.set(NodeFlags::_has_gopt, 0);
}
}
}
}
void SetitemOp::graph_optimize() {
// LOGir << "hello graph_optimize";
setitem_inplace(this);
}
void GetitemOp::graph_optimize() {
// This optimize is still WIP
// LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this);
(void)setitem_grad_opt;
}
}

View File

@ -66,13 +66,13 @@ string strip(const string& s) {
}
void KernelIR::del_scope() {
if (father && (type=="define" || type=="func")) {
if (father && (type=="define" || type=="func" || type=="macro")) {
father->scope[attrs["lvalue"]].remove(this);
}
}
void KernelIR::add_scope() {
if (father && (type=="define" || type=="func"))
if (father && (type=="define" || type=="func" || type=="macro"))
father->scope[get_attr("lvalue")].push_back(this);
}
@ -485,6 +485,7 @@ string KernelIR::to_string(int level, bool debug) {
auto iter = attrs.find("code");
ASSERT(iter != attrs.end()) << attrs << type << father;
s << iter->second << "\n";
has_bc = attrs.count("has_bc");
} else {
s << "\n";
}
@ -590,6 +591,10 @@ KernelIR::KernelIR(const string& src, bool raw) {
if (i==start && k==end) {
attrs["code"] = src;
type = "macro";
auto v = split(src, " ", 3);
ASSERT(v.size()>1);
attrs["lvalue"] = v.at(1);
attrs["rvalue"] = v.size()>2 ? v.at(2) : "";
return;
} else {
push_back(src.substr(j, k-j), nullptr, raw);
@ -650,6 +655,17 @@ KernelIR::KernelIR(const string& src, bool raw) {
// func define
if (s.size()>=2 && s.back()=='}') {
int l = s.find("{");
ASSERT(l != string::npos);
if (startswith(s, "namespace ")) {
// namespace xxx {...}
// l
attrs["code"] = s.substr(0, l);
attrs["has_bc"] = "1";
type = "";
i = j + l;
end--;
continue;
}
int ll = s.rfind("(", l);
int rr = s.rfind(")", l);
// if () not found, maybe src like this:

View File

@ -63,6 +63,12 @@ void PassManager::run_passes() {
LOGvvvv << "KernelIR:\n" << ir.to_string();
if (oc->op->ops.size() == 1 && oc->op->ops[0]->name() == string("array")) {
ir.remove_all_unused();
if (oc->op->flags.get(NodeFlags::_cuda)) {
ir.children.back()->erase();
string type = oc->op->ops[0]->outputs().front()->dtype().to_cstring();
ir.push_back("kernel<<<1,1>>>(op0_outputp, op0->ptr<"+type+">()[0]);");
ir.push_back("__global__ static void kernel(jittor::"+type+"* xp, jittor::"+type+" x) { xp[0] = x; } ", &ir.before, true);
}
return;
}
run_pass<MarkRawPass>();

View File

@ -392,14 +392,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
auto kh = w->shape[wformat.find("h")];
auto kw = w->shape[wformat.find("w")];
if (kh != kw) {
LOGvvvv << "TODO: relay conv_backward_w when kh != kw" << kh << kw;
continue;
}
LOGvvvv << x << y << kh << stride << padding << dilation << groups << xformat << wformat << yformat;
auto make_conv_w = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
rvar = make_conv_w(x, y, kh, stride, padding, dilation, groups, xformat, wformat, yformat);
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, string, string, string>();
rvar = make_conv_w(x, y, kh, kw, stride, padding, dilation, groups, xformat, wformat, yformat);
}
LOGvvvv << relay_conv_name << "output:" << rvar;

View File

@ -191,6 +191,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
FusedOp& fused_op = *fop_needs_compile[-rid-1];
op = &fused_op;
LOGvv << "Compile FusedOp:" << op;
LOGV(11) << "FusedOps:" << fused_op.ops;
fused_op.context = new FusedOpContext();
fused_op.context->setup(&fused_op);
fused_op.do_prepare();
@ -238,11 +239,11 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
}
}
}; // end of threads.launch_all
int active_threads = std::min(thread_num, (int)op_needs_compile.size());
threads.launch_all(active_threads, func);
typedef std::chrono::high_resolution_clock Time;
auto start = Time::now();
int active_threads = std::min(thread_num, (int)op_needs_compile.size());
threads.launch_all(active_threads, func);
int prev_i = 0;
bool change_line = false;
int sleep_us = 10;
@ -298,4 +299,4 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
}
} // jittor
} // jittor

View File

@ -691,4 +691,54 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
return func;
}
struct VarSlices;
// Slice
DEF_IS(VarSlices, bool) is_type(PyObject* obj) {
return PyTuple_CheckExact(obj) ||
PyLong_CheckExact(obj) ||
PySlice_Check(obj) ||
(Py_TYPE(obj) == &PyEllipsis_Type) ||
obj == Py_None ||
is_type<VarHolder*>(obj);
}
template<class T>
void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>& holders) {
if (PyLong_CheckExact(obj)) {
var_slice->set_int(PyLong_AsLong(obj));
} else
if (PySlice_Check(obj)) {
var_slice->slice = from_py_object<decltype(var_slice->slice)>(obj);
} else
if (Py_TYPE(obj) == &PyEllipsis_Type) {
var_slice->set_ellipsis();
} else
if (obj == Py_None) {
var_slice->set_none();
}else {
holders.emplace_back();
auto* vh = from_py_object<VarHolder*>(obj, holders.back());
auto vv = (Var**)vh;
var_slice->set_var(vv[0]);
}
}
DEF_IS(VarSlices, T) from_py_object(PyObject* obj, vector<unique_ptr<VarHolder>>& holders) {
if (PyTuple_CheckExact(obj)) {
auto size = Py_SIZE(obj);
T vs(size);
auto arr = PySequence_Fast_ITEMS(obj);
for (int i=0; i<size; i++) {
auto oi = arr[i];
load_var_slice(oi, vs.slices+i, holders);
}
return vs;
} else {
T vs(1);
load_var_slice(obj, vs.slices, holders);
return vs;
}
}
} // jittor

View File

@ -0,0 +1,41 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "common.h"
#include "misc/nano_vector.h"
namespace jittor {
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims);
JIT_TEST(cuda_loop_schedule) {
auto check = [&](const vector<int64>& shape, const vector<int>& masks, vector<int> tdims={}) {
int masks2[shape.size()];
int tdims2[6];
cuda_loop_schedule(shape, masks2, tdims2);
while (tdims.size() < 6) tdims.push_back(1);
for (int i=0; i<shape.size(); i++)
ASSERT(masks2[i] == masks[i]) << i << shape << masks << vector<int>(masks2, masks2+shape.size());
for (int i=0; i<6; i++)
ASSERT(tdims.at(i)==tdims2[i]) << tdims << vector<int>(tdims2, tdims2+6);
};
check({0}, {1}, {0,1,1,1,1,1});
check({2,2,2,2}, {8, 4, 2, 1}, {2,2,2,2,1,1});
check({2048,1024}, {8, 1}, {1024,1,1,2048,1,1});
check({2048,1025}, {8, 1+(1<<6)}, {1024,1,1,2048,1,1});
check({2048,3025}, {8, 1+(1<<6)}, {1024,1,1,2048,1,1});
check({2048,4425}, {16, 1+8+(1<<6)}, {1024,1,1,5,2048,1});
check({2048, 2048,4425}, {0, 16, 1+8+(1<<6)}, {1024,1,1,5,2048,1});
check({3,3,3,4425}, {0, 32, 16, 1+8+(1<<6)}, {1024,1,1,5,3,3});
check({3,3,3,4425, 3,3}, {0, 32, 16, 8+4+(1<<6), 2, 1}, {3,3,64,70,3,3});
check({3,3,3,12, 9,9}, {32, 16, 8, 4, 2, 1}, {9,9,12,3,3,3});
check({3,3,3,13, 9,9}, {32, 16, 8, 4+64, 2, 1}, {9,9,12,3,3,3});
check({3,3,3,13*4, 9,9}, {0, 32, 16, 8+4+64, 2, 1}, {9,9,12,5,3,3});
check({3,3,3,100, 3,3}, {32, 16, 8, 4+64, 2, 1}, {3,3,64,3,3,3});
check({3,3,3,400, 3,3}, {0, 32, 16, 8+4+64, 2, 1}, {3,3,64,7,3,3});
}
} // jittor

View File

@ -305,9 +305,9 @@ int system_popen(const char* cmd) {
void system_with_check(const char* cmd) {
auto ret = system_popen(cmd);
CHECK(ret!=-1) << "Run cmd failed:" << cmd <<
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
"\nreturn -1. This might be an overcommit issue or out of memory."
<< "Try : echo 1 >/proc/sys/vm/overcommit_memory";
<< "Try : sudo sysctl vm.overcommit_memory=1";
CHECKop(ret,==,0) << "Run cmd failed:" << cmd;
}

36
src/var_slices.cc Normal file
View File

@ -0,0 +1,36 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var_slices.h"
#include "var.h"
namespace jittor {
std::ostream& operator<<(std::ostream& os, const VarSlices& vs) {
os << '[';
for (int i=0; i<vs.n; i++)
os << vs.slices[i] << ",";
return os << ']';
}
std::ostream& operator<<(std::ostream& os, const VarSlice& s) {
if (s.is_var()) return os << s.var->dtype() << s.var->shape;
if (s.is_ellipsis()) return os << "...";
if (s.is_slice()) return os << s.slice;
if (s.is_int()) return os << s.i;
return os << "-";
}
std::ostream& operator<<(std::ostream& os, const Slice& s) {
if (!(s.mask & 1)) os << s.start;
os << ':';
if (!(s.mask & 2)) os << s.stop;
os << ':';
if (!(s.mask & 4)) os << s.step;
return os;
}
} // jittor

67
src/var_slices.h Normal file
View File

@ -0,0 +1,67 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
#include "misc/nano_vector.h"
namespace jittor {
struct Slice;
union VarSlice {
Slice slice;
Var* var;
int64 i;
inline bool is_var() const { return slice.mask == -1; }
inline bool is_ellipsis() const { return slice.mask == -2; }
inline bool is_none() const { return slice.mask == -3; }
inline bool is_int() const { return slice.mask == -4; }
inline bool is_slice() const { return slice.mask >= 0; }
inline void set_var(Var* v) { slice.mask = -1; var = v; }
inline void set_ellipsis() { slice.mask = -2; }
inline void set_none() { slice.mask = -3; }
inline void set_int(int64 v) { slice.mask = -4; i = v; }
};
struct VarSlices {
VarSlice* slices;
int n;
inline VarSlices() : slices(nullptr) {}
inline VarSlices(int n) : slices(new VarSlice[n]), n(n) {}
inline ~VarSlices() {if (slices) delete[] slices;}
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
other.slices = nullptr;
}
inline VarSlices(const VarSlices& other) : slices(new VarSlice[other.n]), n(other.n) {
for (int i=0; i<n; i++)
slices[i] = other.slices[i];
}
inline void operator=(VarSlices&& other) {
if (slices) delete[] slices;
n = other.n;
slices = other.slices;
other.slices = nullptr;
}
inline void operator=(const VarSlices& other) {
if (slices) delete[] slices;
slices = new VarSlice[other.n];
n = other.n;
for (int i=0; i<n; i++)
slices[i] = other.slices[i];
}
};
std::ostream& operator<<(std::ostream& os, const VarSlices& vs);
std::ostream& operator<<(std::ostream& os, const VarSlice& s);
std::ostream& operator<<(std::ostream& os, const Slice& s);
// @pyjt(_print_var_slice)
inline void _print_var_slice(VarSlices&& vs) {
LOGi << vs;
}
} // jittor