mirror of https://github.com/Jittor/Jittor
version 1.2
This commit is contained in:
parent
d533f3960b
commit
c792b23f48
|
@ -49,6 +49,8 @@ void CubArgReduceOp::infer_shape() {
|
|||
if (keepdims) {
|
||||
shape.push_back(1);
|
||||
}
|
||||
if (shape.size() == 0)
|
||||
shape.push_back(1);
|
||||
y->set_shape(shape);
|
||||
y_key->set_shape(shape);
|
||||
}
|
||||
|
@ -104,4 +106,4 @@ void CubArgReduceOp::jit_run() {
|
|||
#endif // JIT_cuda
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -42,8 +42,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
|
@ -57,8 +57,8 @@ void CudnnConvBackwardWOp::infer_shape() {
|
|||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
wco = yc, wci = xc / groups;
|
||||
wh = kernel_size;
|
||||
ww = kernel_size;
|
||||
wh = kh;
|
||||
ww = kw;
|
||||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
}
|
||||
|
||||
|
|
|
@ -13,14 +13,14 @@ namespace jittor {
|
|||
|
||||
struct CudnnConvBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kernel_size, stride, padding, dilation, groups;
|
||||
int kh, kw, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "cudnn_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
@ -58,8 +58,8 @@ void MklConvBackwardWOp::infer_shape() {
|
|||
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
|
||||
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
|
||||
wco = yc, wci = xc / groups;
|
||||
wh = kernel_size;
|
||||
ww = kernel_size;
|
||||
wh = kh;
|
||||
ww = kw;
|
||||
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
|
||||
}
|
||||
|
||||
|
@ -97,7 +97,8 @@ void MklConvBackwardWOp::jit_run() {
|
|||
int height = x->shape[findc("@XFORMAT",'c')];
|
||||
int width = x->shape[findc("@XFORMAT",'d')];
|
||||
int ch_out = dw->shape[findc("@WFORMAT",'o')];
|
||||
int kernel_size = dw->shape[findc("@WFORMAT",'h')];
|
||||
int kh = dw->shape[findc("@WFORMAT",'h')];
|
||||
int kw = dw->shape[findc("@WFORMAT",'w')];
|
||||
|
||||
auto* __restrict__ net_src = x->ptr<Txd>();
|
||||
auto* __restrict__ net_diff_dst = dy->ptr<Tyd>();
|
||||
|
@ -114,9 +115,9 @@ void MklConvBackwardWOp::jit_run() {
|
|||
|
||||
memory::dims conv_src_tz = {batch, ch_in, height, width};
|
||||
memory::dims conv_weights_tz = groups>1
|
||||
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size}
|
||||
: memory::dims{ch_out, ch_in, kernel_size, kernel_size};
|
||||
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
|
||||
? memory::dims{groups, ch_out/groups, ch_in/groups, kh, kw}
|
||||
: memory::dims{ch_out, ch_in, kh, kw};
|
||||
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kh*dilation+dilation-1)/stride+1, (width+padding*2-kw*dilation+dilation-1)/stride+1};
|
||||
memory::dims conv_strides = {stride, stride};
|
||||
memory::dims conv_padding = {padding, padding};
|
||||
memory::dims conv_dilation = {dilation-1, dilation-1};
|
||||
|
|
|
@ -13,14 +13,14 @@ namespace jittor {
|
|||
|
||||
struct MklConvBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kernel_size, stride, padding, dilation, groups;
|
||||
int kh, kw, stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.20'
|
||||
__version__ = '1.2.0.0'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -233,11 +233,22 @@ def ones(shape, dtype="float32"):
|
|||
shape = (shape,)
|
||||
return unary(1, dtype).broadcast(shape)
|
||||
|
||||
def ones_like(x):
|
||||
return ones(x.shape,x.dtype)
|
||||
|
||||
def zeros(shape, dtype="float32"):
|
||||
if not isinstance(shape, (NanoVector, Sequence)):
|
||||
shape = (shape,)
|
||||
return unary(0, dtype).broadcast(shape)
|
||||
|
||||
def full(shape,val,dtype="float32"):
|
||||
if not isinstance(shape, (NanoVector, Sequence)):
|
||||
shape = (shape,)
|
||||
return unary(val, dtype).broadcast(shape)
|
||||
|
||||
def zeros_like(x):
|
||||
return zeros(x.shape,x.dtype)
|
||||
|
||||
flags = core.flags()
|
||||
|
||||
def std(x):
|
||||
|
@ -311,9 +322,17 @@ def squeeze(x, dim):
|
|||
return x.reshape(shape[:dim] + shape[dim+1:])
|
||||
Var.squeeze = squeeze
|
||||
|
||||
def clamp(x, min_v, max_v):
|
||||
assert min_v <= max_v
|
||||
return x.maximum(min_v).minimum(max_v)
|
||||
def clamp(x, min_v=None, max_v=None):
|
||||
if x.shape[0]==0:
|
||||
return x
|
||||
if min_v is not None and max_v is not None:
|
||||
assert min_v <= max_v
|
||||
if min_v is not None:
|
||||
x = x.maximum(min_v)
|
||||
if max_v is not None:
|
||||
x = x.minimum(max_v)
|
||||
return x
|
||||
|
||||
Var.clamp = clamp
|
||||
|
||||
def type_as(a, b):
|
||||
|
@ -574,6 +593,8 @@ class Module:
|
|||
else:
|
||||
if hasattr(v, k):
|
||||
v = getattr(v, k)
|
||||
assert isinstance(v, (Module, Var)), \
|
||||
f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}"
|
||||
else:
|
||||
end = 1
|
||||
break
|
||||
|
@ -582,6 +603,8 @@ class Module:
|
|||
n_failed += 1
|
||||
LOG.w(f'load parameter {key} failed ...')
|
||||
else:
|
||||
assert isinstance(v, Var), \
|
||||
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||
v.update(array(params[key]))
|
||||
|
@ -872,4 +895,4 @@ from .nn import matmul
|
|||
from . import contrib
|
||||
from . import numpy2cupy
|
||||
from .contrib import concat
|
||||
from .misc import *
|
||||
from .misc import *
|
||||
|
|
|
@ -241,7 +241,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
if "multiple_outputs" not in attrs:
|
||||
jit_cc_src.append(f"""
|
||||
VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{
|
||||
Op* _op = new {op_name}({", ".join(op_make_args)});
|
||||
auto _op = new {op_name}({", ".join(op_make_args)});
|
||||
if (_op->outputs_holder.size() != 1) {{
|
||||
delete _op;
|
||||
LOGf << "Wrong output size of" << \"{op_name}\";
|
||||
|
@ -261,7 +261,7 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
else:
|
||||
jit_cc_src.append(f"""
|
||||
vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
|
||||
Op* _op = new {op_name}({", ".join(op_make_args)});
|
||||
auto _op = new {op_name}({", ".join(op_make_args)});
|
||||
if (_op->flags.get(NodeFlags::_forwarded)) {{
|
||||
vector<VarPtr> outputs = move(_op->outputs_holder);
|
||||
delete _op;
|
||||
|
@ -408,6 +408,15 @@ def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
|
|||
arg_type.replace("Var", "VarHolder")+' '+arg)
|
||||
new_args.append(arg)
|
||||
more_src.append(f"_op->add_inputs({arg});")
|
||||
elif arg_type.startswith("VarSlices"):
|
||||
new_args_def.append(arg_def)
|
||||
new_args.append(arg)
|
||||
more_src.append(f"""
|
||||
vector<Var*> svars;
|
||||
for (int i=0; i<_op->vs.n; i++)
|
||||
if (_op->vs.slices[i].is_var())
|
||||
svars.push_back(_op->vs.slices[i].var);
|
||||
_op->add_inputs(svars);""")
|
||||
else:
|
||||
new_args_def.append(arg_def)
|
||||
new_args.append(arg)
|
||||
|
|
|
@ -42,7 +42,7 @@ Example::
|
|||
indexes[dim] = f"i{dim}-{cdim}"
|
||||
b = a.reindex(shape, indexes)
|
||||
# ugly fix for preventing large fused op
|
||||
if len(arr)>=10:
|
||||
if len(arr)>=100:
|
||||
b.stop_fuse()
|
||||
if s is None:
|
||||
s = b
|
||||
|
@ -99,6 +99,20 @@ def slice_var_index(x, slices):
|
|||
cnt_list = 0
|
||||
extras_idx = []
|
||||
extras = []
|
||||
has_ellipse = 0
|
||||
ellipse_index = 0
|
||||
for s,i in zip(slices,range(len(slices))):
|
||||
if isinstance(s,type(...)):
|
||||
has_ellipse+=1
|
||||
ellipse_index = i
|
||||
if has_ellipse>1:
|
||||
raise Exception(f"There are more than one ...")
|
||||
elif has_ellipse==1:
|
||||
slices = list(slices)
|
||||
del slices[ellipse_index]
|
||||
while len(slices)<len(shape):
|
||||
slices.insert(ellipse_index,slice(None))
|
||||
|
||||
for i in range(len(shape)):
|
||||
if i>=len(slices):
|
||||
s = slice(None)
|
||||
|
@ -119,6 +133,7 @@ def slice_var_index(x, slices):
|
|||
step = 1 if s.step is None else s.step
|
||||
if start<0: start += sp
|
||||
if stop<0: stop += sp
|
||||
if stop>sp+1: stop = sp
|
||||
out_shape.append(1+int(max(0, (stop-start-1)//step)))
|
||||
out_index.append(f"{start}+i{j}*{step}")
|
||||
elif isinstance(s, jt.Var):
|
||||
|
@ -160,3 +175,57 @@ def setitem(x, slices, value):
|
|||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
# PATCH
|
||||
def getitem(x, slices):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
return getitem(x, slices.where())
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
return x.getitem(slices)
|
||||
|
||||
def setitem(x, slices, value):
|
||||
if isinstance(slices, jt.Var) and slices.dtype == "bool":
|
||||
mask = jt.broadcast(slices, x)
|
||||
value = jt.broadcast(value, x)
|
||||
return mask.ternary(value, mask)
|
||||
if isinstance(slices, list):
|
||||
slices = tuple(slices)
|
||||
return x.assign(x.setitem(slices, value))
|
||||
|
||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
def concat(arr, dim):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
* [in] dim: concat which dim
|
||||
|
||||
* [out] out: concat result
|
||||
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
'''
|
||||
# TODO: low performance when concat lots of vars
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
for a in arr:
|
||||
total_dim += a.shape[dim]
|
||||
cdim = 0
|
||||
shape = list(a.shape)
|
||||
shape[dim] = total_dim
|
||||
s = jt.empty(shape, a.dtype)
|
||||
slices = [slice(None)]*len(a.shape)
|
||||
for a in arr:
|
||||
if a.shape[dim] == 0:
|
||||
continue
|
||||
slices[dim] = slice(cdim, cdim+a.shape[dim])
|
||||
# print(slices, type(a))
|
||||
s = s.setitem(tuple(slices), a)
|
||||
# s = jt.setitem(s, tuple(slices), a)
|
||||
cdim += a.shape[dim]
|
||||
return s
|
||||
|
|
|
@ -56,6 +56,45 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
|
|||
def relu_invariant_gauss_(var, mode="fan_in"):
|
||||
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
|
||||
|
||||
def calculate_std(var,mode,nonlinearity,param=0.01):
|
||||
mode = mode.lower()
|
||||
assert isinstance(param,(int,float))
|
||||
assert var.ndim>=2
|
||||
assert mode in ['fan_in', 'fan_out']
|
||||
|
||||
fan = var.shape[1] if mode == 'fan_in' else var.shape[0]
|
||||
fan *= var[0][0].numel()
|
||||
|
||||
gains = {
|
||||
'linear':1,
|
||||
'conv1d':1,
|
||||
'conv2d':1,
|
||||
'conv3d':1,
|
||||
'conv_transpose1d':1,
|
||||
'conv_transpose2d':1,
|
||||
'conv_transpose3d':1,
|
||||
'sigmoid':1,
|
||||
'tanh':5.0/3,
|
||||
'relu':math.sqrt(2.0),
|
||||
'leaky_relu':math.sqrt(2.0 / (1 + param ** 2)),
|
||||
}
|
||||
gain = gains[nonlinearity]
|
||||
std = gain/math.sqrt(fan)
|
||||
return std
|
||||
|
||||
|
||||
def kaiming_uniform_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
std = calculate_std(var,mode,nonlinearity,a)
|
||||
bound = math.sqrt(3.0) * std
|
||||
with jt.no_grad():
|
||||
return uniform_(var,-bound, bound)
|
||||
|
||||
def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
std = calculate_std(var,mode,nonlinearity,a)
|
||||
with jt.no_grad():
|
||||
return gauss_(var,0, std)
|
||||
|
||||
|
||||
#TODO: bound = gain * math.sqrt(6.0/fan) ??
|
||||
def xavier_uniform(shape, dtype, gain=1.0):
|
||||
assert len(shape)>1
|
||||
|
@ -81,4 +120,4 @@ def xavier_gauss(shape, dtype, gain=1.0):
|
|||
return gauss(shape, dtype, 0, std)
|
||||
|
||||
def xavier_gauss_(var, gain=1.0):
|
||||
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
|
||||
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
import jittor as jt
|
||||
import numpy as np
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
|
||||
def repeat(x, *shape):
|
||||
r'''
|
||||
|
@ -94,6 +95,22 @@ def expand(x, shape):
|
|||
return x.broadcast(shape)
|
||||
jt.Var.expand = expand
|
||||
|
||||
|
||||
def median(x,dim=None,keepdim=False):
|
||||
if dim is None:
|
||||
x = x.reshape(-1)
|
||||
dim=0
|
||||
_,x = x.argsort(dim)
|
||||
slices = [slice(None) for i in range(dim-1)]
|
||||
k = (x.shape[dim]-1)//2
|
||||
if keepdim:
|
||||
slices.append(slice(k,k+1))
|
||||
else:
|
||||
slices.append(k)
|
||||
return x[tuple(slices)]
|
||||
|
||||
jt.Var.median = median
|
||||
|
||||
def stack(x, dim=0):
|
||||
r'''
|
||||
Concatenates sequence of vars along a new dimension.
|
||||
|
@ -116,8 +133,10 @@ def stack(x, dim=0):
|
|||
[[[1 2 3]
|
||||
[[4 5 6]]]
|
||||
'''
|
||||
assert isinstance(x, list)
|
||||
assert len(x) >= 2
|
||||
assert isinstance(x, Sequence)
|
||||
if len(x) < 2:
|
||||
return x[0].unsqueeze(dim)
|
||||
|
||||
res = [x_.unsqueeze(dim) for x_ in x]
|
||||
return jt.contrib.concat(res, dim=dim)
|
||||
jt.Var.stack = stack
|
||||
|
@ -140,6 +159,10 @@ def flip(x, dim=0):
|
|||
[[4 3 2 1]]
|
||||
'''
|
||||
assert isinstance(dim, int)
|
||||
if dim<0:
|
||||
dim+=x.ndim
|
||||
assert dim>=0 and dim<len(x.shape)
|
||||
|
||||
tar_dims = []
|
||||
for i in range(len(x.shape)):
|
||||
if i == dim:
|
||||
|
@ -258,6 +281,8 @@ def unbind(x, dim=0):
|
|||
if dim < 0: dim += len(x.shape)
|
||||
return [x[(slice(None),)*dim+(i,)] for i in range(x.shape[dim])]
|
||||
|
||||
jt.Var.unbind = unbind
|
||||
|
||||
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
|
||||
assert range == None
|
||||
assert scale_each == False
|
||||
|
@ -268,3 +293,321 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
|
|||
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
|
||||
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
|
||||
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, Iterable):
|
||||
return x
|
||||
return tuple([x]*n)
|
||||
return parse
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
_triple = _ntuple(3)
|
||||
_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
def unique(x):
|
||||
r'''
|
||||
Returns the unique elements of the input tensor.
|
||||
|
||||
Args:
|
||||
|
||||
x– the input tensor.
|
||||
'''
|
||||
x = x.reshape(-1)
|
||||
_,x = jt.argsort(x)
|
||||
index2 = [i for i in range(1,x.shape[0])]
|
||||
index1 = [i for i in range(x.shape[0]-1)]
|
||||
y = x[1:][x[index2] != x[index1]]
|
||||
x = jt.contrib.concat([x[:1],y],dim=0)
|
||||
return x
|
||||
|
||||
jt.Var.unique = unique
|
||||
|
||||
|
||||
def hypot(a,b):
|
||||
return jt.sqrt(a.sqr()+b.sqr())
|
||||
|
||||
def rad2deg(x):
|
||||
return 180 * x / np.pi
|
||||
|
||||
jt.Var.rad2deg = rad2deg
|
||||
|
||||
def deg2rad(x):
|
||||
return x * np.pi / 180.
|
||||
|
||||
jt.Var.deg2rad = deg2rad
|
||||
|
||||
def arctan2(y,x):
|
||||
angle = jt.zeros(x.shape,dtype=x.dtype)
|
||||
mask = x!=0.0
|
||||
if angle[mask].numel()>0:
|
||||
angle[mask] = jt.arctan(y[mask]/x[mask])
|
||||
|
||||
mask = (y<0) & (x<0)
|
||||
if angle[mask].numel()>0:
|
||||
angle[mask] -= np.pi
|
||||
|
||||
mask = (y>0) &(x<0)
|
||||
if angle[mask].numel()>0:
|
||||
angle[mask] +=np.pi
|
||||
return angle
|
||||
|
||||
|
||||
|
||||
def nonzero(x):
|
||||
r'''
|
||||
Return the index of the elements of input tensor which are not equal to zero.
|
||||
'''
|
||||
x = jt.where(x)
|
||||
x = [xx.unsqueeze(1) for xx in x]
|
||||
if len(x)<2:
|
||||
return x[0]
|
||||
x = jt.contrib.concat(x,dim=1)
|
||||
return x
|
||||
|
||||
jt.Var.nonzero = nonzero
|
||||
|
||||
|
||||
def arange(start=0, end=None, step=1,dtype=None):
|
||||
if end is None:
|
||||
end,start = start,0
|
||||
l = round((end-start)//step)+1
|
||||
if (l-1)*step+start>=end:
|
||||
l-=1
|
||||
x = jt.index((l,),0)
|
||||
x = x*step+start
|
||||
if dtype is not None:
|
||||
x= x.cast(dtype)
|
||||
return x
|
||||
|
||||
def randperm(n, dtype="int64"):
|
||||
x = np.arange(n)
|
||||
np.random.shuffle(x)
|
||||
return jt.array(x).cast(dtype)
|
||||
|
||||
def log2(x):
|
||||
return jt.log(x)/math.log(2.0)
|
||||
|
||||
jt.Var.log2 = log2
|
||||
|
||||
def item(x):
|
||||
assert x.ndim==1 and x.shape[0]==1
|
||||
return x.data[0]
|
||||
|
||||
jt.Var.item = item
|
||||
|
||||
def meshgrid(*tensors):
|
||||
r'''
|
||||
Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids,
|
||||
where the i th grid is defined by expanding the i th input over dimensions defined by other inputs.
|
||||
'''
|
||||
size = len(tensors)
|
||||
shape = []
|
||||
for i in range(size):
|
||||
assert isinstance(tensors[i],jt.Var) and tensors[i].ndim==1
|
||||
shape.append(tensors[i].shape[0])
|
||||
grids = []
|
||||
view_shape = [1]*size
|
||||
for i in range(size):
|
||||
vs = view_shape[:]
|
||||
vs[i]=-1
|
||||
grids.append(tensors[i].reshape(vs).expand(shape))
|
||||
|
||||
return grids
|
||||
|
||||
|
||||
def split(d,split_size,dim):
|
||||
r'''
|
||||
Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
If split_size is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
|
||||
|
||||
If split_size is a list, then tensor will be split into len(split_size) chunks with sizes in dim according to split_size_or_sections.
|
||||
|
||||
Args:
|
||||
d (Tensor) – tensor to split.
|
||||
|
||||
split_size (int) or (list(int)) – size of a single chunk or list of sizes for each chunk
|
||||
|
||||
dim (int) – dimension along which to split the tensor.
|
||||
'''
|
||||
if isinstance(split_size,int):
|
||||
shape = d.shape[dim]
|
||||
if shape % split_size == 0:
|
||||
split_size = [split_size]*(shape//split_size)
|
||||
else:
|
||||
split_size = [split_size]*(shape//split_size)+[shape%split_size]
|
||||
if isinstance(split_size, Iterable):
|
||||
assert sum(split_size)==d.shape[dim]
|
||||
|
||||
if dim<0:
|
||||
dim+=d.ndim
|
||||
|
||||
ans = []
|
||||
last = 0
|
||||
for i in split_size:
|
||||
if i==0:
|
||||
shape = list(d.shape)
|
||||
shape[dim]=0
|
||||
new_d = jt.zeros(tuple(shape),dtype=d.dtype)
|
||||
ans.append(new_d)
|
||||
continue
|
||||
|
||||
ss = (slice(None),)*dim+(slice(last,last+i),)
|
||||
new_d = d[ss]
|
||||
last +=i
|
||||
ans.append(new_d)
|
||||
return tuple(ans)
|
||||
|
||||
jt.Var.split = split
|
||||
|
||||
def tolist(x):
|
||||
return x.numpy().tolist()
|
||||
jt.Var.tolist = tolist
|
||||
|
||||
def topk(input, k, dim=None, largest=True, sorted=True):
|
||||
if input.numel()==0:
|
||||
return jt.array([],dtype=input.dtype),jt.array([],dtype='int32')
|
||||
if dim is None:
|
||||
dim = -1
|
||||
if dim<0:
|
||||
dim+=input.ndim
|
||||
|
||||
index,values = jt.argsort(input,dim=dim,descending=largest)
|
||||
dims = (slice(None),)*dim+(slice(0,k),)
|
||||
indices = index[dims]
|
||||
values = values[dims]
|
||||
return values,indices
|
||||
|
||||
jt.Var.topk = topk
|
||||
|
||||
def kthvalue(input, k, dim=None, keepdim=False):
|
||||
if dim is None:
|
||||
dim = -1
|
||||
if dim<0:
|
||||
dim+=input.ndim
|
||||
index,values = jt.argsort(input,dim=dim)
|
||||
dims = (slice(None),)*dim+(slice(k-1,k),)
|
||||
indices = index[dims]
|
||||
values = values[dims]
|
||||
if not keepdim and indices.ndim>1:
|
||||
indices = indices.squeeze(dim)
|
||||
values = values.squeeze(dim)
|
||||
return values,indices
|
||||
|
||||
jt.Var.kthvalue = kthvalue
|
||||
|
||||
|
||||
def gather(x,dim,index):
|
||||
if dim<0:
|
||||
dim+=index.ndim
|
||||
x_shape = list(x.shape )
|
||||
i_shape = list(index.shape)
|
||||
assert i_shape[dim]>0
|
||||
assert x.ndim == index.ndim
|
||||
i_shape[dim]=x_shape[dim]
|
||||
assert i_shape == x_shape
|
||||
ins = []
|
||||
for i in range(index.ndim):
|
||||
ins.append(jt.index(index.shape,dim=i))
|
||||
ins[dim]=index
|
||||
return x.reindex(ins)
|
||||
|
||||
|
||||
def prod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
x = x.sum(dim=dim)
|
||||
return jt.exp(x)
|
||||
|
||||
jt.Var.prod = prod
|
||||
|
||||
def cumsum_forward(np, data):
|
||||
a = data['inputs'][0]
|
||||
b = data['outputs'][0]
|
||||
np.cumsum(a, axis=1, out=b)
|
||||
|
||||
def cumsum_backward(np, data):
|
||||
dout = data['dout']
|
||||
out = data['outputs'][0]
|
||||
np.cumsum(dout[:, ::-1], axis=1, out=out)
|
||||
np.copyto(out, out[:, ::-1])
|
||||
|
||||
def cumsum(x, dim=None):
|
||||
'''
|
||||
Parameters:
|
||||
-----------
|
||||
x: [batch_size, N], jt.var
|
||||
|
||||
Returns:
|
||||
--------
|
||||
the cumulative sum of x
|
||||
'''
|
||||
return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])
|
||||
|
||||
jt.Var.cumsum = cumsum
|
||||
|
||||
def cumprod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
x = cumsum(x,dim=dim)
|
||||
return jt.exp(x)
|
||||
|
||||
jt.Var.cumprod=cumprod
|
||||
|
||||
def nms(dets,thresh):
|
||||
'''
|
||||
dets jt.array [x1,y1,x2,y2,score]
|
||||
x(:,0)->x1,x(:,1)->y1,x(:,2)->x2,x(:,3)->y2,x(:,4)->score
|
||||
'''
|
||||
threshold = str(thresh)
|
||||
order = jt.argsort(dets[:,4],descending=True)[0]
|
||||
dets = dets[order]
|
||||
s_1 = '(@x(j,2)-@x(j,0)+1)*(@x(j,3)-@x(j,1)+1)'
|
||||
s_2 = '(@x(i,2)-@x(i,0)+1)*(@x(i,3)-@x(i,1)+1)'
|
||||
s_inter_w = 'max((Tx)0,min(@x(j,2),@x(i,2))-max(@x(j,0),@x(i,0))+1)'
|
||||
s_inter_h = 'max((Tx)0,min(@x(j,3),@x(i,3))-max(@x(j,1),@x(i,1))+1)'
|
||||
s_inter = s_inter_h+'*'+s_inter_w
|
||||
iou = s_inter + '/(' + s_1 +'+' + s_2 + '-' + s_inter + ')'
|
||||
fail_cond = iou+'>'+threshold
|
||||
selected = jt.candidate(dets, fail_cond)
|
||||
return order[selected]
|
||||
|
||||
|
||||
jt.Var.expand = jt.Var.broadcast
|
||||
jt.Var.expand_as = jt.Var.broadcast_var
|
||||
|
||||
|
||||
def index_fill_(x,dim,indexs,val):
|
||||
r'''
|
||||
Fills the elements of the input tensor with value val by selecting the indices in the order given in index.
|
||||
|
||||
Args:
|
||||
x - the input tensor
|
||||
dim - dimension along which to index
|
||||
index – indices of input tensor to fill in
|
||||
val – the value to fill with
|
||||
'''
|
||||
overflow_conditions = [f'i{dim}=={i}'for i in indexs]
|
||||
indexs = [f'i{i}' for i in range(len(x.shape))]
|
||||
return x.reindex(shape = x.shape,indexes = indexs,overflow_conditions=overflow_conditions,overflow_value=val)
|
||||
|
||||
def triu_(x,diagonal=0):
|
||||
r'''
|
||||
Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
|
||||
|
||||
The upper triangular part of the matrix is defined as the elements on and above the diagonal.
|
||||
|
||||
Args:
|
||||
x – the input tensor.
|
||||
|
||||
diagonal – the diagonal to consider,default =0
|
||||
'''
|
||||
l = len(x.shape)
|
||||
assert l>1
|
||||
overflow_conditions=[f'i{l-1}<i{l-2}+{diagonal}']
|
||||
indexs = [f'i{i}' for i in range(l)]
|
||||
return x.reindex(x.shape,indexs,overflow_conditions=overflow_conditions,overflow_value=0)
|
||||
|
||||
jt.Var.triu_ = triu_
|
||||
|
|
|
@ -15,8 +15,10 @@ from jittor import init, Module
|
|||
import numpy as np
|
||||
import collections
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
|
||||
from jittor.optim import *
|
||||
from jittor.misc import _pair
|
||||
|
||||
|
||||
def matmul_transpose(a, b):
|
||||
|
@ -25,6 +27,10 @@ def matmul_transpose(a, b):
|
|||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1]
|
||||
if len(a.shape)>2:
|
||||
aa = a.reshape((-1, a.shape[-1]))
|
||||
cc = matmul_transpose(aa, b)
|
||||
return cc.reshape(a.shape[:-1]+(-1,))
|
||||
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
|
@ -108,6 +114,12 @@ Example::
|
|||
# cc:[..., n, m, k]
|
||||
# -->
|
||||
# 012
|
||||
if len_b == 2 and len_a>2:
|
||||
# TODO:ugly implementation for tuner
|
||||
aa = a.reshape((-1, m))
|
||||
cc = matmul(aa, b)
|
||||
print(a.shape, b.shape, cc.shape)
|
||||
return cc.reshape(a.shape[:-1] + [k])
|
||||
for i in range(len_c-2):
|
||||
ai = len_a-(len_c-i)
|
||||
bi = len_b-(len_c-i)
|
||||
|
@ -182,11 +194,34 @@ def bce_loss(output, target, weight=None, size_average=True):
|
|||
def l1_loss(output, target):
|
||||
return (output-target).abs().mean()
|
||||
|
||||
|
||||
def smooth_l1_loss(y_true, y_pred,reduction="mean"):
|
||||
"""Implements Smooth-L1 loss.
|
||||
y_true and y_pred are typically: [N, 4], but could be any shape.
|
||||
|
||||
Args:
|
||||
y_true - ground truth
|
||||
y_pred - predictions
|
||||
reduction - the mode of cal loss which must be in ['mean','sum','none']
|
||||
"""
|
||||
diff = jt.abs(y_true - y_pred)
|
||||
less_than_one = (diff<1.0).float32()
|
||||
loss = (less_than_one * 0.5 * diff.sqr()) + (1 - less_than_one) * (diff - 0.5)
|
||||
if reduction=="mean":
|
||||
return loss.mean()
|
||||
elif reduction=="sum":
|
||||
return loss.sum()
|
||||
elif reduction=="none":
|
||||
return loss
|
||||
else:
|
||||
raise ValueError(f'not support {reduction}')
|
||||
|
||||
class CrossEntropyLoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self,ignore_index=None):
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def execute(self, output, target):
|
||||
return cross_entropy_loss(output, target)
|
||||
return cross_entropy_loss(output, target,self.ignore_index)
|
||||
|
||||
class MSELoss(Module):
|
||||
def __init__(self):
|
||||
|
@ -228,6 +263,13 @@ def softmax(x, dim = None):
|
|||
ret = x / x.sum(dim, keepdims=True)
|
||||
return ret
|
||||
|
||||
def log_softmax(x,dim=None):
|
||||
x = softmax(x,dim=dim)
|
||||
return jt.log(x)
|
||||
|
||||
def log_sigmoid(x):
|
||||
return jt.log(jt.sigmoid(x))
|
||||
|
||||
class Dropout(Module):
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
|
||||
|
@ -267,13 +309,12 @@ class BatchNorm(Module):
|
|||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
if affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
|
@ -300,43 +341,63 @@ class BatchNorm(Module):
|
|||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
||||
if len(x.shape) == 3:
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0, 2], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0, 2], keepdims=1)
|
||||
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0, 2])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0, 2])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0, 2])
|
||||
running_var = self.running_var.broadcast(x, [0, 2])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0, 2])
|
||||
b = self.bias.broadcast(x, [0, 2])
|
||||
else:
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
||||
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0])
|
||||
b = self.bias.broadcast(x, [0])
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, [0])
|
||||
b = self.bias.broadcast(x, [0])
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
|
@ -379,19 +440,23 @@ class GroupNorm(Module):
|
|||
self.bias = init.constant((num_channels,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
N,C,H,W = x.shape
|
||||
assert C == self.num_channels
|
||||
N = x.shape[0]
|
||||
C = self.num_channels
|
||||
output_shape = (N,-1)
|
||||
# TODO: 3d group norm
|
||||
if x.ndim==4:
|
||||
output_shape = x.shape
|
||||
assert C % self.num_groups == 0
|
||||
x = x.reshape((N, self.num_groups, int(C/self.num_groups), H*W))
|
||||
x = x.reshape((N, self.num_groups, int(C/self.num_groups), -1))
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
return norm_x.reshape(output_shape)
|
||||
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
return (norm_x * w + b).reshape((N,C,H,W))
|
||||
return (norm_x * w + b).reshape(output_shape)
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
|
@ -482,6 +547,86 @@ class Conv(Module):
|
|||
y = y + b
|
||||
return y
|
||||
|
||||
class Conv1d(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = (kernel_size, 1)
|
||||
self.stride = (stride, 1)
|
||||
self.padding = (padding, 0)
|
||||
self.dilation = (dilation, 1)
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
self.conv = Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)
|
||||
|
||||
def execute(self, x):
|
||||
N,C,D = x.shape
|
||||
assert C==self.in_channels
|
||||
x = x.unsqueeze(-1)
|
||||
x = self.conv(x)
|
||||
y = x.squeeze(-1)
|
||||
return y
|
||||
|
||||
|
||||
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
padding = _pair(padding)
|
||||
stride = _pair(stride)
|
||||
dilation = _pair(dilation)
|
||||
out_channels = weight.shape[0]
|
||||
|
||||
if groups == 1:
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = weight.shape[-2:]
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid
|
||||
f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid
|
||||
])
|
||||
ww = weight.broadcast(xx.shape, [0,3,4])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
||||
if bias is not None:
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
else:
|
||||
N,C,H,W = x.shape
|
||||
Kh, Kw = weight.shape[-2:]
|
||||
G = groups
|
||||
CpG = C // G # channels per group
|
||||
oc = out_channels
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
||||
f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid
|
||||
])
|
||||
xx.compile_options = {"G":G}
|
||||
# w: [oc, CpG, Kh, Kw]
|
||||
ww = weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [
|
||||
f'i1*{oc//G}+i2',
|
||||
'i3',
|
||||
'i6',
|
||||
'i7'
|
||||
])
|
||||
yy = xx*ww
|
||||
y = yy.reindex_reduce('add', [N, oc, oh, ow], [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
'i5'
|
||||
])
|
||||
if bias is not None:
|
||||
b = bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
class ConvTranspose(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||
|
@ -745,7 +890,17 @@ def upsample(img, size, mode="nearest", align_corners=False):
|
|||
y = wid * (w / W)
|
||||
return _interpolate(img, x, y, (nid,cid), mode)
|
||||
|
||||
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||
def interpolate(X,size=None,scale_factor=None,mode='bilinear',align_corners=False):
|
||||
if scale_factor is not None:
|
||||
size = [X.shape[-2]*scale_factor,X.shape[-1]*scale_factor]
|
||||
if isinstance(size,int):
|
||||
size = (size,size)
|
||||
if scale_factor is not None and scale_factor>1:
|
||||
return upsample(X,size,mode,align_corners)
|
||||
else:
|
||||
return resize(X,size,mode,align_corners)
|
||||
|
||||
def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'):
|
||||
r'''
|
||||
Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.
|
||||
|
||||
|
@ -789,6 +944,195 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
|
|||
y = ((grid[:,:,:,0].unsqueeze(1).repeat([1,Ci,1,1]) + 1) / 2) * (Wi - 1)
|
||||
return _interpolate(input, x, y, (nid,cid), mode)
|
||||
|
||||
|
||||
def linspace_from_neg_one(grid,num_steps,align_corners):
|
||||
if num_steps <= 1:
|
||||
return jt.array([],dtype=grid.dtype)
|
||||
# TODO: use jt.index
|
||||
ra = np.linspace(-1,1,num_steps)
|
||||
if not align_corners:
|
||||
ra = ra*(num_steps-1)/num_steps
|
||||
return jt.array(ra,dtype=grid.dtype)
|
||||
|
||||
def make_base_grid_4D(theta,N,C,H,W,align_corners):
|
||||
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype);
|
||||
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
||||
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
||||
base_grid[...,-1] = 1
|
||||
return base_grid
|
||||
|
||||
def make_base_grid_5D(theta,N,C,D,H,W,align_corners):
|
||||
base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype)
|
||||
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
||||
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
||||
base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1)
|
||||
base_grid[...,-1] = 1
|
||||
return base_grid
|
||||
|
||||
def affine_grid_generator_4D(theta,N,C,H,W,align_corners):
|
||||
base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners)
|
||||
grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3),theta.transpose(0,2,1))
|
||||
return grid.reshape(N, H, W, 2)
|
||||
|
||||
def affine_grid_generator_5D(theta,N,C,D,H,W,align_corners):
|
||||
base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners)
|
||||
grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4),theta.transpose(0,2,1))
|
||||
return grid.reshape(N, D, H, W, 3)
|
||||
|
||||
def affine_grid(theta, size, align_corners=False):
|
||||
assert str(theta.dtype) in ['float','float32','float64']
|
||||
assert min(size)>0
|
||||
assert len(size) in [4,5]
|
||||
if len(size)== 4:
|
||||
assert theta.ndim == 3 and theta.shape[-2] == 2 and theta.shape[-1] == 3
|
||||
return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3], align_corners)
|
||||
elif len(size)==5:
|
||||
assert theta.ndim == 3 and theta.shape[-2] == 3 and theta.shape[-1] == 4
|
||||
return affine_grid_generator_5D(theta, size[0], size[1], size[2], size[3], size[4], align_corners)
|
||||
|
||||
|
||||
def grid_sampler_unnormalize(coord,size,align_corners):
|
||||
if align_corners:
|
||||
#unnormalize coord from [-1, 1] to [0, size - 1]
|
||||
return ((coord + 1) / 2) * (size - 1)
|
||||
else:
|
||||
#unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
||||
return ((coord + 1) * size - 1) / 2
|
||||
|
||||
|
||||
def clip_coordinates(x,clip_limit):
|
||||
return jt.clamp(x,min_v=0,max_v=clip_limit-1)
|
||||
|
||||
def reflect_coordinates(x,twice_low,twice_high):
|
||||
if twice_low == twice_high:
|
||||
return jt.zeros_like(x)
|
||||
m = twice_low / 2
|
||||
span = (twice_high - twice_low) / 2
|
||||
x = (x - m).abs()
|
||||
#`fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
||||
extra = x.mod(span)
|
||||
flips = (x / span).floor()
|
||||
result1 = extra+m
|
||||
result2 = span-extra+m
|
||||
con = flips%2==0
|
||||
not_con = flips%2!=0
|
||||
result1[not_con]=0.0
|
||||
result2[con]=0.0
|
||||
return result1+result2
|
||||
|
||||
|
||||
def grid_sampler_compute_source_index(coord,size,padding_mode,align_corners):
|
||||
coord = grid_sampler_unnormalize(coord, size, align_corners)
|
||||
if padding_mode == 'border':
|
||||
#clip coordinates to image borders
|
||||
coord = clip_coordinates(coord, size)
|
||||
elif padding_mode == 'reflection':
|
||||
#reflect coordinates by image borders
|
||||
if align_corners:
|
||||
coord = reflect_coordinates(coord, 0, 2*(size - 1))
|
||||
else:
|
||||
coord = reflect_coordinates(coord, -1, 2*size - 1)
|
||||
#clip coordinates to image borders
|
||||
coord = clip_coordinates(coord, size)
|
||||
return coord
|
||||
|
||||
|
||||
|
||||
def grid_sampler_3d(X,grid,mode,padding_mode,align_corners):
|
||||
N = X.shape[0]
|
||||
C = X.shape[1]
|
||||
inp_D = X.shape[2]
|
||||
inp_H = X.shape[3]
|
||||
inp_W = X.shape[4]
|
||||
|
||||
D = grid.shape[1]
|
||||
H = grid.shape[2]
|
||||
W = grid.shape[3]
|
||||
x = grid[:,:,:,:,0]
|
||||
y = grid[:,:,:,:,1]
|
||||
z = grid[:,:,:,:,2]
|
||||
shape = [N,C,D,H,W]
|
||||
cid = jt.index(shape, dim=1)
|
||||
nid = jt.index(shape, dim=0)
|
||||
|
||||
x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners)
|
||||
y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners)
|
||||
z = grid_sampler_compute_source_index(z,inp_D,padding_mode,align_corners)
|
||||
xid = x.reindex(shape,['i0','i2','i3','i4'])
|
||||
yid = y.reindex(shape,['i0','i2','i3','i4'])
|
||||
zid = z.reindex(shape,['i0','i2','i3','i4'])
|
||||
|
||||
if mode=='nearest':
|
||||
return X.reindex([nid,cid,zid.round(),yid.round(),xid.round()])
|
||||
elif mode=='bilinear':
|
||||
fx,fy,fz = xid.floor(),yid.floor(),zid.floor()
|
||||
cx,cy,cz = fx+1,fy+1,fz+1
|
||||
dx,dy,dz = xid-fx,yid-fy,zid-fz
|
||||
dnx,dny,dnz = cx-xid,cy-yid,cz-zid
|
||||
a = X.reindex([nid,cid,fz,fy,fx])
|
||||
b = X.reindex([nid,cid,cz,fy,fx])
|
||||
c = X.reindex([nid,cid,fz,cy,fx])
|
||||
d = X.reindex([nid,cid,fz,fy,cx])
|
||||
e = X.reindex([nid,cid,fz,cy,cx])
|
||||
f = X.reindex([nid,cid,cz,fy,cx])
|
||||
g = X.reindex([nid,cid,cz,cy,fx])
|
||||
h = X.reindex([nid,cid,cz,cy,cx])
|
||||
o = a*dnx*dny*dnz+b*dnx*dny*dz+c*dnx*dy*dnz+d*dx*dny*dnz+e*dx*dy*dnz+f*dx*dny*dz+g*dnx*dy*dz+h*dx*dy*dz
|
||||
return o
|
||||
|
||||
def grid_sampler_2d(X,grid,mode,padding_mode,align_corners):
|
||||
N = X.shape[0]
|
||||
C = X.shape[1]
|
||||
inp_H = X.shape[2]
|
||||
inp_W = X.shape[3]
|
||||
|
||||
H = grid.shape[1]
|
||||
W = grid.shape[2]
|
||||
x = grid[:,:,:,0]
|
||||
y = grid[:,:,:,1]
|
||||
shape = [N,C,H,W]
|
||||
cid = jt.index(shape, dim=1)
|
||||
nid = jt.index(shape, dim=0)
|
||||
|
||||
x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners)
|
||||
y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners)
|
||||
xid = x.reindex(shape,['i0','i2','i3'])
|
||||
yid = y.reindex(shape,['i0','i2','i3'])
|
||||
|
||||
if mode=='nearest':
|
||||
return X.reindex([nid,cid,yid.round(),xid.round()])
|
||||
elif mode=='bilinear':
|
||||
#xid,yid = (xid+0.00001),(yid+0.00001)
|
||||
fx,fy = (xid).floor(),(yid).floor()
|
||||
cx,cy = fx+1,fy+1
|
||||
dx,dy = xid-fx,yid-fy
|
||||
dnx,dny = cx-xid,cy-yid
|
||||
|
||||
a = X.reindex([nid,cid,fy,fx],overflow_value=0.0)
|
||||
b = X.reindex([nid,cid,cy,fx],overflow_value=0.0)
|
||||
c = X.reindex([nid,cid,fy,cx],overflow_value=0.0)
|
||||
d = X.reindex([nid,cid,cy,cx],overflow_value=0.0)
|
||||
o = a*dnx*dny+b*dnx*dy+c*dx*dny+d*dx*dy
|
||||
return o
|
||||
|
||||
|
||||
def grid_sampler(X, grid, mode, padding_mode, align_corners):
|
||||
assert X.dtype==grid.dtype
|
||||
assert ((X.ndim==4 or X.ndim==5) and X.ndim==grid.ndim)
|
||||
assert X.shape[0]==grid.shape[0] and grid.shape[-1]==X.ndim-2
|
||||
assert X.numel()>0
|
||||
if X.ndim == 4:
|
||||
return grid_sampler_2d(X, grid, mode, padding_mode, align_corners)
|
||||
else:
|
||||
return grid_sampler_3d(X, grid, mode, padding_mode, align_corners)
|
||||
|
||||
|
||||
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
|
||||
assert mode in ['bilinear','nearest']
|
||||
assert padding_mode in ['zeros','border','reflection']
|
||||
return grid_sampler(input, grid, mode, padding_mode, align_corners)
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest'):
|
||||
self.scale_factor = scale_factor if isinstance(scale_factor, tuple) else (scale_factor, scale_factor)
|
||||
|
@ -808,6 +1152,9 @@ class Sequential(Module):
|
|||
if isinstance(mod, collections.OrderedDict):
|
||||
for k, m in mod.items():
|
||||
self.add_module(k, m)
|
||||
elif isinstance(mod,list):
|
||||
for m in mod:
|
||||
self.append(m)
|
||||
else:
|
||||
self.append(mod)
|
||||
def __getitem__(self, idx):
|
||||
|
@ -836,4 +1183,7 @@ class Sequential(Module):
|
|||
assert not isinstance(mod, type), f"Module is not a type"
|
||||
self.layers[name]=mod
|
||||
|
||||
def __len__(self):
|
||||
return len(self.layers)
|
||||
|
||||
ModuleList = Sequential
|
||||
|
|
|
@ -109,6 +109,8 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
|
|||
func_args_convert = ""
|
||||
func_call = df["func_name"]+"("
|
||||
pytypes = [ get_pytype_map(a[0],0) for a in args ]
|
||||
holder_dec_array = []
|
||||
holder_set_array = []
|
||||
for tid, tpc in enumerate(pytypes):
|
||||
check = get_pytype_map(args[tid][0],2)
|
||||
default_arg = args[tid][2]
|
||||
|
@ -118,6 +120,11 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
|
|||
if jtp == "VarHolder*":
|
||||
holder_dec = f"unique_ptr<VarHolder> arg{tid}_holder"
|
||||
holder_set = f", arg{tid}_holder"
|
||||
if jtp == "VarSlices":
|
||||
holder_dec = f"vector<unique_ptr<VarHolder>> arg{tid}_holder"
|
||||
holder_set = f", arg{tid}_holder"
|
||||
holder_dec_array.append(holder_dec)
|
||||
holder_set_array.append(holder_set)
|
||||
if len(default_arg):
|
||||
func_args_convert += f"""
|
||||
{holder_dec};
|
||||
|
@ -165,7 +172,7 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
|
|||
if (khash == {get_hash(args[aid][1])}u) {{
|
||||
// hash match {args[aid][1]}
|
||||
CHECK(({get_pytype_map(args[aid][0],2)}(vo)));
|
||||
arg{aid} = {pytypes[aid]}(vo);
|
||||
arg{aid} = {pytypes[aid]}(vo{holder_set_array[aid]});
|
||||
arg_filled |= 1ull << {aid};
|
||||
continue;
|
||||
}}
|
||||
|
@ -759,7 +766,11 @@ def compile_src(src, h, basename):
|
|||
core_name = submodule_info["attrs"]["core_name"]
|
||||
has_map = class_name in ["VarHolder", "NanoVector"]
|
||||
has_seq = class_name == "NanoVector"
|
||||
code = f"""
|
||||
# add extra include to avoid compile error
|
||||
src_code = ""
|
||||
if include_name.endswith("var_slices.h"):
|
||||
src_code += '#include "var_holder.h"\n'
|
||||
src_code += f"""
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "pyjt/py_arg_printer.h"
|
||||
#include "common.h"
|
||||
|
@ -830,7 +841,7 @@ def compile_src(src, h, basename):
|
|||
|
||||
}}
|
||||
"""
|
||||
return code
|
||||
return src_code
|
||||
|
||||
def compile_single(head_file_name, src_file_name, src=None):
|
||||
basename = head_file_name.split("/")[-1].split(".")[0]
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor.nn import affine_grid,grid_sample
|
||||
|
||||
|
||||
class TestAffineGrid(unittest.TestCase):
|
||||
def test_affine_grid_2d(self):
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
N = 8
|
||||
C = 3
|
||||
H = 256
|
||||
W = 128
|
||||
theta = np.random.randn(N,2,3).astype(np.float32)
|
||||
features = np.random.randint(256,size=(N,C,H,W)).astype(np.float32)
|
||||
|
||||
torch_theta = torch.Tensor(theta)
|
||||
torch_features = torch.Tensor(features)
|
||||
torch_grid = F.affine_grid(torch_theta,size=(N,C,H,W),align_corners=False)
|
||||
torch_sample = F.grid_sample(torch_features,torch_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
|
||||
|
||||
jt_theta = jt.array(theta)
|
||||
jt_features = jt.array(features)
|
||||
jt_grid = affine_grid(jt_theta,size=(N,C,H,W),align_corners=False)
|
||||
jt_sample = grid_sample(jt_features,jt_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
|
||||
|
||||
assert np.allclose(jt_theta.numpy(),torch_theta.numpy())
|
||||
assert np.allclose(jt_features.numpy(),torch_features.numpy())
|
||||
assert np.allclose(jt_grid.numpy(),torch_grid.numpy(),atol=1e-05)
|
||||
assert np.allclose(torch_sample.numpy(),jt_sample.numpy(),atol=1e-01)
|
||||
|
||||
|
||||
def test_affine_grid_3d(self):
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
N = 8
|
||||
C = 3
|
||||
D = 64
|
||||
H = 256
|
||||
W = 128
|
||||
theta = np.random.randn(N,3,4).astype(np.float32)
|
||||
features = np.random.randint(256,size=(N,C,D,H,W)).astype(np.float32)
|
||||
|
||||
torch_theta = torch.Tensor(theta)
|
||||
torch_features = torch.Tensor(features)
|
||||
torch_grid = F.affine_grid(torch_theta,size=(N,C,D,H,W),align_corners=False)
|
||||
torch_sample = F.grid_sample(torch_features,torch_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
|
||||
|
||||
jt_theta = jt.array(theta)
|
||||
jt_features = jt.array(features)
|
||||
jt_grid = affine_grid(jt_theta,size=(N,C,D,H,W),align_corners=False)
|
||||
jt_sample = grid_sample(jt_features,jt_grid,mode='bilinear',padding_mode='zeros',align_corners=False)
|
||||
|
||||
assert np.allclose(jt_theta.numpy(),torch_theta.numpy())
|
||||
assert np.allclose(jt_features.numpy(),torch_features.numpy())
|
||||
assert np.allclose(jt_grid.numpy(),torch_grid.numpy(),atol=1e-05)
|
||||
assert np.allclose(torch_sample.numpy(),jt_sample.numpy(),atol=1e-01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -55,6 +55,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_backward([5,5,5], 'min', 0, True)
|
||||
check_backward([5,5,5], 'min', 2, True)
|
||||
check_backward([5,5,5], 'min', 1, True)
|
||||
check_backward([5,], 'min', 0, True)
|
||||
check_backward([20,20,20,20], 'max', 0, True)
|
||||
check_backward([20,20,20,20], 'max', 2, True)
|
||||
check_backward([20,20,20,20], 'max', 1, True)
|
||||
|
@ -62,6 +63,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_backward([5,5,5], 'min', 0, False)
|
||||
check_backward([5,5,5], 'min', 2, False)
|
||||
check_backward([5,5,5], 'min', 1, False)
|
||||
check_backward([5,], 'min', 0, False)
|
||||
check_backward([20,20,20,20], 'max', 0, False)
|
||||
check_backward([20,20,20,20], 'max', 2, False)
|
||||
check_backward([20,20,20,20], 'max', 1, False)
|
||||
|
@ -73,6 +75,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_backward([5,5,5], 'min', 0, True)
|
||||
check_backward([5,5,5], 'min', 2, True)
|
||||
check_backward([5,5,5], 'min', 1, True)
|
||||
check_backward([5,], 'min', 0, True)
|
||||
check_backward([20,20,20,20], 'max', 0, True)
|
||||
check_backward([20,20,20,20], 'max', 2, True)
|
||||
check_backward([20,20,20,20], 'max', 1, True)
|
||||
|
@ -80,6 +83,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_backward([5,5,5], 'min', 0, False)
|
||||
check_backward([5,5,5], 'min', 2, False)
|
||||
check_backward([5,5,5], 'min', 1, False)
|
||||
check_backward([5,], 'min', 0, False)
|
||||
check_backward([20,20,20,20], 'max', 0, False)
|
||||
check_backward([20,20,20,20], 'max', 2, False)
|
||||
check_backward([20,20,20,20], 'max', 1, False)
|
||||
|
@ -89,6 +93,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_reduce([5,5,5], 'min', 0, True)
|
||||
check_reduce([5,5,5], 'min', 2, True)
|
||||
check_reduce([5,5,5], 'min', 1, True)
|
||||
check_reduce([5], 'min', 0, True)
|
||||
check_reduce([20,20,20,20], 'max', 0, True)
|
||||
check_reduce([20,20,20,20], 'max', 2, True)
|
||||
check_reduce([20,20,20,20], 'max', 1, True)
|
||||
|
@ -96,6 +101,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_reduce([5,5,5], 'min', 0, False)
|
||||
check_reduce([5,5,5], 'min', 2, False)
|
||||
check_reduce([5,5,5], 'min', 1, False)
|
||||
check_reduce([5], 'min', 0, False)
|
||||
check_reduce([20,20,20,20], 'max', 0, False)
|
||||
check_reduce([20,20,20,20], 'max', 2, False)
|
||||
check_reduce([20,20,20,20], 'max', 1, False)
|
||||
|
@ -107,6 +113,7 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_reduce([5,5,5], 'min', 0, True, True)
|
||||
check_reduce([5,5,5], 'min', 2, True, True)
|
||||
check_reduce([5,5,5], 'min', 1, True, True)
|
||||
check_reduce([5], 'min', 0, True)
|
||||
check_reduce([20,20,20,20], 'max', 0, True, True)
|
||||
check_reduce([20,20,20,20], 'max', 2, True, True)
|
||||
check_reduce([20,20,20,20], 'max', 1, True, True)
|
||||
|
@ -114,9 +121,10 @@ class TestArgReduceOp(unittest.TestCase):
|
|||
check_reduce([5,5], 'min', 0, False, True)
|
||||
check_reduce([5,5,5], 'min', 2, False, True)
|
||||
check_reduce([5,5,5], 'min', 1, False, True)
|
||||
check_reduce([5], 'min', 0, False)
|
||||
check_reduce([20,20,20,20], 'max', 0, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 2, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 1, False, True)
|
||||
check_reduce([20,20,20,20], 'max', 3, False, True)
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -7,6 +7,37 @@ import unittest
|
|||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def concat2(arr, dim):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
* [in] dim: concat which dim
|
||||
|
||||
* [out] out: concat result
|
||||
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
'''
|
||||
# TODO: low performance when concat lots of vars
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
for a in arr:
|
||||
total_dim += a.shape[dim]
|
||||
cdim = 0
|
||||
shape = list(a.shape)
|
||||
shape[dim] = total_dim
|
||||
s = jt.empty(shape, a.dtype)
|
||||
slices = [slice(None)]*len(a.shape)
|
||||
for a in arr:
|
||||
slices[dim] = slice(cdim, cdim+a.shape[dim])
|
||||
# print(slices, type(a))
|
||||
s = s.setitem(tuple(slices), a)
|
||||
# s = jt.setitem(s, tuple(slices), a)
|
||||
cdim += a.shape[dim]
|
||||
return s
|
||||
|
||||
class TestConcatOp(unittest.TestCase):
|
||||
def test_concat_op(self):
|
||||
|
@ -18,7 +49,117 @@ class TestConcatOp(unittest.TestCase):
|
|||
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
|
||||
print('concat success...')
|
||||
print('concat success...')
|
||||
|
||||
@jt.flag_scope(use_cuda = 1)
|
||||
def test_concat_perf(self):
|
||||
def check(dim, size, backward=False):
|
||||
n = 64
|
||||
a = jt.random((n,n,n,n))
|
||||
a.sync()
|
||||
m = n // size
|
||||
arr = []
|
||||
for i in range(m):
|
||||
arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)])
|
||||
b = jt.contrib.concat(arr, dim)
|
||||
if backward:
|
||||
loss = b * a
|
||||
b = jt.grad(loss, a)
|
||||
with jt.profile_scope(1, 0) as rep:
|
||||
b.sync()
|
||||
# print(rep)
|
||||
i = rep[0].index("TotalTime")
|
||||
stime = 0
|
||||
for r in rep[1:]:
|
||||
stime += float(r[i])
|
||||
bw = 4*64**4*2*2 / stime
|
||||
# sizeof(float) * numel * (split and concat) * (read and write)
|
||||
print(f"{dim} {size} {stime/1e6}ms, {bw}GB/s")
|
||||
return bw
|
||||
ndim = 4
|
||||
splits = [1, 2, 4, 8, 16, 32, 64]
|
||||
m = len(splits)
|
||||
result = np.zeros((4, m))
|
||||
result_back = np.zeros((4, m))
|
||||
for i in range(ndim):
|
||||
for j in range(m):
|
||||
result[i,j] = check(i, splits[j])
|
||||
result_back[i,j] = check(i, splits[j], True)
|
||||
print(result.T)
|
||||
print(result_back.T)
|
||||
'''
|
||||
[[ 17.02802497 17.12933081 17.10814418 15.49217942]
|
||||
[ 33.10922467 33.01865886 33.08940182 30.24637466]
|
||||
[ 62.27219795 62.06702029 61.90039457 58.68727009]
|
||||
[112.31933307 111.89659519 111.02357161 108.98520165]
|
||||
[187.24806534 190.68837367 186.73965711 186.32242015]
|
||||
[280.28594579 278.94498734 284.42015302 284.98722929]
|
||||
[387.03887468 386.14916854 386.47551229 385.28621521]]
|
||||
|
||||
[[ 5.04141217 4.55677858 4.55677363 3.79321142]
|
||||
[ 9.05243799 8.99777599 8.96021333 7.49345194]
|
||||
[ 17.45032635 17.36882645 17.14316909 14.98928307]
|
||||
[ 35.60450372 35.55333375 35.32826879 32.00750909]
|
||||
[ 61.72854251 62.285231 61.64460882 58.17541776]
|
||||
[ 97.44981525 96.79104909 95.38118155 95.09154931]
|
||||
[135.11495888 134.60444658 135.41807381 135.38139881]]
|
||||
|
||||
'''
|
||||
|
||||
@jt.flag_scope(use_cuda = 1)
|
||||
def test_concat2_perf(self):
|
||||
def check(dim, size, backward=False):
|
||||
n = 64
|
||||
a = jt.random((n,n,n,n))
|
||||
a.sync()
|
||||
m = n // size
|
||||
arr = []
|
||||
for i in range(m):
|
||||
arr.append(a.getitem((slice(None),)*dim + (slice(i*size,i*size+size),)))
|
||||
b = concat2(arr, dim)
|
||||
if backward:
|
||||
loss = b * a
|
||||
b = jt.grad(loss, a)
|
||||
with jt.profile_scope(1, 0) as rep:
|
||||
b.sync()
|
||||
# print(rep)
|
||||
i = rep[0].index("TotalTime")
|
||||
stime = 0
|
||||
for r in rep[1:]:
|
||||
stime += float(r[i])
|
||||
bw = 4*64**4*2*2 / stime
|
||||
# sizeof(float) * numel * (split and concat) * (read and write)
|
||||
print(f"{dim} {size} {stime/1e6}ms, {bw}GB/s")
|
||||
return bw
|
||||
ndim = 4
|
||||
splits = [1, 2, 4, 8, 16, 32, 64]
|
||||
m = len(splits)
|
||||
result = np.zeros((4, m))
|
||||
result_back = np.zeros((4, m))
|
||||
for i in range(ndim):
|
||||
for j in range(m):
|
||||
result[i,j] = check(i, splits[j])
|
||||
result_back[i,j] = check(i, splits[j], True)
|
||||
print(result.T)
|
||||
print(result_back.T)
|
||||
'''
|
||||
[[ 15.59142118 15.8001291 15.77589713 11.79319714]
|
||||
[ 31.33130734 31.2476813 31.20394782 23.19700034]
|
||||
[ 57.90763098 57.71203221 58.02228419 45.60297828]
|
||||
[104.20428796 104.08291412 104.18568373 91.648383 ]
|
||||
[175.21896606 175.44422637 176.57915576 168.33344684]
|
||||
[264.35929995 267.63202466 262.92687504 268.41854563]
|
||||
[352.36998687 355.89200025 360.95753527 361.34916742]]
|
||||
[[ 3.39802237 3.42782551 3.43126375 2.85884566]
|
||||
[ 7.12993628 7.11445323 7.11482319 5.90134142]
|
||||
[ 15.13540229 15.11031669 15.12954432 12.76302703]
|
||||
[ 28.08930928 28.09445985 28.01005224 25.43536254]
|
||||
[ 49.58246623 49.70843778 49.49253912 48.07459389]
|
||||
[ 80.3745414 80.85044884 79.74203591 80.97114412]
|
||||
[117.14450249 119.22320442 119.2380328 119.63622556]]
|
||||
|
||||
'''
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -9,7 +9,7 @@
|
|||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
from jittor.test.test_grad import ngrad
|
||||
|
||||
class TestSlice(unittest.TestCase):
|
||||
def test_slice_bool(self):
|
||||
|
@ -17,7 +17,122 @@ class TestSlice(unittest.TestCase):
|
|||
a[1] = True
|
||||
a[2] = 1
|
||||
assert a.dtype == "bool"
|
||||
print(a)
|
||||
a.sync()
|
||||
|
||||
def test_var_slices(self):
|
||||
def check(slices, msg):
|
||||
with jt.log_capture_scope() as logs:
|
||||
jt.core._print_var_slice(slices)
|
||||
s = logs[0]['msg']
|
||||
assert s == msg, s
|
||||
check((1), "[1,]")
|
||||
check(([[0],[1]],slice(None),[1,2],1), "[int32[2,1,],::,int32[2,],1,]")
|
||||
check((slice(None),slice(None),slice(None),slice(None)), "[::,::,::,::,]")
|
||||
check(([0,1],[0,1],[0,1],[0,1]), "[int32[2,],int32[2,],int32[2,],int32[2,],]")
|
||||
check(([0,1],-2,slice(None),[0,1]), "[int32[2,],-2,::,int32[2,],]")
|
||||
check(([0,1],slice(1,2,2),[1,2],1), "[int32[2,],1:2:2,int32[2,],1,]")
|
||||
check(([0,1],slice(None),[1,2],1), "[int32[2,],::,int32[2,],1,]")
|
||||
check((slice(1,None,2),slice(-1,None,2),[1,2],-4), "[1::2,-1::2,int32[2,],-4,]")
|
||||
check(0, "[0,]")
|
||||
check(10, "[10,]")
|
||||
check(-10, "[-10,]")
|
||||
check(1, "[1,]")
|
||||
check((1,slice(None),2), "[1,::,2,]")
|
||||
check((-2,slice(None),2,slice(1,9,2)), "[-2,::,2,1:9:2,]")
|
||||
check((None,1,None,2,None), "[-,1,-,2,-,]")
|
||||
check((...,1,...,2,...), "[...,1,...,2,...,]")
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No cuda")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_getitem(self):
|
||||
def check(shape, slices, i_to_vs="", i_to_o="", o_shape=""):
|
||||
# print(slices)
|
||||
x = jt.random(shape)
|
||||
|
||||
with jt.log_capture_scope(log_vprefix="getitem=1000") as logs:
|
||||
a = x.getitem(slices)
|
||||
a.sync()
|
||||
b = x.data[slices]
|
||||
bshape = b.shape if len(b.shape) else (1,)
|
||||
assert a.shape == bshape, (a.shape, bshape)
|
||||
s = logs[-1]['msg']
|
||||
assert "i_to_vs: "+i_to_vs in s
|
||||
assert "i_to_o: "+i_to_o in s
|
||||
assert "o_shape: "+o_shape in s
|
||||
aa = a.numpy()
|
||||
assert (aa==b).all(), (aa, b)
|
||||
|
||||
y = x.numpy()
|
||||
v = jt.random(a.shape)
|
||||
z = x.setitem(slices, v)
|
||||
y[slices] = v.data
|
||||
assert (z.data==y).all(), (z.data, y, v.data, x.data)
|
||||
|
||||
# test_setitem broadcast
|
||||
adim = len(a.shape)
|
||||
for mask in range(1<<adim):
|
||||
new_shape = list(a.shape)
|
||||
for i in range(adim):
|
||||
if (mask>>i)&1:
|
||||
new_shape[i] = 1
|
||||
y = x.numpy()
|
||||
v = jt.random(new_shape)
|
||||
z = x.setitem(slices, v)
|
||||
y[slices] = v.data
|
||||
assert (z.data==y).all(), (z.data, y, v.data, x.data)
|
||||
|
||||
|
||||
# TODO: when slice same row/col many times and assign value, numpy will retain the last value but we assign their sum. eg: check([3,3,3,3], ([[0,1,1]],slice(None),[[1],[2],[0]],1))
|
||||
check([3], (1), "[0,]", "[-1,]", "[]")
|
||||
check([3,3,3,3], ([[0],[1]],slice(None),[1,2],1), "[0,-1,2,3,]", "[-1,2,-1,-1,]", "[2,2,3,]")
|
||||
check([3,3,3,3], (slice(None),slice(None),slice(None),slice(None)), "[-1,-2,-2,-2,]", "[0,0,0,0,]", "[81,]")
|
||||
check([3,3,3,3], ([0,1],[0,1],[0,1],[0,1]), "[0,1,2,3,]", "[-1,-1,-1,-1,]", "[2,]")
|
||||
check([3,3,3,3], ([0,1],-2,slice(None),[0,1]), "[0,1,-1,3,]", "[-1,-1,1,-1,]", "[2,3,]")
|
||||
check([3,3,3,3], ([0,1],slice(1,2,2),[1,2],1), "[0,1,2,3,]", "[-1,1,-1,-1,]", "[2,1,]")
|
||||
check([3,3,3,3], ([0,1],slice(None),[1,2],1), "[0,-1,2,3,]", "[-1,1,-1,-1,]", "[2,3,]")
|
||||
check([3,3,3,3], (slice(1,10,1),...,slice(2,None,-1)), "[0,-1,-2,2,]", "[0,1,1,2,]", "[2,9,3,]")
|
||||
check([10,10,10,10], (slice(1,None,2),slice(-1,None,2),[1,2],-4), "[0,1,2,3,]", "[0,1,-1,-1,]", "")
|
||||
check([20], 0, "[0,]", "[-1,]", "[]")
|
||||
check([20], 10, "[0,]", "[-1,]", "[]")
|
||||
check([20], -10, "[0,]", "[-1,]", "[]")
|
||||
check([10,10,10,10], 1, "[0,-1,-2,-2,]", "[-1,0,0,0,]", "[1000,]")
|
||||
check([10,10,10,10], (1,slice(None),2), "[0,-1,2,-1,]", "[-1,0,-1,1,]", "")
|
||||
check([10,10,10,10], (-2,slice(None),2,slice(1,9,2)), "[0,-1,2,3,]", "[-1,0,-1,1,]")
|
||||
|
||||
def test_getitem_grad(self):
|
||||
shape = (10,)
|
||||
slices = slice(2,4)
|
||||
|
||||
a = jt.random(shape)
|
||||
b = a.getitem(slices)
|
||||
mask = jt.random(b.shape)
|
||||
loss = b*mask
|
||||
da = jt.grad(loss, a)
|
||||
|
||||
_, np_grad = ngrad(lambda vars: (vars[0][slices]*mask.data).sum(), [a.numpy()], 1e-3)
|
||||
assert np.allclose(da.numpy(), np_grad, atol = 1e-3), (da.numpy(), np_grad)
|
||||
|
||||
shape = (10,)
|
||||
slices = slice(2,4)
|
||||
|
||||
a = jt.random(shape)
|
||||
b = a.getitem(slices)
|
||||
b = jt.random(b.shape)
|
||||
c = a.setitem(slices, b)
|
||||
mask = jt.random(c.shape)
|
||||
loss = c*mask
|
||||
da,db = jt.grad(loss, [a,b])
|
||||
|
||||
def numpy_grad(vars):
|
||||
a, b = vars
|
||||
a = a.copy()
|
||||
a[slices] = b
|
||||
return (a*mask.data).sum()
|
||||
|
||||
_, (nda, ndb) = ngrad(numpy_grad, [a.numpy(), b.numpy()], 1e-3)
|
||||
assert np.allclose(da.numpy(), nda, atol = 1e-3)
|
||||
assert np.allclose(db.numpy(), ndb, atol = 1e-3)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -12,6 +12,7 @@ import math
|
|||
import numpy as np
|
||||
import warnings
|
||||
from collections.abc import Sequence, Mapping
|
||||
import jittor as jt
|
||||
|
||||
def crop(img, top, left, height, width):
|
||||
'''
|
||||
|
@ -215,6 +216,90 @@ def to_tensor(img):
|
|||
return np.array(img).transpose((2,0,1)) / np.float32(255)
|
||||
return img
|
||||
|
||||
|
||||
def to_pil_image(pic, mode=None):
|
||||
"""Convert a tensor or an ndarray to PIL Image.
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
||||
Returns:
|
||||
PIL Image: Image converted to PIL Image.
|
||||
"""
|
||||
if not(isinstance(pic, jt.Var) or isinstance(pic, np.ndarray)):
|
||||
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
|
||||
|
||||
elif isinstance(pic, jt.Var):
|
||||
if pic.ndim not in {2, 3}:
|
||||
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
|
||||
|
||||
elif pic.ndim == 2:
|
||||
# if 2D image, add channel dimension (CHW)
|
||||
pic = pic.unsqueeze(0)
|
||||
|
||||
elif isinstance(pic, np.ndarray):
|
||||
if pic.ndim not in {2, 3}:
|
||||
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
|
||||
|
||||
elif pic.ndim == 2:
|
||||
# if 2D image, add channel dimension (HWC)
|
||||
pic = np.expand_dims(pic, 2)
|
||||
|
||||
npimg = pic
|
||||
if isinstance(pic, jt.Var):
|
||||
if 'float' in str(pic.dtype) and mode != 'F':
|
||||
pic = pic.mul(255).uint8()
|
||||
npimg = np.transpose(pic.numpy(), (1, 2, 0))
|
||||
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
raise TypeError('Input pic must be a jt.Var or NumPy ndarray, ' +
|
||||
'not {}'.format(type(npimg)))
|
||||
|
||||
if npimg.shape[2] == 1:
|
||||
expected_mode = None
|
||||
npimg = npimg[:, :, 0]
|
||||
if npimg.dtype == np.uint8:
|
||||
expected_mode = 'L'
|
||||
elif npimg.dtype == np.int16:
|
||||
expected_mode = 'I;16'
|
||||
elif npimg.dtype == np.int32:
|
||||
expected_mode = 'I'
|
||||
elif npimg.dtype == np.float32:
|
||||
expected_mode = 'F'
|
||||
if mode is not None and mode != expected_mode:
|
||||
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
|
||||
.format(mode, np.dtype, expected_mode))
|
||||
mode = expected_mode
|
||||
|
||||
elif npimg.shape[2] == 2:
|
||||
permitted_2_channel_modes = ['LA']
|
||||
if mode is not None and mode not in permitted_2_channel_modes:
|
||||
raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
|
||||
|
||||
if mode is None and npimg.dtype == np.uint8:
|
||||
mode = 'LA'
|
||||
|
||||
elif npimg.shape[2] == 4:
|
||||
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
|
||||
if mode is not None and mode not in permitted_4_channel_modes:
|
||||
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
|
||||
|
||||
if mode is None and npimg.dtype == np.uint8:
|
||||
mode = 'RGBA'
|
||||
else:
|
||||
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
|
||||
if mode is not None and mode not in permitted_3_channel_modes:
|
||||
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
|
||||
if mode is None and npimg.dtype == np.uint8:
|
||||
mode = 'RGB'
|
||||
|
||||
if mode is None:
|
||||
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
|
||||
|
||||
return Image.fromarray(npimg, mode=mode)
|
||||
|
||||
|
||||
|
||||
class ImageNormalize:
|
||||
'''
|
||||
Class for normalizing the input image.
|
||||
|
@ -328,4 +413,53 @@ class RandomCrop:
|
|||
top = np.random.randint(0,height-self.size[0]+1)
|
||||
left = np.random.randint(0,width-self.size[1]+1)
|
||||
return crop(img, top, left, self.size[0], self.size[1])
|
||||
|
||||
|
||||
class Lambda:
|
||||
"""Apply a user-defined lambda as a transform.
|
||||
Args:
|
||||
lambd (function): Lambda/function to be used for transform.
|
||||
"""
|
||||
|
||||
def __init__(self, lambd):
|
||||
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, img):
|
||||
return self.lambd(img)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
class ToTensor:
|
||||
def __call__(self, pic):
|
||||
"""
|
||||
Args:
|
||||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
return to_tensor(pic)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
class ToPILImage(object):
|
||||
def __init__(self, mode=None):
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, pic):
|
||||
"""
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||
Returns:
|
||||
PIL Image: Image converted to PIL Image.
|
||||
"""
|
||||
return to_pil_image(pic, self.mode)
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
if self.mode is not None:
|
||||
format_string += 'mode={0}'.format(self.mode)
|
||||
format_string += ')'
|
||||
return format_string
|
||||
|
|
|
@ -222,6 +222,16 @@ class Hook:
|
|||
pickle.dump(ps, f)
|
||||
LOG.i(f"save params ok")
|
||||
|
||||
def hook_function(self, func):
|
||||
name = func.__name__
|
||||
def new_func(*args, **kw):
|
||||
ret = func(*args, **kw)
|
||||
self.record(name+".args", args)
|
||||
self.record(name+".kw", kw)
|
||||
self.record(name+".ret", ret)
|
||||
return ret
|
||||
return new_func
|
||||
|
||||
|
||||
def hook_module(self, mod, mod_name=""):
|
||||
if os.environ.get("use_auto_diff", '1') == '0':
|
||||
|
|
|
@ -74,8 +74,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
vector<Node*> bfs_q;
|
||||
bfs_q.reserve(vars.size());
|
||||
int start_var_num = 0;
|
||||
{
|
||||
while (1) {
|
||||
op_num = 0;
|
||||
start_var_num = 0;
|
||||
bfs_q.clear();
|
||||
// get all nodes need to be executed
|
||||
int need_opt = 0;
|
||||
auto t = ++Node::tflag_count;
|
||||
for (Var* v : vars)
|
||||
if (!v->is_finished() && v->tflag != t) {
|
||||
|
@ -89,6 +93,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
for (auto i : node->_inputs)
|
||||
if (i.node->tflag != t && !i.node->is_finished()) {
|
||||
i.node->tflag = t;
|
||||
need_opt += i.node->flags.get(NodeFlags::_has_gopt);
|
||||
bfs_q.push_back(i.node);
|
||||
}
|
||||
// this var has been fetched
|
||||
|
@ -99,11 +104,19 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
!n.node->is_finished() &&
|
||||
n.node->flags.get(NodeFlags::_fetch)) {
|
||||
n.node->tflag = t;
|
||||
need_opt += n.node->flags.get(NodeFlags::_has_gopt);
|
||||
bfs_q.push_back(n.node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!need_opt) break;
|
||||
for (Node* n : bfs_q) {
|
||||
if (n->flags.get(NodeFlags::_has_gopt)) {
|
||||
n->op()->graph_optimize();
|
||||
n->flags.set(NodeFlags::_has_gopt, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto tt = Node::tflag_count;
|
||||
vector<Op*> ops;
|
||||
|
|
|
@ -55,9 +55,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
return false;
|
||||
if (node->is_stop_grad())
|
||||
return false;
|
||||
// int value has zero grad
|
||||
if (node->is_var())
|
||||
return node->var()->is_float();
|
||||
return true;
|
||||
});
|
||||
LOGvv << "Size of grad nodes:" << gnodes.size();
|
||||
|
|
|
@ -44,6 +44,11 @@ struct JitKey {
|
|||
explicit hex1(uint data) : data(data) {}
|
||||
};
|
||||
|
||||
struct shex1 {
|
||||
int data;
|
||||
explicit shex1(int data) : data(data) {}
|
||||
};
|
||||
|
||||
struct hex2 {
|
||||
uint data;
|
||||
explicit hex2(uint data) : data(data) {}
|
||||
|
@ -90,6 +95,13 @@ inline JK& operator<<(JK& jk, const JK::hex1& h) {
|
|||
return jk << (char)((data<10) ? data+'0' : data-10+'a');
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const JK::shex1& h) {
|
||||
if (h.data<0)
|
||||
return jk << '-' << JK::hex1(-h.data);
|
||||
else
|
||||
return jk << JK::hex1(h.data);
|
||||
}
|
||||
|
||||
inline JK& operator<<(JK& jk, const JK::hex2& h) {
|
||||
return jk << JK::hex1(h.data>>4) << JK::hex1(h.data);
|
||||
}
|
||||
|
|
|
@ -212,14 +212,13 @@ void* SFRLAllocator::alloc(size_t size, size_t& allocation) {
|
|||
if (block == nullptr) {
|
||||
free_all_sfrl_allocators();
|
||||
size_t alloc_size = allocation_size(size);
|
||||
void* ptr = underlying->alloc(alloc_size, allocation);
|
||||
if (ptr == nullptr) {
|
||||
void* ptr = nullptr;
|
||||
try {
|
||||
ptr = underlying->alloc(alloc_size, allocation);
|
||||
} catch (...) {
|
||||
unused_memory -= large_blocks.free_all_cached_blocks(underlying);
|
||||
unused_memory -= small_blocks.free_all_cached_blocks(underlying);
|
||||
void* ptr = underlying->alloc(alloc_size, allocation);
|
||||
if (ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
ptr = underlying->alloc(alloc_size, allocation);
|
||||
}
|
||||
block = new CachingBlock(alloc_size, blocks, ptr);
|
||||
} else {
|
||||
|
|
|
@ -22,8 +22,33 @@ static inline int lzcnt(int64 v) {
|
|||
|
||||
struct Slice {
|
||||
int64 start, stop, step, mask;
|
||||
inline void fill(int64 size) {
|
||||
if (step>0) {
|
||||
if (mask&2)
|
||||
stop = size;
|
||||
else if (stop<0)
|
||||
stop += size;
|
||||
else
|
||||
stop = std::min(size, stop);
|
||||
} else {
|
||||
if (mask&1) start = size-1;
|
||||
if (mask&2)
|
||||
stop = -1;
|
||||
else if (stop<0)
|
||||
stop = std::max((int64)0, stop+size);
|
||||
}
|
||||
if (start<0) start += size;
|
||||
mask = 0;
|
||||
ASSERT(start>=0 && stop>=-1 && start<size && stop<=size)
|
||||
<< "slice overflow:" << start << stop << step;
|
||||
}
|
||||
};
|
||||
|
||||
// return a / ceil_to_2_pow(b)
|
||||
inline uint64 fast_div(uint64 a, uint64 b) {
|
||||
return a >> (64 - lzcnt(b));
|
||||
}
|
||||
|
||||
// @pyjt(NanoVector)
|
||||
struct NanoVector {
|
||||
int64 data=0, offset=0;
|
||||
|
@ -108,22 +133,13 @@ struct NanoVector {
|
|||
|
||||
// @pyjt(__map_getitem__)
|
||||
inline NanoVector slice(Slice slice) {
|
||||
if (slice.step>0) {
|
||||
if (slice.mask&2) slice.stop = size();
|
||||
} else {
|
||||
if (slice.mask&1) slice.start = size()-1;
|
||||
if (slice.mask&2) slice.stop = 0;
|
||||
}
|
||||
if (slice.start<0) slice.start += size();
|
||||
if (slice.stop<0) slice.stop += size();
|
||||
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<=size())
|
||||
<< "slice overflow:" << slice.start << slice.stop << slice.step;
|
||||
slice.fill(size());
|
||||
NanoVector v;
|
||||
if (slice.step>0) {
|
||||
for (int i=slice.start; i<slice.stop; i+=slice.step)
|
||||
v.push_back(this->operator[](i));
|
||||
} else {
|
||||
for (int i=slice.start; i>=slice.stop; i+=slice.step)
|
||||
for (int i=slice.start; i>slice.stop; i+=slice.step)
|
||||
v.push_back(this->operator[](i));
|
||||
}
|
||||
return v;
|
||||
|
@ -166,14 +182,14 @@ struct NanoVector {
|
|||
struct Iter {
|
||||
const NanoVector* self;
|
||||
int index;
|
||||
int64 operator*() { return self->at(index); }
|
||||
Iter& operator++() { index++; return *this; }
|
||||
Iter operator+(int i) { return {self, i+index}; }
|
||||
bool operator!=(Iter& other) { return index != other.index; }
|
||||
inline int64 operator*() { return self->at(index); }
|
||||
inline Iter& operator++() { index++; return *this; }
|
||||
inline Iter operator+(int i) { return {self, i+index}; }
|
||||
inline bool operator!=(Iter& other) { return index != other.index; }
|
||||
};
|
||||
|
||||
Iter begin() { return {this, 0}; }
|
||||
Iter end() { return {this, size()}; }
|
||||
inline Iter begin() { return {this, 0}; }
|
||||
inline Iter end() { return {this, size()}; }
|
||||
|
||||
inline void pop_back() { offset--; data &= (1ll<<total_bits())-1; }
|
||||
inline void push_back(Iter s, Iter t) {
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "misc/nano_vector.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
template<class T=int64, int N=10>
|
||||
struct StackVector {
|
||||
int n;
|
||||
T a[N+1];
|
||||
inline T& front() { return a[0]; }
|
||||
inline T& back() { return a[n-1]; }
|
||||
inline int size() { return n;}
|
||||
inline StackVector(int n=0) : n(n) {}
|
||||
|
||||
struct Iter {
|
||||
const StackVector<T,N>* self;
|
||||
int index;
|
||||
inline T operator*() { return self->at(index); }
|
||||
inline Iter& operator++() { index++; return *this; }
|
||||
inline Iter operator+(int i) { return {self, i+index}; }
|
||||
inline bool operator!=(Iter& other) { return index != other.index; }
|
||||
};
|
||||
|
||||
inline Iter begin() { return {this, 0}; }
|
||||
inline Iter end() { return {this, size()}; }
|
||||
inline T& operator[](int i) { return a[i]; }
|
||||
|
||||
inline void pop_back() { n--; }
|
||||
inline void push_back(T v) { if (n<N) a[n++] = v; }
|
||||
inline void check() { ASSERT(n<N); }
|
||||
inline NanoVector to_nano_vector() {
|
||||
check();
|
||||
NanoVector nv;
|
||||
for (int i=0; i<n; i++)
|
||||
nv.push_back_check_overflow(a[i]);
|
||||
return nv;
|
||||
}
|
||||
};
|
||||
|
||||
template<class T, int N>
|
||||
inline std::ostream& operator<<(std::ostream& os, const StackVector<T,N>& v) {
|
||||
os << '[';
|
||||
for (int i=0; i<v.n; i++)
|
||||
os << v.a[i] << ',';
|
||||
return os << ']';
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -48,6 +48,8 @@ struct NodeFlags {
|
|||
_op_type=_n+4, _op_type_nbits=2,
|
||||
// bit6: backprop grad at ones
|
||||
_grads=_n+6,
|
||||
// bit7: has graph optimize
|
||||
_has_gopt=_n+7,
|
||||
};
|
||||
|
||||
inline void set(Flags f, int a=1, int nbits=1) {
|
||||
|
|
|
@ -75,6 +75,12 @@ bool Op::shape_infered() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Op::compile_optimize(string& src) {}
|
||||
|
||||
void Op::infer_shape() {}
|
||||
void Op::run() {}
|
||||
void Op::jit_prepare() {}
|
||||
void Op::graph_optimize() {}
|
||||
|
||||
string Op::name_ex() const {
|
||||
string a=name();
|
||||
|
|
8
src/op.h
8
src/op.h
|
@ -38,9 +38,9 @@ struct Op : Node {
|
|||
|
||||
virtual VarPtr grad(Var* out, Var* dout, Var* v, int v_index);
|
||||
virtual void grads(Var** douts, VarPtr* dins);
|
||||
virtual void infer_shape() {}
|
||||
virtual void run() {};
|
||||
virtual void jit_prepare() {};
|
||||
virtual void infer_shape();
|
||||
virtual void run();
|
||||
virtual void jit_prepare();
|
||||
virtual void do_jit_prepare();
|
||||
virtual const char* name() const = 0;
|
||||
virtual void statistics(uint64_t& in, uint64_t& out, uint64_t& compute);
|
||||
|
@ -48,6 +48,8 @@ struct Op : Node {
|
|||
virtual void do_run_after_prepare();
|
||||
virtual void do_run();
|
||||
virtual VarPtr duplicate();
|
||||
virtual void compile_optimize(string& src);
|
||||
virtual void graph_optimize();
|
||||
void jit_run();
|
||||
|
||||
string name_ex() const;
|
||||
|
|
|
@ -416,6 +416,7 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
vector<string> args;
|
||||
size_t l = k+1;
|
||||
if (expr == "for" || expr == "if" || expr == "expand_macro" ||
|
||||
expr == "is_def" ||
|
||||
(k<src.size() && src[k]=='(')) {
|
||||
ASSERT(src[k] == '(');
|
||||
comma.push_back(k);
|
||||
|
@ -447,6 +448,9 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
if (args.size() >= 5) {
|
||||
step = OpCompiler::eval(vs, defs);
|
||||
vs = args[4];
|
||||
for (int i=5; i<args.size(); i++) {
|
||||
vs += "," + args[i];
|
||||
}
|
||||
}
|
||||
auto new_defs = defs;
|
||||
LOGvvv << "Expand for" << expr >> "[" >> vil >> "," >> vir >> "," >> step >> "]";
|
||||
|
@ -473,6 +477,18 @@ string precompile(unordered_map<string,string> defs, string src, unordered_map<s
|
|||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "is_def") {
|
||||
ASSERT(args.size()==1)
|
||||
<< "Jit error: is_def wrong arguments.";
|
||||
string vdef = args[0];
|
||||
vdef = precompile(defs, vdef, macros);
|
||||
if (defs.count(vdef) || macros.count(vdef))
|
||||
new_src += "1";
|
||||
else
|
||||
new_src += "0";
|
||||
i = l-1;
|
||||
continue;
|
||||
} else
|
||||
if (expr == "expand_macro") {
|
||||
// syntax: @expand_macro(macro, args)
|
||||
// ij k l
|
||||
|
@ -983,8 +999,9 @@ jit_op_entry_t OpCompiler::do_compile(Op* op) {
|
|||
src_after_passes = tm.tune();
|
||||
src = &src_after_passes;
|
||||
}
|
||||
op->compile_optimize(*src);
|
||||
auto ret = oc.compile(op->get_jit_key(), *src);
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -149,6 +149,8 @@ void ArgReduceOp::infer_shape() {
|
|||
shape.push_back(x->shape[i]);
|
||||
}
|
||||
}
|
||||
if (shape.size() == 0)
|
||||
shape.push_back(1);
|
||||
y->set_shape(shape);
|
||||
y_key->set_shape(shape);
|
||||
}
|
||||
|
@ -205,4 +207,4 @@ void ArgReduceOp::jit_run() {
|
|||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -162,7 +162,7 @@ void BinaryOp::infer_shape() {
|
|||
// -1 1 need b
|
||||
// has 1, b, both 1, not b, 0, error
|
||||
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
|
||||
CHECK(xshape && yshape) << "Shape can not broadcast to 0.";
|
||||
// CHECK(xshape && yshape) << "Shape can not broadcast to 0.";
|
||||
need_broadcast = true;
|
||||
continue;
|
||||
}
|
||||
|
@ -198,4 +198,4 @@ void BinaryOp::jit_run() {
|
|||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -5,12 +5,16 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "ops/candidate_op.h"
|
||||
#ifdef JIT_cuda
|
||||
#include "executor.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
y = create_output(nullptr, dtype);
|
||||
}
|
||||
|
@ -27,10 +31,76 @@ void CandidateOp::jit_prepare() {
|
|||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
#ifdef JIT_cuda
|
||||
|
||||
__global__ static void candidate_kernel(
|
||||
@for(i, 0, XDIM, 1, index_t xshape@i, )
|
||||
Tx* __restrict__ xp,
|
||||
Ty* __restrict__ yp,
|
||||
bool* __restrict__ maskp,
|
||||
int* __restrict__ np
|
||||
) {
|
||||
int n=0;
|
||||
int tid = threadIdx.x;
|
||||
int tnum = blockDim.x;
|
||||
|
||||
// define cond stride
|
||||
index_t xstride@{XDIM-1} = 1;
|
||||
@for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
|
||||
|
||||
// generate d-for loop
|
||||
for (index_t i=0; i < xshape0; i++) {
|
||||
__syncthreads();
|
||||
if (!maskp[i]) continue;
|
||||
if (tid == 0) {
|
||||
yp[n] = i;
|
||||
n++;
|
||||
}
|
||||
for (index_t j=i+1+tid; j < xshape0; j+=tnum) {
|
||||
if (@FUNC) maskp[j] = 0;
|
||||
}
|
||||
}
|
||||
if (tid == 0) {
|
||||
np[0] = n;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void CandidateOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
// define cond shape
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
|
||||
// define ys
|
||||
auto* __restrict__ yp = y->ptr<Ty>();
|
||||
size_t n_allocation;
|
||||
int* np = (int*)exe.allocator->alloc(4, n_allocation);
|
||||
size_t mask_allocation;
|
||||
bool* maskp = (bool*)exe.allocator->alloc(xshape0, mask_allocation);
|
||||
checkCudaErrors(cudaMemsetAsync(maskp, 1, xshape0));
|
||||
|
||||
candidate_kernel<<<1, std::max(1, std::min(1024, xshape0)) >>>(
|
||||
@for(i, 0, XDIM, 1, xshape@i, )
|
||||
xp,
|
||||
yp,
|
||||
maskp,
|
||||
np
|
||||
);
|
||||
|
||||
int n=0;
|
||||
// checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
|
||||
y->set_shape({n});
|
||||
exe.allocator->free(np, 4, n_allocation);
|
||||
exe.allocator->free(maskp, xshape0, mask_allocation);
|
||||
}
|
||||
#else
|
||||
void CandidateOp::jit_run() {
|
||||
using namespace std;
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
// define cond shape
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
// define cond stride
|
||||
index_t xstride@{XDIM-1} = 1;
|
||||
@for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
|
||||
|
@ -56,6 +126,7 @@ void CandidateOp::jit_run() {
|
|||
}
|
||||
y->set_shape({n});
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -0,0 +1,499 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "ops/getitem_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#endif
|
||||
#ifndef JIT
|
||||
#include "misc/stack_vector.h"
|
||||
#include "opt/kernel_ir.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include "misc/cuda_flags.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
static auto make_setitem = get_op_info("setitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||
|
||||
GetitemOp::GetitemOp(Var* x, VarSlices&& slices)
|
||||
: vs(move(slices)) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_has_gopt);
|
||||
create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void GetitemOp::infer_slices(
|
||||
StackVector<>& __restrict__ i_to_vs,
|
||||
StackVector<>& __restrict__ i_to_o,
|
||||
StackVector<>& __restrict__ out_shape
|
||||
) {
|
||||
auto in = inputs().front();
|
||||
auto in_shape = in->shape;
|
||||
auto nin = in_shape.size();
|
||||
i_to_vs.n = i_to_o.n = nin;
|
||||
out_shape.n = 0;
|
||||
|
||||
int vid = 0;
|
||||
first_oid_of_var = -1;
|
||||
var_dim = 0;
|
||||
for (int i=0; i<nin; i++) {
|
||||
auto& s = vs.slices[vid];
|
||||
if (vid >= vs.n) {
|
||||
// i i i
|
||||
// | | |
|
||||
// v v v --> overflow
|
||||
// s s
|
||||
i_to_vs[i] = -1;
|
||||
i_to_o[i] = out_shape.size();
|
||||
out_shape.push_back(in_shape[i]);
|
||||
} else
|
||||
if (s.is_var()) {
|
||||
// i --> s ---> o
|
||||
// + ---> o
|
||||
// var maybe multiple dims
|
||||
if (first_oid_of_var == -1) {
|
||||
for (int i=0; i<vs.n; i++)
|
||||
if (vs.slices[i].is_var())
|
||||
var_dim = std::max(var_dim, vs.slices[i].var->shape.size());
|
||||
first_oid_of_var = out_shape.size();
|
||||
for (int j=0; j<var_dim; j++) {
|
||||
out_shape.push_back(1);
|
||||
}
|
||||
}
|
||||
i_to_vs[i] = vid++;
|
||||
i_to_o[i] = -1;
|
||||
auto iv = s.var;
|
||||
auto iv_shape = iv->shape;
|
||||
auto niv = iv_shape.size();
|
||||
for (int j=0; j<niv; j++) {
|
||||
auto iv_shape_j = iv_shape[niv-j-1];
|
||||
auto& out_shape_j = out_shape[first_oid_of_var+var_dim-j-1];
|
||||
if (out_shape_j == 1)
|
||||
out_shape_j = iv_shape_j;
|
||||
else
|
||||
ASSERT(out_shape_j == iv_shape_j || out_shape_j < 0 || iv_shape_j < 0)
|
||||
<< out_shape_j << iv_shape_j << out_shape;
|
||||
}
|
||||
} else
|
||||
if (s.is_ellipsis()) {
|
||||
auto remain_slice = vs.n-vid-1;
|
||||
auto remain_idims = nin-i;
|
||||
auto ellipsis_size = remain_idims - remain_slice;
|
||||
ASSERT(ellipsis_size>=0) << "NDims not match";
|
||||
for (int j=0; j<ellipsis_size; j++) {
|
||||
i_to_vs[i+j] = -1;
|
||||
i_to_o[i+j] = out_shape.size();
|
||||
out_shape.push_back(in_shape[i+j]);
|
||||
}
|
||||
vid ++;
|
||||
i += ellipsis_size-1;
|
||||
} else
|
||||
if (s.is_none()) {
|
||||
i--;
|
||||
out_shape.push_back(1);
|
||||
vid++;
|
||||
continue;
|
||||
} else
|
||||
if (s.is_int()) {
|
||||
i_to_vs[i] = vid++;
|
||||
i_to_o[i] = -1;
|
||||
auto in_shape_i = in_shape[i];
|
||||
auto& v = s.slice.start;
|
||||
if (v<0) v += in_shape_i;
|
||||
CHECK(v>=0 && v<in_shape_i) << "slice overflow, " << v << "not in [0,">>in_shape_i>>")";
|
||||
} else {
|
||||
// slice
|
||||
auto& slice = s.slice;
|
||||
auto in_shape_i = in_shape[i];
|
||||
auto out_shape_j = in_shape_i;
|
||||
if (slice.mask == 7) {
|
||||
// slice is a[::]
|
||||
// start, stop, step is not filled
|
||||
vid++;
|
||||
i_to_vs[i] = -1;
|
||||
i_to_o[i] = out_shape.size();
|
||||
out_shape.push_back(out_shape_j);
|
||||
} else {
|
||||
i_to_vs[i] = vid++;
|
||||
i_to_o[i] = out_shape.size();
|
||||
if (in_shape_i > 0) {
|
||||
slice.fill(in_shape_i);
|
||||
if (abs(slice.step) <= 1)
|
||||
out_shape_j = (slice.stop - slice.start) * slice.step;
|
||||
else if (slice.step>0)
|
||||
out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1;
|
||||
else
|
||||
out_shape_j = (slice.start - slice.stop - 1) / -slice.step + 1;
|
||||
out_shape_j = std::max(0l, out_shape_j);
|
||||
}
|
||||
out_shape.push_back(out_shape_j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) {
|
||||
// bz by bx tz ty tx
|
||||
// 5 4 3 2 1 0
|
||||
// LOi: bitmask of used dims of loop i
|
||||
// LOi bit 6: need for
|
||||
// if need for, keep for range: for (int i@i=tid; tid<range; tid+=tnum)
|
||||
// if not need for, replace range -> tnum, for -> int i@i = tid
|
||||
int rtnum = 1024;
|
||||
// int max_tnum = {1024, 1024, 64, (1u<<31)-1, 65535, 65535};
|
||||
int loop_id = (int)o_shape.size()-1;
|
||||
int tid = 0;
|
||||
int64 block_size = 1;
|
||||
int thread_size = 1;
|
||||
for (int i=0; i<6; i++) tdims[i] = 1;
|
||||
for (; tid<3 && loop_id>=0 && rtnum>1; tid++) {
|
||||
int64 si = o_shape[loop_id];
|
||||
int mask = 1<<tid;
|
||||
if (tid==2) rtnum = std::min(64, rtnum);
|
||||
if (si>rtnum*4) {
|
||||
// need for, use tid(1<<i) and bx(8)
|
||||
mask |= 8|(1<<6);
|
||||
block_size = (si-1)/rtnum+1;
|
||||
tdims[tid] = rtnum;
|
||||
tdims[3] = block_size;
|
||||
tid = 3;
|
||||
thread_size *= rtnum;
|
||||
rtnum = 0;
|
||||
} else
|
||||
if (si>rtnum) {
|
||||
mask |= (1<<6);
|
||||
thread_size *= rtnum;
|
||||
tdims[tid] = rtnum;
|
||||
rtnum = 0;
|
||||
} else {
|
||||
rtnum = rtnum / std::max(si, (int64)1);
|
||||
thread_size *= si;
|
||||
tdims[tid] = si;
|
||||
}
|
||||
masks[loop_id] = mask;
|
||||
loop_id --;
|
||||
}
|
||||
int64 total_size = (int64)block_size*thread_size;
|
||||
if (tid<3) tid=3;
|
||||
for (; tid<6 && loop_id>=0 && total_size<(256*1024); tid++) {
|
||||
int64 si = o_shape[loop_id];
|
||||
int mask = 1<<tid;
|
||||
int64 max_thread = tid>=4 ? 65535 : (1u<<31)-1;
|
||||
if (si > max_thread) {
|
||||
si = max_thread;
|
||||
mask |= 1<<6;
|
||||
}
|
||||
total_size *= si;
|
||||
tdims[tid] = si;
|
||||
masks[loop_id] = mask;
|
||||
loop_id --;
|
||||
}
|
||||
while (loop_id>=0) {
|
||||
masks[loop_id--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void GetitemOp::compile_optimize(string& src) {
|
||||
_compile_optimize(src);
|
||||
}
|
||||
|
||||
void GetitemOp::_compile_optimize(string& src) {
|
||||
if (!flags.get(NodeFlags::_cuda))
|
||||
return;
|
||||
|
||||
auto jd = get_jit_define();
|
||||
map<string,string> jd_map(jd.begin(), jd.end());
|
||||
|
||||
KernelIR main(src);
|
||||
auto& func = main.children.back()->children.back();
|
||||
// auto& loop = func->children.back();
|
||||
|
||||
func->push_back("void func() {}", &func->before);
|
||||
|
||||
auto& new_func = func->before.back();
|
||||
// auto new_func = func->before.back()->move_out();
|
||||
|
||||
new_func->attrs["dtype"] = "static __global__ void";
|
||||
// LOGir << main.to_string();
|
||||
src = main.to_string();
|
||||
string arg_call = "";
|
||||
const char* tname[] = {"threadIdx.x", "threadIdx.y", "threadIdx.z", "blockIdx.x", "blockIdx.y", "blockIdx.z"};
|
||||
const char* tname2[] = {"blockDim.x", "blockDim.y", "blockDim.z", "gridDim.x", "gridDim.y", "gridDim.z"};
|
||||
for (auto& ir : func->children) {
|
||||
if (ir->type == "define") {
|
||||
string& rvalue = ir->attrs.at("rvalue");
|
||||
string& lvalue = ir->attrs.at("lvalue");
|
||||
string& dtype = ir->attrs.at("dtype");
|
||||
if (startswith(rvalue, "input")
|
||||
|| startswith(rvalue, "output")
|
||||
|| startswith(rvalue, "vs.")
|
||||
|| rvalue.back() == ')'
|
||||
|| rvalue.back() == ']')
|
||||
{
|
||||
if (dtype == "auto")
|
||||
LOGvvvv << "keep" << rvalue;
|
||||
else {
|
||||
LOGvvvv << "args" << rvalue;
|
||||
if (arg_call.size()) arg_call += ", ";
|
||||
arg_call += lvalue;
|
||||
LOGvvvv << dtype+" "+lvalue;
|
||||
new_func->push_back(dtype+" "+lvalue+";", &new_func->inner);
|
||||
}
|
||||
} else {
|
||||
LOGvvvv << "move" <<rvalue;
|
||||
new_func->push_back(ir->clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
new_func->push_back(func->children.back()->move_out());
|
||||
auto& loop = new_func->children.back();
|
||||
int no = o_shape.size();
|
||||
KernelIR* loops[no];
|
||||
if (!no) {
|
||||
func->push_back("func<<<1,1>>>("+arg_call+");");
|
||||
} else {
|
||||
loops[0] = loop.get();
|
||||
for (int i=1; i<no; i++)
|
||||
loops[i] = loops[i-1]->children.back().get();
|
||||
for (int i=0; i<no; i++) {
|
||||
auto l = loops[i];
|
||||
ASSERT(l->inner.size() == 3);
|
||||
auto lo = l->find_define("LO"+S(i));
|
||||
ASSERT(lo);
|
||||
auto loi = std::stoi(lo->attrs.at("rvalue"));
|
||||
string tid = "";
|
||||
string tnum = "";
|
||||
for (int j=0; j<6; j++) {
|
||||
if ((loi>>j)&1) {
|
||||
if (tid.size()) {
|
||||
tid += string("+")+tnum+"*"+tname[j];
|
||||
tnum += string("*")+tname2[j];
|
||||
} else {
|
||||
tid = tname[j];
|
||||
tnum = tname2[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!tid.size()) {
|
||||
continue;
|
||||
}
|
||||
if (loi&(1<<6)) {
|
||||
l->inner.at(0)->attrs.at("rvalue") = tid;
|
||||
l->inner.at(2)->attrs.at("code") = "i"+S(i)+"+="+tnum+";";
|
||||
} else {
|
||||
// no need for
|
||||
while (l->inner.size())
|
||||
l->inner.at(0)->erase();
|
||||
l->push_front("index_t i"+S(i)+" = "+tid+";");
|
||||
}
|
||||
}
|
||||
func->push_back("int no = o_shape.size();");
|
||||
func->push_back("int masks[no];");
|
||||
func->push_back("int tdims[6];");
|
||||
func->push_back("cuda_loop_schedule(o_shape, masks, tdims);");
|
||||
func->push_back("dim3 grid_dim(tdims[3],tdims[4],tdims[5]);");
|
||||
func->push_back("dim3 block_dim(tdims[0],tdims[1],tdims[2]);");
|
||||
func->push_back("func<<<grid_dim, block_dim>>>("+arg_call+");");
|
||||
}
|
||||
src = main.to_string();
|
||||
}
|
||||
|
||||
void GetitemOp::infer_shape() {
|
||||
auto in = inputs().front();
|
||||
auto out = outputs().front();
|
||||
auto in_shape = in->shape;
|
||||
auto nin = in_shape.size();
|
||||
|
||||
StackVector<> i_to_vs(nin);
|
||||
StackVector<> i_to_o(nin);
|
||||
// shape return to use
|
||||
StackVector<> out_shape;
|
||||
infer_slices(i_to_vs, i_to_o, out_shape);
|
||||
|
||||
// optimized shape (each dim is a loop var)
|
||||
StackVector<> o_shape;
|
||||
int fov = -1;
|
||||
for (int i=0; i<nin; i++) {
|
||||
auto& vid = i_to_vs[i];
|
||||
auto& oid = i_to_o[i];
|
||||
auto os = out_shape[oid];
|
||||
if (oid>=0) {
|
||||
if (vid==-1 && i && i_to_vs[i-1]<0) {
|
||||
vid = -2;
|
||||
o_shape.back() *= os;
|
||||
} else
|
||||
o_shape.push_back(os);
|
||||
oid = o_shape.size()-1;
|
||||
} else {
|
||||
auto& s = vs.slices[vid];
|
||||
if (s.is_var() && fov == -1) {
|
||||
fov = o_shape.size();
|
||||
for (int i=0; i<var_dim; i++)
|
||||
o_shape.push_back(out_shape[first_oid_of_var+i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
first_oid_of_var = fov;
|
||||
|
||||
if (!out_shape.size()) out_shape.push_back(1);
|
||||
out->set_shape(out_shape.to_nano_vector());
|
||||
|
||||
this->i_to_vs = i_to_vs.to_nano_vector();
|
||||
this->i_to_o = i_to_o.to_nano_vector();
|
||||
this->o_shape = o_shape.to_nano_vector();
|
||||
|
||||
LOGvvvv << "\ni_to_vs:" << i_to_vs
|
||||
<< "\ni_to_o:" << i_to_o
|
||||
<< "\no_shape:" << o_shape;
|
||||
}
|
||||
|
||||
VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
if (v_index)
|
||||
return nullptr;
|
||||
auto zeros = make_number(0, v);
|
||||
// TODO: maybe add here?
|
||||
// need analysis the overlap attr os var slices
|
||||
return make_setitem(zeros, VarSlices(vs), dout, ns_void);
|
||||
}
|
||||
|
||||
void GetitemOp::jit_prepare() {
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
add_jit_define("Ti", in->dtype());
|
||||
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
|
||||
add_jit_define("ODIM", JK::hex1(o_shape.size()));
|
||||
if (first_oid_of_var>=0) {
|
||||
add_jit_define("FOV", JK::hex1(first_oid_of_var));
|
||||
add_jit_define("VD", JK::hex1(var_dim));
|
||||
}
|
||||
for (int i=0; i<idim; i++) {
|
||||
auto iv = i_to_vs[i];
|
||||
auto io = i_to_o[i];
|
||||
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
|
||||
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
|
||||
auto& v = vs.slices[iv];
|
||||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
add_jit_define("VS", JK::hex1(i), "-1");
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
auto vshape = var->shape;
|
||||
auto vdim = vshape.size();
|
||||
int vsmask = 0;
|
||||
for (int j=0; j<vdim; j++) {
|
||||
int k = first_oid_of_var+j+var_dim-vdim;
|
||||
if (vshape[j] == o_shape[k])
|
||||
vsmask |= 1<<(j+var_dim-vdim);
|
||||
}
|
||||
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
|
||||
add_jit_define("VST", JK::hex1(i), var->dtype());
|
||||
}
|
||||
} else
|
||||
if (iv>=0 && io>=0) {
|
||||
ASSERT(v.is_slice());
|
||||
if (std::abs(v.slice.step) <= 1)
|
||||
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
|
||||
else
|
||||
add_jit_define("VS", JK::hex1(i), "0");
|
||||
}
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
int no = o_shape.size();
|
||||
int masks[no];
|
||||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
void GetitemOp::jit_run() {
|
||||
auto in = inputs().front();
|
||||
auto out = outputs().front();
|
||||
if (out->num == 0) return;
|
||||
|
||||
@for(i, 0, ODIM, index_t oshape@i = o_shape[@i];)
|
||||
@if(ODIM>0,
|
||||
index_t ostride@{ODIM-1} = 1;
|
||||
@for(i, ODIM-2, -1, -1, index_t ostride@i = ostride@{i+1} * oshape@{i+1};)
|
||||
)
|
||||
Ti* op = out->ptr<Ti>();
|
||||
|
||||
Ti* ip = in->ptr<Ti>();
|
||||
@for(i, 0, IDIM, index_t ishape@i =
|
||||
@if(IV@i==-1,oshape@{IO@i},
|
||||
@if(IV@i==-2,1,in->shape[@i]));
|
||||
)
|
||||
index_t istride@{IDIM-1} = 1;
|
||||
@for(i, IDIM-2, -1, -1, index_t istride@i = istride@{i+1} * ishape@{i+1};)
|
||||
|
||||
|
||||
@for(i, 0, IDIM,
|
||||
@if(IV@i>=0 && IO@i>=0,
|
||||
index_t vstart@i = vs.slices[@{IV@i}].slice.start;
|
||||
index_t vstep@i = @if(VS@i==0,vs.slices[@{IV@i}].slice.step;,@{VS@i});
|
||||
)
|
||||
)
|
||||
|
||||
@for(i, 0, IDIM,
|
||||
@if(IV@i>=0 && IO@i<0,
|
||||
@if(VS@i==-1,index_t vi@i = vs.slices[@{IV@i}].slice.start;);
|
||||
)
|
||||
)
|
||||
|
||||
@for(i, 0, IDIM,
|
||||
@if(IV@i>=0 && IO@i<0,
|
||||
@if(VS@i>=0,
|
||||
index_t vs@i@@s@{VD-1} = 1;
|
||||
VST@i* vp@i = vs.slices[IV@i].var->ptr<VST@i>();
|
||||
@for(j,VD-2,-1,-1,index_t vs@i@@s@j = vs@i@@s@{j+1} *
|
||||
@if((VS@i>>(j+1))&1,oshape@{j+1+FOV},1);
|
||||
)
|
||||
);
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
|
||||
index_t oid = 0 @for(d, 0, ODIM, + i@d * ostride@d);
|
||||
@for(d, 0, IDIM, index_t iid@d =
|
||||
@if(IV@d==-1, i@{IO@d},
|
||||
@if(IV@d==-2, 0,
|
||||
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
|
||||
@if(VS@d==-1, vi@d,
|
||||
@if(VS@d>=0,
|
||||
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
|
||||
, ??? )))));
|
||||
)
|
||||
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
|
||||
op[oid] = ip[iid];
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,42 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "var_slices.h"
|
||||
#include "misc/stack_vector.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct GetitemOp : Op {
|
||||
VarSlices vs;
|
||||
// map i to related var slice
|
||||
NanoVector i_to_vs;
|
||||
// map i to related o
|
||||
NanoVector i_to_o;
|
||||
NanoVector o_shape;
|
||||
int first_oid_of_var, var_dim;
|
||||
|
||||
GetitemOp(Var* x, VarSlices&& slices);
|
||||
|
||||
const char* name() const override { return "getitem"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
void compile_optimize(string& src) override;
|
||||
void graph_optimize() override;
|
||||
DECLARE_jit_run;
|
||||
|
||||
void infer_slices(
|
||||
StackVector<>& __restrict__ i_to_vs,
|
||||
StackVector<>& __restrict__ i_to_o,
|
||||
StackVector<>& __restrict__ out_shape
|
||||
);
|
||||
void _compile_optimize(string& src);
|
||||
};
|
||||
|
||||
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims);
|
||||
|
||||
} // jittor
|
|
@ -16,9 +16,22 @@ static auto make_broadcast_to = get_op_info("broadcast_to")
|
|||
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
||||
|
||||
VarPtr make_number(float number, Var* x) {
|
||||
VarPtr nums = make_array(&number, 1, ns_float32);
|
||||
nums = make_broadcast_to(nums, x, {});
|
||||
return make_unary(nums, x->dtype());
|
||||
union Number {
|
||||
float32 f32;
|
||||
float64 f64;
|
||||
int32 i32;
|
||||
int64 i64;
|
||||
} v;
|
||||
if (x->dtype() == ns_float32) v.f32 = number; else
|
||||
if (x->dtype() == ns_float64) v.f64 = number; else
|
||||
if (x->dtype() == ns_int32) v.i32 = number; else
|
||||
if (x->dtype() == ns_int64) v.i64 = number; else {
|
||||
VarPtr nums = make_array(&number, 1, ns_float32);
|
||||
nums = make_broadcast_to(nums, x, {});
|
||||
return make_unary(nums, x->dtype());
|
||||
}
|
||||
VarPtr nums = make_array(&v, 1, x->dtype());
|
||||
return make_broadcast_to(nums, x, {});
|
||||
}
|
||||
|
||||
static void init() {
|
||||
|
|
|
@ -0,0 +1,327 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "ops/setitem_op.h"
|
||||
#include "ops/getitem_op.h"
|
||||
#ifdef JIT
|
||||
#include "ops/binary_op_defs.h"
|
||||
#ifdef JIT_cuda
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#endif
|
||||
#else
|
||||
#include "ops/op_register.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include "misc/cuda_flags.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_number = get_op_info("number")
|
||||
.get_constructor<VarPtr, float, Var*>();
|
||||
static auto make_getitem = get_op_info("getitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&>();
|
||||
static auto make_setitem = get_op_info("setitem")
|
||||
.get_constructor<VarPtr, Var*, VarSlices&&, Var*, NanoString>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoString>();
|
||||
static auto make_unary = get_op_info("unary")
|
||||
.get_constructor<VarPtr, Var*, NanoString>();
|
||||
|
||||
SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op)
|
||||
: vs(move(slices)), op(op) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_has_gopt);
|
||||
ASSERT(ns == ns_void || ns.is_binary());
|
||||
create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void SetitemOp::infer_shape() {
|
||||
auto in = inputs().front();
|
||||
auto data = input(1);
|
||||
auto out = outputs().front();
|
||||
auto in_shape = in->shape;
|
||||
auto nin = in_shape.size();
|
||||
|
||||
StackVector<> i_to_vs(nin);
|
||||
StackVector<> i_to_o(nin);
|
||||
// shape return to use
|
||||
StackVector<> out_shape;
|
||||
((GetitemOp*)this)->infer_slices(i_to_vs, i_to_o, out_shape);
|
||||
if (!out_shape.size()) out_shape.push_back(1);
|
||||
|
||||
// get broadcast mask of set value
|
||||
auto data_shape = data->shape;
|
||||
auto data_dim = data_shape.size();
|
||||
int bmask = 0;
|
||||
int bmask2 = 0;
|
||||
|
||||
ASSERTop(data_dim,<=,out_shape.size()) << "Data dimension not match";
|
||||
for (int i=0; i<data_dim; i++) {
|
||||
int j = i - data_dim + out_shape.size();
|
||||
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
|
||||
ASSERTop(data_shape[i],==,out_shape[j]) << "Data shape not match" << data_shape << out_shape;
|
||||
bmask |= 1<<j;
|
||||
}
|
||||
}
|
||||
|
||||
// optimized shape (each dim is a loop var)
|
||||
StackVector<> o_shape;
|
||||
int fov = -1;
|
||||
for (int i=0; i<nin; i++) {
|
||||
auto& vid = i_to_vs[i];
|
||||
auto& oid = i_to_o[i];
|
||||
auto os = out_shape[oid];
|
||||
if (oid>=0) {
|
||||
if (vid==-1 && i && i_to_vs[i-1]<0
|
||||
&& ((bmask>>oid)&1) == ((bmask>>(oid-1))&1))
|
||||
// same broadcast condition with prev dim
|
||||
{
|
||||
vid = -2;
|
||||
o_shape.back() *= os;
|
||||
} else {
|
||||
o_shape.push_back(os);
|
||||
// fix bmask2 offset
|
||||
bmask2 |= ((bmask>>oid)&1) << (o_shape.size()-1);
|
||||
}
|
||||
oid = o_shape.size()-1;
|
||||
} else {
|
||||
auto& s = vs.slices[vid];
|
||||
if (s.is_var() && fov == -1) {
|
||||
fov = o_shape.size();
|
||||
for (int i=0; i<var_dim; i++) {
|
||||
o_shape.push_back(out_shape[first_oid_of_var+i]);
|
||||
// fix bmask2 offset
|
||||
bmask2 |= ((bmask>>(first_oid_of_var+i))&1) << (o_shape.size()-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
first_oid_of_var = fov;
|
||||
this->bmask = bmask2;
|
||||
|
||||
out->set_shape(in_shape);
|
||||
|
||||
this->i_to_vs = i_to_vs.to_nano_vector();
|
||||
this->i_to_o = i_to_o.to_nano_vector();
|
||||
this->o_shape = o_shape.to_nano_vector();
|
||||
|
||||
LOGvvvv << "\ni_to_vs:" << i_to_vs
|
||||
<< "\ni_to_o:" << i_to_o
|
||||
<< "\no_shape:" << o_shape;
|
||||
}
|
||||
|
||||
VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
if (v_index >= 2)
|
||||
return nullptr;
|
||||
if (op == ns_void) {
|
||||
if (v_index == 0) {
|
||||
float32 number = 0;
|
||||
VarPtr zero = make_array(&number, 1, ns_float32);
|
||||
return make_setitem(dout, VarSlices(vs), zero, ns_void);
|
||||
} else {
|
||||
return make_getitem(dout, VarSlices(vs));
|
||||
}
|
||||
}
|
||||
if (op == ns_add) {
|
||||
if (v_index == 0) {
|
||||
return dout;
|
||||
} else {
|
||||
return make_getitem(dout, VarSlices(vs));
|
||||
}
|
||||
}
|
||||
if (op == ns_subtract) {
|
||||
if (v_index == 0) {
|
||||
return dout;
|
||||
} else {
|
||||
return make_unary(make_getitem(dout, VarSlices(vs)), ns_negative);
|
||||
}
|
||||
}
|
||||
if (op == ns_multiply) {
|
||||
if (v_index == 0) {
|
||||
return make_setitem(dout, VarSlices(vs), input(1), ns_multiply);
|
||||
} else {
|
||||
return make_binary(
|
||||
make_getitem(inputs().front(), VarSlices(vs)),
|
||||
make_getitem(dout, VarSlices(vs)), ns_multiply);
|
||||
}
|
||||
}
|
||||
if (op == ns_divide) {
|
||||
if (v_index == 0) {
|
||||
return make_setitem(dout, VarSlices(vs), input(1), ns_divide);
|
||||
} else {
|
||||
// dy = -dz*x / y^2
|
||||
auto dout2 = make_getitem(dout, VarSlices(vs));
|
||||
auto x = make_getitem(inputs().front(), VarSlices(vs));
|
||||
auto y = v;
|
||||
auto ndz = make_unary(dout2, ns_negative);
|
||||
auto ndzx = make_binary(ndz, x, ns_multiply);
|
||||
auto y2 = make_binary(y, y, ns_multiply);
|
||||
return make_binary(ndzx, y2, ns_divide);
|
||||
}
|
||||
}
|
||||
LOGf << "Setitem grad of op" << op << "is not supported yet";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void SetitemOp::jit_prepare() {
|
||||
auto data = input(1);
|
||||
add_jit_define("OP", op);
|
||||
add_jit_define("Td", data->dtype());
|
||||
add_jit_define("BMASK", JK::hex(bmask));
|
||||
// TODO: merge code
|
||||
auto in = inputs().front();
|
||||
int idim = i_to_vs.size();
|
||||
add_jit_define("Ti", in->dtype());
|
||||
add_jit_define("IDIM", JK::hex1(i_to_vs.size()));
|
||||
add_jit_define("ODIM", JK::hex1(o_shape.size()));
|
||||
if (first_oid_of_var>=0) {
|
||||
add_jit_define("FOV", JK::hex1(first_oid_of_var));
|
||||
add_jit_define("VD", JK::hex1(var_dim));
|
||||
}
|
||||
for (int i=0; i<idim; i++) {
|
||||
auto iv = i_to_vs[i];
|
||||
auto io = i_to_o[i];
|
||||
add_jit_define("IV", JK::hex1(i), JK::shex1(iv));
|
||||
add_jit_define("IO", JK::hex1(i), JK::shex1(io));
|
||||
auto& v = vs.slices[iv];
|
||||
if (iv>=0 && io==-1) {
|
||||
if (v.is_int()) {
|
||||
add_jit_define("VS", JK::hex1(i), "-1");
|
||||
} else {
|
||||
ASSERT(v.is_var());
|
||||
auto var = v.var;
|
||||
auto vshape = var->shape;
|
||||
auto vdim = vshape.size();
|
||||
int vsmask = 0;
|
||||
for (int j=0; j<vdim; j++) {
|
||||
int k = first_oid_of_var+j+var_dim-vdim;
|
||||
if (vshape[j] == o_shape[k])
|
||||
vsmask |= 1<<(j+var_dim-vdim);
|
||||
}
|
||||
add_jit_define("VS", JK::hex1(i), JK::hex(vsmask));
|
||||
add_jit_define("VST", JK::hex1(i), var->dtype());
|
||||
}
|
||||
} else
|
||||
if (iv>=0 && io>=0) {
|
||||
ASSERT(v.is_slice());
|
||||
if (std::abs(v.slice.step) <= 1)
|
||||
add_jit_define("VS", JK::hex1(i), JK::shex1(v.slice.step));
|
||||
else
|
||||
add_jit_define("VS", JK::hex1(i), "0");
|
||||
}
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
int no = o_shape.size();
|
||||
int masks[no];
|
||||
int tdims[6];
|
||||
cuda_loop_schedule(o_shape, masks, tdims);
|
||||
for (int i=0; i<no; i++) {
|
||||
add_jit_define("LO", JK::hex1(i), JK::hex(masks[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void SetitemOp::compile_optimize(string& src) {
|
||||
((GetitemOp*)this)->_compile_optimize(src);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
void SetitemOp::jit_run() {
|
||||
auto in = inputs().front();
|
||||
auto data = input(1);
|
||||
auto out = outputs().front();
|
||||
if (out->num == 0) return;
|
||||
|
||||
@for(i, 0, ODIM, index_t oshape@i = o_shape[@i];)
|
||||
@if(ODIM>0,
|
||||
index_t ostride@{ODIM-1} = 1;
|
||||
@for(i, ODIM-2, -1, -1, index_t ostride@i = ostride@{i+1} * oshape@{i+1};)
|
||||
)
|
||||
Ti* op = out->ptr<Ti>();
|
||||
|
||||
Ti* ip = in->ptr<Ti>();
|
||||
@for(i, 0, IDIM, index_t ishape@i =
|
||||
@if(IV@i==-1,oshape@{IO@i},
|
||||
@if(IV@i==-2,1,in->shape[@i]));
|
||||
)
|
||||
index_t istride@{IDIM-1} = 1;
|
||||
@for(i, IDIM-2, -1, -1, index_t istride@i = istride@{i+1} * ishape@{i+1};)
|
||||
|
||||
|
||||
@for(i, 0, IDIM,
|
||||
@if(IV@i>=0 && IO@i>=0,
|
||||
index_t vstart@i = vs.slices[@{IV@i}].slice.start;
|
||||
index_t vstep@i = @if(VS@i==0,vs.slices[@{IV@i}].slice.step;,@{VS@i});
|
||||
)
|
||||
)
|
||||
|
||||
@for(i, 0, IDIM,
|
||||
@if(IV@i>=0 && IO@i<0,
|
||||
@if(VS@i==-1,index_t vi@i = vs.slices[@{IV@i}].slice.start;);
|
||||
)
|
||||
)
|
||||
|
||||
@for(i, 0, IDIM,
|
||||
@if(IV@i>=0 && IO@i<0,
|
||||
@if(VS@i>=0,
|
||||
index_t vs@i@@s@{VD-1} = 1;
|
||||
VST@i* vp@i = vs.slices[IV@i].var->ptr<VST@i>();
|
||||
@for(j,VD-2,-1,-1,index_t vs@i@@s@j = vs@i@@s@{j+1} *
|
||||
@if((VS@i>>(j+1))&1,oshape@{j+1+FOV},1);
|
||||
)
|
||||
);
|
||||
)
|
||||
)
|
||||
|
||||
Td* dp = data->ptr<Td>();
|
||||
@if(ODIM>0,
|
||||
index_t dstride@{ODIM-1} = 1;
|
||||
@for(i, ODIM-2, -1, -1, index_t dstride@i = dstride@{i+1} * @if((BMASK>>(i+1))&1,oshape@{i+1},1);)
|
||||
)
|
||||
#ifdef JIT_cpu
|
||||
if (op != ip)
|
||||
std::memcpy(op, ip, out->size);
|
||||
#else
|
||||
if (op != ip)
|
||||
checkCudaErrors(cudaMemcpyAsync(op, ip, out->size, cudaMemcpyDefault, 0));
|
||||
#endif
|
||||
|
||||
@for(d, 0, ODIM, for (index_t i@d=0; i@d < oshape@d; i@d++)) {
|
||||
index_t did = 0 @for(d, 0, ODIM, @if((BMASK>>d)&1,+ i@d * dstride@d));
|
||||
@for(d, 0, IDIM, index_t iid@d =
|
||||
@if(IV@d==-1, i@{IO@d},
|
||||
@if(IV@d==-2, 0,
|
||||
@if(IO@d!=-1, (i@{IO@d}*vstep@d+vstart@d),
|
||||
@if(VS@d==-1, vi@d,
|
||||
@if(VS@d>=0,
|
||||
index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))])
|
||||
, ??? )))));
|
||||
)
|
||||
auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d);
|
||||
|
||||
@if(@strcmp(@OP,void)==0,
|
||||
op[iid] = (Ti)dp[did],
|
||||
op[iid] = @expand_macro(@OP, Ti, op[iid], dp[did])
|
||||
);
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,34 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "var_slices.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct SetitemOp : Op {
|
||||
VarSlices vs;
|
||||
// map i to related var slice
|
||||
NanoVector i_to_vs;
|
||||
// map i to related o
|
||||
NanoVector i_to_o;
|
||||
NanoVector o_shape;
|
||||
int first_oid_of_var, var_dim;
|
||||
int bmask;
|
||||
NanoString op;
|
||||
|
||||
SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op=ns_void);
|
||||
|
||||
const char* name() const override { return "setitem"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
void compile_optimize(string& src) override;
|
||||
void graph_optimize() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -49,8 +49,11 @@ Tapes::Tapes(
|
|||
const vector<VarHolder*>& taped_outputs,
|
||||
GradCallback&& grad_callback
|
||||
) {
|
||||
callback = move(grad_callback);
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_grads);
|
||||
callback = move(grad_callback);
|
||||
|
||||
|
||||
/*
|
||||
stop grad stop grad
|
||||
|
@ -89,4 +92,4 @@ void tape_together(
|
|||
new Tapes(taped_inputs, taped_outputs, move(grad_callback));
|
||||
}
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -5,12 +5,17 @@
|
|||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "ops/where_op.h"
|
||||
#ifdef JIT_cuda
|
||||
#include "executor.h"
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
flags.set(NodeFlags::_vary_shape);
|
||||
auto ndim = cond->shape.size();
|
||||
outs.reset(new Var*[ndim]);
|
||||
|
@ -33,6 +38,186 @@ void WhereOp::jit_prepare() {
|
|||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
__global__ static void where_kernel(
|
||||
@for(i, 0, NDIM, 1, index_t condshape@i, )
|
||||
Ti* __restrict__ condp,
|
||||
@for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, )
|
||||
int* __restrict__ np
|
||||
) {
|
||||
__shared__ uint n;
|
||||
int tid = threadIdx.x;
|
||||
int tnum = blockDim.x;
|
||||
if (tid == 0)
|
||||
n = 0;
|
||||
// define cond stride
|
||||
index_t condstride@{NDIM-1} = 1;
|
||||
@for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};)
|
||||
__syncthreads();
|
||||
|
||||
// generate d-for loop
|
||||
@for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++))
|
||||
for (index_t i@{NDIM-1}=tid; i@{NDIM-1}<condshape@{NDIM-1}; i@{NDIM-1}+=tnum)
|
||||
{
|
||||
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
|
||||
if (condp[condid]) {
|
||||
uint cn = atomicInc(&n, 1u<<30);
|
||||
@for(i, 0, NDIM, outs@i@@p[cn] = i@i;)
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0)
|
||||
(*np) = n;
|
||||
}
|
||||
|
||||
|
||||
__device__ inline uint prefix_sum(uint val, uint lane_id) {
|
||||
#define FULL_MASK 0xffffffff
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
uint x = __shfl_up_sync(FULL_MASK, val, offset);
|
||||
val += lane_id>=offset? x : 0;
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ inline uint bc(uint val, uint lane_id) {
|
||||
return __shfl_sync(FULL_MASK, val, lane_id);
|
||||
}
|
||||
|
||||
__global__ static void where_kernel_one_warp(
|
||||
@for(i, 0, NDIM, 1, index_t condshape@i, )
|
||||
Ti* __restrict__ condp,
|
||||
@for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, )
|
||||
int* __restrict__ np
|
||||
) {
|
||||
uint n = 0;
|
||||
int tid = threadIdx.x;
|
||||
int tnum = 32;
|
||||
// define cond stride
|
||||
index_t condstride@{NDIM-1} = 1;
|
||||
@for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};)
|
||||
|
||||
// generate d-for loop
|
||||
@for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++))
|
||||
for (index_t i=0; i<condshape@{NDIM-1}; i+=tnum)
|
||||
{
|
||||
index_t i@{NDIM-1} = i + tid;
|
||||
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
|
||||
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
|
||||
uint prefix_x = prefix_sum(x, tid);
|
||||
if (x) {
|
||||
uint cn = n + prefix_x - 1;
|
||||
@for(i, 0, NDIM, outs@i@@p[cn] = i@i;)
|
||||
}
|
||||
n += bc(prefix_x, 31);
|
||||
}
|
||||
if (tid == 0)
|
||||
(*np) = n;
|
||||
}
|
||||
|
||||
#define WTN 1024
|
||||
|
||||
__global__ static void where_kernel_one_block(
|
||||
@for(i, 0, NDIM, 1, index_t condshape@i, )
|
||||
Ti* __restrict__ condp,
|
||||
@for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, )
|
||||
int* __restrict__ np
|
||||
) {
|
||||
uint n = 0;
|
||||
int tid = threadIdx.x;
|
||||
int tnum = WTN;
|
||||
__shared__ uint s[WTN/32];
|
||||
int wid = tid / 32;
|
||||
int lid = tid % 32;
|
||||
int wnum = WTN / 32;
|
||||
// define cond stride
|
||||
index_t condstride@{NDIM-1} = 1;
|
||||
@for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};)
|
||||
|
||||
// generate d-for loop
|
||||
@for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++))
|
||||
for (index_t i=0; i<condshape@{NDIM-1}; i+=tnum)
|
||||
{
|
||||
index_t i@{NDIM-1} = i + tid;
|
||||
auto condid = @for(d, 0, NDIM, + i@d * condstride@d);
|
||||
uint x = i@{NDIM-1}<condshape@{NDIM-1} ? condp[condid] : 0;
|
||||
uint prefix_x = prefix_sum(x, lid);
|
||||
uint warp_sum = bc(prefix_x, 31);
|
||||
|
||||
// prefix sum between warps
|
||||
if (lid == 0) {
|
||||
s[wid] = warp_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
if (wid == 0) {
|
||||
s[lid] = prefix_sum(s[lid], lid);
|
||||
}
|
||||
__syncthreads();
|
||||
uint warp_prefix_sum = s[wid];
|
||||
uint block_sum = s[wnum-1];
|
||||
__syncthreads();
|
||||
|
||||
if (x) {
|
||||
uint cn = n + prefix_x - 1 + warp_prefix_sum - warp_sum;
|
||||
@for(i, 0, NDIM, outs@i@@p[cn] = i@i;)
|
||||
}
|
||||
n += block_sum;
|
||||
}
|
||||
if (tid == 0)
|
||||
(*np) = n;
|
||||
}
|
||||
|
||||
void WhereOp::jit_run() {
|
||||
auto* __restrict__ condp = cond->ptr<Ti>();
|
||||
// define cond shape
|
||||
@for(i, 0, NDIM, index_t condshape@i = cond->shape[@i];)
|
||||
|
||||
// define outs
|
||||
@for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr<To>();)
|
||||
|
||||
size_t n_allocation;
|
||||
int* np = (int*)exe.allocator->alloc(4, n_allocation);
|
||||
|
||||
// one block kernel, result maybe unstable
|
||||
// int tnum = condshape@{NDIM-1};
|
||||
// tnum = std::max(1, std::min(1024, tnum));
|
||||
// where_kernel<<<1,tnum>>>(
|
||||
// @for(i, 0, NDIM, 1, condshape@i, )
|
||||
// condp,
|
||||
// @for(i, 0, NDIM, 1, outs@i@@p, )
|
||||
// np
|
||||
// );
|
||||
|
||||
|
||||
int tnum = condshape@{NDIM-1};
|
||||
if (tnum < 100) {
|
||||
// one warp kernel, result is stable
|
||||
where_kernel_one_warp<<<1,32>>>(
|
||||
@for(i, 0, NDIM, 1, condshape@i, )
|
||||
condp,
|
||||
@for(i, 0, NDIM, 1, outs@i@@p, )
|
||||
np
|
||||
);
|
||||
} else {
|
||||
// one block kernel, result is stable
|
||||
where_kernel_one_block<<<1,WTN>>>(
|
||||
@for(i, 0, NDIM, 1, condshape@i, )
|
||||
condp,
|
||||
@for(i, 0, NDIM, 1, outs@i@@p, )
|
||||
np
|
||||
);
|
||||
}
|
||||
|
||||
int n=0;
|
||||
// checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDefault));
|
||||
@for(i, 0, NDIM, outs[@i]->set_shape({n});)
|
||||
exe.allocator->free(np, 4, n_allocation);
|
||||
}
|
||||
#else
|
||||
|
||||
void WhereOp::jit_run() {
|
||||
auto* __restrict__ condp = cond->ptr<Ti>();
|
||||
// define cond shape
|
||||
|
@ -55,6 +240,8 @@ void WhereOp::jit_run() {
|
|||
}
|
||||
@for(i, 0, NDIM, outs[@i]->set_shape({n});)
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "ops/setitem_op.h"
|
||||
#include "ops/getitem_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
inline static bool fast_strcmp(const char* a, const char* b) {
|
||||
while (*b && *a == *b) a++, b++;
|
||||
return !*b;
|
||||
}
|
||||
|
||||
static void setitem_inplace(SetitemOp* op) {
|
||||
// LOGir << "setitem_inplace";
|
||||
auto input = op->inputs().front();
|
||||
if (!(input->outputs().size() == 1 &&
|
||||
input->forward_liveness<=1 &&
|
||||
(op->op == ns_void || op->op == ns_add || op->op == ns_subtract))) {
|
||||
return;
|
||||
}
|
||||
auto input_op = input->input();
|
||||
if (input_op) {
|
||||
// make sure input op will not use input
|
||||
auto input_name = input_op->name();
|
||||
if (!(input_op->type() == OpType::broadcast ||
|
||||
fast_strcmp(input_name, "array") ||
|
||||
fast_strcmp(input_name, "empty") ||
|
||||
fast_strcmp(input_name, "setitem") ||
|
||||
fast_strcmp(input_name, "getitem")))
|
||||
// TODO: inplace getitem maybe risky, getitem maybe inplace too
|
||||
return;
|
||||
}
|
||||
auto output = op->outputs().front();
|
||||
output->share_with(input);
|
||||
// LOGir << "apply setitem_inplace on" << op << "input:" << input << "output:" << output;
|
||||
}
|
||||
|
||||
struct BBox {
|
||||
int n = 0;
|
||||
int* minmax = nullptr;
|
||||
|
||||
|
||||
|
||||
void load_var_slice(const VarSlice& vs) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
static void setitem_grad_opt(GetitemOp* op) {
|
||||
if (!op->flags.get(NodeFlags::_has_gopt))
|
||||
return;
|
||||
auto get_in = op->inputs().front();
|
||||
auto get_in_op = get_in->input();
|
||||
if (!get_in_op)
|
||||
return;
|
||||
auto name = get_in_op->name();
|
||||
if (!fast_strcmp(name, "setitem"))
|
||||
return;
|
||||
// find setitem op chain
|
||||
auto first_set = (SetitemOp*)get_in_op;
|
||||
vector<SetitemOp*> chain;
|
||||
while (1) {
|
||||
auto next = first_set->inputs().front()->input();
|
||||
if (!next) break;
|
||||
if (!fast_strcmp(next->name(), "setitem"))
|
||||
break;
|
||||
chain.push_back(first_set);
|
||||
first_set = (SetitemOp*)next;
|
||||
}
|
||||
chain.push_back(first_set);
|
||||
for (int i=0; i<chain.size()/2; i++)
|
||||
std::swap(chain[i], chain[chain.size()-1-i]);
|
||||
auto last_set = (SetitemOp*)get_in_op;
|
||||
while (1) {
|
||||
SetitemOp* next = nullptr;
|
||||
auto out_var = last_set->outputs().front();
|
||||
for (auto* out : out_var->outputs()) {
|
||||
if (fast_strcmp(out->name(), "setitem")) {
|
||||
next = (SetitemOp*)out;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!next) break;
|
||||
last_set = next;
|
||||
chain.push_back(next);
|
||||
}
|
||||
// LOGir << "find setitem chain" << chain.size() << chain;
|
||||
for (auto* sop : chain) {
|
||||
// LOGig << sop << sop->vs;
|
||||
auto out_var = sop->outputs().front();
|
||||
for (auto* out : out_var->outputs()) {
|
||||
if (fast_strcmp(out->name(), "getitem")) {
|
||||
out->flags.set(NodeFlags::_has_gopt, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void SetitemOp::graph_optimize() {
|
||||
// LOGir << "hello graph_optimize";
|
||||
setitem_inplace(this);
|
||||
}
|
||||
|
||||
void GetitemOp::graph_optimize() {
|
||||
// This optimize is still WIP
|
||||
// LOGir << "hello getitem graph_optimize";
|
||||
// setitem_grad_opt(this);
|
||||
(void)setitem_grad_opt;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -66,13 +66,13 @@ string strip(const string& s) {
|
|||
}
|
||||
|
||||
void KernelIR::del_scope() {
|
||||
if (father && (type=="define" || type=="func")) {
|
||||
if (father && (type=="define" || type=="func" || type=="macro")) {
|
||||
father->scope[attrs["lvalue"]].remove(this);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelIR::add_scope() {
|
||||
if (father && (type=="define" || type=="func"))
|
||||
if (father && (type=="define" || type=="func" || type=="macro"))
|
||||
father->scope[get_attr("lvalue")].push_back(this);
|
||||
}
|
||||
|
||||
|
@ -485,6 +485,7 @@ string KernelIR::to_string(int level, bool debug) {
|
|||
auto iter = attrs.find("code");
|
||||
ASSERT(iter != attrs.end()) << attrs << type << father;
|
||||
s << iter->second << "\n";
|
||||
has_bc = attrs.count("has_bc");
|
||||
} else {
|
||||
s << "\n";
|
||||
}
|
||||
|
@ -590,6 +591,10 @@ KernelIR::KernelIR(const string& src, bool raw) {
|
|||
if (i==start && k==end) {
|
||||
attrs["code"] = src;
|
||||
type = "macro";
|
||||
auto v = split(src, " ", 3);
|
||||
ASSERT(v.size()>1);
|
||||
attrs["lvalue"] = v.at(1);
|
||||
attrs["rvalue"] = v.size()>2 ? v.at(2) : "";
|
||||
return;
|
||||
} else {
|
||||
push_back(src.substr(j, k-j), nullptr, raw);
|
||||
|
@ -650,6 +655,17 @@ KernelIR::KernelIR(const string& src, bool raw) {
|
|||
// func define
|
||||
if (s.size()>=2 && s.back()=='}') {
|
||||
int l = s.find("{");
|
||||
ASSERT(l != string::npos);
|
||||
if (startswith(s, "namespace ")) {
|
||||
// namespace xxx {...}
|
||||
// l
|
||||
attrs["code"] = s.substr(0, l);
|
||||
attrs["has_bc"] = "1";
|
||||
type = "";
|
||||
i = j + l;
|
||||
end--;
|
||||
continue;
|
||||
}
|
||||
int ll = s.rfind("(", l);
|
||||
int rr = s.rfind(")", l);
|
||||
// if () not found, maybe src like this:
|
||||
|
|
|
@ -63,6 +63,12 @@ void PassManager::run_passes() {
|
|||
LOGvvvv << "KernelIR:\n" << ir.to_string();
|
||||
if (oc->op->ops.size() == 1 && oc->op->ops[0]->name() == string("array")) {
|
||||
ir.remove_all_unused();
|
||||
if (oc->op->flags.get(NodeFlags::_cuda)) {
|
||||
ir.children.back()->erase();
|
||||
string type = oc->op->ops[0]->outputs().front()->dtype().to_cstring();
|
||||
ir.push_back("kernel<<<1,1>>>(op0_outputp, op0->ptr<"+type+">()[0]);");
|
||||
ir.push_back("__global__ static void kernel(jittor::"+type+"* xp, jittor::"+type+" x) { xp[0] = x; } ", &ir.before, true);
|
||||
}
|
||||
return;
|
||||
}
|
||||
run_pass<MarkRawPass>();
|
||||
|
|
|
@ -392,14 +392,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
continue;
|
||||
auto kh = w->shape[wformat.find("h")];
|
||||
auto kw = w->shape[wformat.find("w")];
|
||||
if (kh != kw) {
|
||||
LOGvvvv << "TODO: relay conv_backward_w when kh != kw" << kh << kw;
|
||||
continue;
|
||||
}
|
||||
LOGvvvv << x << y << kh << stride << padding << dilation << groups << xformat << wformat << yformat;
|
||||
auto make_conv_w = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, string, string, string>();
|
||||
rvar = make_conv_w(x, y, kh, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, string, string, string>();
|
||||
rvar = make_conv_w(x, y, kh, kw, stride, padding, dilation, groups, xformat, wformat, yformat);
|
||||
}
|
||||
|
||||
LOGvvvv << relay_conv_name << "output:" << rvar;
|
||||
|
|
|
@ -191,6 +191,7 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
FusedOp& fused_op = *fop_needs_compile[-rid-1];
|
||||
op = &fused_op;
|
||||
LOGvv << "Compile FusedOp:" << op;
|
||||
LOGV(11) << "FusedOps:" << fused_op.ops;
|
||||
fused_op.context = new FusedOpContext();
|
||||
fused_op.context->setup(&fused_op);
|
||||
fused_op.do_prepare();
|
||||
|
@ -238,11 +239,11 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
}
|
||||
}
|
||||
}; // end of threads.launch_all
|
||||
int active_threads = std::min(thread_num, (int)op_needs_compile.size());
|
||||
threads.launch_all(active_threads, func);
|
||||
|
||||
typedef std::chrono::high_resolution_clock Time;
|
||||
auto start = Time::now();
|
||||
int active_threads = std::min(thread_num, (int)op_needs_compile.size());
|
||||
threads.launch_all(active_threads, func);
|
||||
int prev_i = 0;
|
||||
bool change_line = false;
|
||||
int sleep_us = 10;
|
||||
|
@ -298,4 +299,4 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
}
|
||||
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
|
@ -691,4 +691,54 @@ DEF_IS(GradCallback, T) from_py_object(PyObject* obj) {
|
|||
return func;
|
||||
}
|
||||
|
||||
struct VarSlices;
|
||||
// Slice
|
||||
DEF_IS(VarSlices, bool) is_type(PyObject* obj) {
|
||||
return PyTuple_CheckExact(obj) ||
|
||||
PyLong_CheckExact(obj) ||
|
||||
PySlice_Check(obj) ||
|
||||
(Py_TYPE(obj) == &PyEllipsis_Type) ||
|
||||
obj == Py_None ||
|
||||
is_type<VarHolder*>(obj);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>& holders) {
|
||||
if (PyLong_CheckExact(obj)) {
|
||||
var_slice->set_int(PyLong_AsLong(obj));
|
||||
} else
|
||||
if (PySlice_Check(obj)) {
|
||||
var_slice->slice = from_py_object<decltype(var_slice->slice)>(obj);
|
||||
} else
|
||||
if (Py_TYPE(obj) == &PyEllipsis_Type) {
|
||||
var_slice->set_ellipsis();
|
||||
} else
|
||||
if (obj == Py_None) {
|
||||
var_slice->set_none();
|
||||
}else {
|
||||
holders.emplace_back();
|
||||
auto* vh = from_py_object<VarHolder*>(obj, holders.back());
|
||||
auto vv = (Var**)vh;
|
||||
var_slice->set_var(vv[0]);
|
||||
}
|
||||
}
|
||||
|
||||
DEF_IS(VarSlices, T) from_py_object(PyObject* obj, vector<unique_ptr<VarHolder>>& holders) {
|
||||
if (PyTuple_CheckExact(obj)) {
|
||||
auto size = Py_SIZE(obj);
|
||||
T vs(size);
|
||||
auto arr = PySequence_Fast_ITEMS(obj);
|
||||
for (int i=0; i<size; i++) {
|
||||
auto oi = arr[i];
|
||||
load_var_slice(oi, vs.slices+i, holders);
|
||||
}
|
||||
return vs;
|
||||
} else {
|
||||
T vs(1);
|
||||
load_var_slice(obj, vs.slices, holders);
|
||||
return vs;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "common.h"
|
||||
#include "misc/nano_vector.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims);
|
||||
|
||||
JIT_TEST(cuda_loop_schedule) {
|
||||
auto check = [&](const vector<int64>& shape, const vector<int>& masks, vector<int> tdims={}) {
|
||||
int masks2[shape.size()];
|
||||
int tdims2[6];
|
||||
cuda_loop_schedule(shape, masks2, tdims2);
|
||||
while (tdims.size() < 6) tdims.push_back(1);
|
||||
for (int i=0; i<shape.size(); i++)
|
||||
ASSERT(masks2[i] == masks[i]) << i << shape << masks << vector<int>(masks2, masks2+shape.size());
|
||||
for (int i=0; i<6; i++)
|
||||
ASSERT(tdims.at(i)==tdims2[i]) << tdims << vector<int>(tdims2, tdims2+6);
|
||||
};
|
||||
check({0}, {1}, {0,1,1,1,1,1});
|
||||
check({2,2,2,2}, {8, 4, 2, 1}, {2,2,2,2,1,1});
|
||||
check({2048,1024}, {8, 1}, {1024,1,1,2048,1,1});
|
||||
check({2048,1025}, {8, 1+(1<<6)}, {1024,1,1,2048,1,1});
|
||||
check({2048,3025}, {8, 1+(1<<6)}, {1024,1,1,2048,1,1});
|
||||
check({2048,4425}, {16, 1+8+(1<<6)}, {1024,1,1,5,2048,1});
|
||||
check({2048, 2048,4425}, {0, 16, 1+8+(1<<6)}, {1024,1,1,5,2048,1});
|
||||
check({3,3,3,4425}, {0, 32, 16, 1+8+(1<<6)}, {1024,1,1,5,3,3});
|
||||
check({3,3,3,4425, 3,3}, {0, 32, 16, 8+4+(1<<6), 2, 1}, {3,3,64,70,3,3});
|
||||
check({3,3,3,12, 9,9}, {32, 16, 8, 4, 2, 1}, {9,9,12,3,3,3});
|
||||
check({3,3,3,13, 9,9}, {32, 16, 8, 4+64, 2, 1}, {9,9,12,3,3,3});
|
||||
check({3,3,3,13*4, 9,9}, {0, 32, 16, 8+4+64, 2, 1}, {9,9,12,5,3,3});
|
||||
check({3,3,3,100, 3,3}, {32, 16, 8, 4+64, 2, 1}, {3,3,64,3,3,3});
|
||||
check({3,3,3,400, 3,3}, {0, 32, 16, 8+4+64, 2, 1}, {3,3,64,7,3,3});
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -305,9 +305,9 @@ int system_popen(const char* cmd) {
|
|||
|
||||
void system_with_check(const char* cmd) {
|
||||
auto ret = system_popen(cmd);
|
||||
CHECK(ret!=-1) << "Run cmd failed:" << cmd <<
|
||||
CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd <<
|
||||
"\nreturn -1. This might be an overcommit issue or out of memory."
|
||||
<< "Try : echo 1 >/proc/sys/vm/overcommit_memory";
|
||||
<< "Try : sudo sysctl vm.overcommit_memory=1";
|
||||
CHECKop(ret,==,0) << "Run cmd failed:" << cmd;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var_slices.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const VarSlices& vs) {
|
||||
os << '[';
|
||||
for (int i=0; i<vs.n; i++)
|
||||
os << vs.slices[i] << ",";
|
||||
return os << ']';
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const VarSlice& s) {
|
||||
if (s.is_var()) return os << s.var->dtype() << s.var->shape;
|
||||
if (s.is_ellipsis()) return os << "...";
|
||||
if (s.is_slice()) return os << s.slice;
|
||||
if (s.is_int()) return os << s.i;
|
||||
return os << "-";
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Slice& s) {
|
||||
if (!(s.mask & 1)) os << s.start;
|
||||
os << ':';
|
||||
if (!(s.mask & 2)) os << s.stop;
|
||||
os << ':';
|
||||
if (!(s.mask & 4)) os << s.step;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,67 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "misc/nano_vector.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct Slice;
|
||||
|
||||
union VarSlice {
|
||||
Slice slice;
|
||||
Var* var;
|
||||
int64 i;
|
||||
inline bool is_var() const { return slice.mask == -1; }
|
||||
inline bool is_ellipsis() const { return slice.mask == -2; }
|
||||
inline bool is_none() const { return slice.mask == -3; }
|
||||
inline bool is_int() const { return slice.mask == -4; }
|
||||
inline bool is_slice() const { return slice.mask >= 0; }
|
||||
inline void set_var(Var* v) { slice.mask = -1; var = v; }
|
||||
inline void set_ellipsis() { slice.mask = -2; }
|
||||
inline void set_none() { slice.mask = -3; }
|
||||
inline void set_int(int64 v) { slice.mask = -4; i = v; }
|
||||
};
|
||||
|
||||
struct VarSlices {
|
||||
VarSlice* slices;
|
||||
int n;
|
||||
inline VarSlices() : slices(nullptr) {}
|
||||
inline VarSlices(int n) : slices(new VarSlice[n]), n(n) {}
|
||||
inline ~VarSlices() {if (slices) delete[] slices;}
|
||||
inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) {
|
||||
other.slices = nullptr;
|
||||
}
|
||||
inline VarSlices(const VarSlices& other) : slices(new VarSlice[other.n]), n(other.n) {
|
||||
for (int i=0; i<n; i++)
|
||||
slices[i] = other.slices[i];
|
||||
}
|
||||
inline void operator=(VarSlices&& other) {
|
||||
if (slices) delete[] slices;
|
||||
n = other.n;
|
||||
slices = other.slices;
|
||||
other.slices = nullptr;
|
||||
}
|
||||
inline void operator=(const VarSlices& other) {
|
||||
if (slices) delete[] slices;
|
||||
slices = new VarSlice[other.n];
|
||||
n = other.n;
|
||||
for (int i=0; i<n; i++)
|
||||
slices[i] = other.slices[i];
|
||||
}
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const VarSlices& vs);
|
||||
std::ostream& operator<<(std::ostream& os, const VarSlice& s);
|
||||
std::ostream& operator<<(std::ostream& os, const Slice& s);
|
||||
|
||||
// @pyjt(_print_var_slice)
|
||||
inline void _print_var_slice(VarSlices&& vs) {
|
||||
LOGi << vs;
|
||||
}
|
||||
|
||||
} // jittor
|
Loading…
Reference in New Issue