mirror of https://github.com/Jittor/Jittor
206 lines
7.2 KiB
Python
206 lines
7.2 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2020 Jittor. Authors:
|
|
# Guowei Yang <471184555@qq.com>
|
|
# Guoye Yang <498731903@qq.com>
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
# All Rights Reserved.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import jittor as jt
|
|
import numpy as np
|
|
|
|
def argmax_pool(x, size, stride, padding=0):
|
|
y_shape = list(x.shape)
|
|
y_shape[2]=(x.shape[2]+padding*2-size)//stride+1
|
|
y_shape[3]=(x.shape[3]+padding*2-size)//stride+1
|
|
|
|
y = jt.code(y_shape, x.dtype, [x],
|
|
cpu_src=f'''
|
|
for (int i=0; i<out_shape0; i++)
|
|
for (int j=0; j<out_shape1; j++)
|
|
for (int k=0; k<out_shape2; k++)
|
|
for (int l=0; l<out_shape3; l++) {{
|
|
int kx=k*{stride}+{size}/2-{padding};
|
|
int ky=l*{stride}+{size}/2-{padding};
|
|
@out(i,j,k,l) = @in0(i,j,kx,ky);
|
|
for (int p=kx-{size}/2;p<=kx+{size}/2;p++)
|
|
for (int q=ky-{size}/2;q<=ky+{size}/2;q++)
|
|
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
|
if (@out(i,j,k,l) < @in0(i,j,p,q))
|
|
@out(i,j,k,l) = @in0(i,j,p,q);
|
|
}}
|
|
''',
|
|
cpu_grad_src = [f'''
|
|
for (int i=0; i<out_shape0; i++)
|
|
for (int j=0; j<out_shape1; j++)
|
|
for (int k=0; k<out_shape2; k++)
|
|
for (int l=0; l<out_shape3; l++) @out(i,j,k,l) = 0;
|
|
|
|
for (int i=0; i<pout_shape0; i++)
|
|
for (int j=0; j<pout_shape1; j++)
|
|
for (int k=0; k<pout_shape2; k++)
|
|
for (int l=0; l<pout_shape3; l++) {{
|
|
int kx=k*{stride}+{size}/2-{padding};
|
|
int ky=l*{stride}+{size}/2-{padding};
|
|
int bo=1;
|
|
for (int p=kx-{size}/2;p<=kx+{size}/2 && bo;p++)
|
|
for (int q=ky-{size}/2;q<=ky+{size}/2 && bo;q++)
|
|
if (p>=0 && q>=0 && p<in0_shape2 && q<in0_shape3)
|
|
if (@pout(i,j,k,l) == @in0(i,j,p,q)) {{
|
|
@out(i,j,p,q) += @dout(i,j,k,l);
|
|
bo=0;
|
|
}}
|
|
}}
|
|
'''])
|
|
return y
|
|
|
|
def concat(arr, dim):
|
|
# TODO: low performance when concat lots of vars
|
|
total_dim = 0
|
|
if dim < 0: dim += len(arr[0].shape)
|
|
for a in arr:
|
|
total_dim += a.shape[dim]
|
|
cdim = 0
|
|
s = None
|
|
indexes = [ f"i{i}" for i in range(len(a.shape)) ]
|
|
for a in arr:
|
|
shape = list(a.shape)
|
|
shape[dim] = total_dim
|
|
indexes[dim] = f"i{dim}-{cdim}"
|
|
b = a.reindex(shape, indexes)
|
|
# ugly fix for preventing large fused op
|
|
if len(arr)>=10:
|
|
b.stop_fuse()
|
|
if s is None:
|
|
s = b
|
|
else:
|
|
s += b
|
|
cdim += a.shape[dim]
|
|
return s
|
|
|
|
def check(bc):
|
|
bc = np.array(bc)
|
|
if ((bc != 1) * (bc != bc.max(0))).sum() > 0:
|
|
raise Exception(f"Shape not match.")
|
|
else:
|
|
return bc.max(0)
|
|
|
|
def slice_var_index(x, slices):
|
|
if not isinstance(slices, tuple):
|
|
slices = (slices,)
|
|
if isinstance(slices[0], jt.Var):
|
|
if len(slices) == 1 and slices[0].dtype == "bool":
|
|
return (slices[0].where(),)
|
|
bc = []
|
|
ml = -1
|
|
for idx, s in enumerate(slices):
|
|
if isinstance(s, jt.Var):
|
|
shape = s.shape
|
|
elif isinstance(s, np.ndarray):
|
|
shape = list(s.shape)
|
|
elif isinstance(s, list):
|
|
shape = list(np.array(s).shape)
|
|
else:
|
|
continue
|
|
if len(shape) >= ml:
|
|
ml = len(shape)
|
|
bc.append(shape)
|
|
for idx, shape in enumerate(bc):
|
|
if len(shape) < ml:
|
|
shape = (ml - len(shape)) * [1] + shape
|
|
bc[idx] = shape
|
|
if len(bc) >= 1:
|
|
bc_shape = check(bc)
|
|
ss = []
|
|
for idx, s in enumerate(slices):
|
|
if isinstance(s, np.ndarray) or isinstance(s, list):
|
|
ss.append(jt.array(s).broadcast(bc_shape.tolist()))
|
|
elif isinstance(s, jt.Var):
|
|
ss.append(s.broadcast(bc_shape.tolist()))
|
|
else:
|
|
ss.append(s)
|
|
slices = ss
|
|
out_shape = []
|
|
out_index = []
|
|
shape = x.shape
|
|
cnt_list = 0
|
|
extras_idx = []
|
|
extras = []
|
|
for i in range(len(shape)):
|
|
if i>=len(slices):
|
|
s = slice(None)
|
|
else:
|
|
s = slices[i]
|
|
sp = shape[i]
|
|
j = len(out_shape)
|
|
if isinstance(s, int):
|
|
if s<0: s += sp
|
|
out_index.append(str(s))
|
|
elif isinstance(s, slice):
|
|
if s == slice(None):
|
|
out_shape.append(sp)
|
|
out_index.append(f"i{j}")
|
|
continue
|
|
start = 0 if s.start is None else s.start
|
|
stop = sp if s.stop is None else s.stop
|
|
step = 1 if s.step is None else s.step
|
|
if start<0: start += sp
|
|
if stop<0: stop += sp
|
|
out_shape.append(1+int(max(0, (stop-start-1)//step)))
|
|
out_index.append(f"{start}+i{j}*{step}")
|
|
elif isinstance(s, jt.Var):
|
|
if cnt_list == 0:
|
|
for idx in range(len(bc_shape)):
|
|
extras_idx.append(f"i{len(out_shape) + idx}")
|
|
out_shape += bc_shape.tolist()
|
|
out_index.append(f"@e{cnt_list}("+ ",".join(extras_idx) + ")")
|
|
cnt_list += 1
|
|
extras.append(s)
|
|
else:
|
|
raise Exception(f"Not support slice {s}")
|
|
if len(out_shape)==0:
|
|
out_shape = [1]
|
|
# Stop fuse both input and output, prevent recompile
|
|
x.stop_fuse()
|
|
return (out_shape, out_index, 0, [], extras)
|
|
|
|
def slice_var(x, slices):
|
|
reindex_args = slice_var_index(x, slices)
|
|
x.stop_fuse()
|
|
return x.reindex(*reindex_args).stop_fuse()
|
|
|
|
def setitem(x, slices, value):
|
|
reindex_args = slice_var_index(x, slices)
|
|
reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
|
|
xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
|
|
value = jt.broadcast(value, xslice)
|
|
one = jt.broadcast(1, xslice)
|
|
if not isinstance(reindex_args[0][0], jt.Var):
|
|
reindex_args = (x.shape,) + reindex_args[1:]
|
|
mask = one.reindex_reduce("add", *reindex_reduce_args)
|
|
data = value.reindex_reduce("add", *reindex_reduce_args)
|
|
# Stop fuse both input and output, prevent recompile
|
|
out = mask.ternary(data, x).stop_fuse()
|
|
x.assign(out)
|
|
return x
|
|
|
|
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
|
jt.Var.__setitem__ = setitem
|
|
|
|
def adam(model, loss, lr=3e-4, betas=[0.9, 0.999], eps=1e-8):
|
|
ps = jt.find_vars(model)
|
|
gs = jt.grad(loss, ps)
|
|
with jt.var_scope('_'.join([model, 'adam']), unique=True):
|
|
adam_step = jt.make_var([1], init=jt.zeros)
|
|
adam_step += 1
|
|
for p,g in zip(ps,gs):
|
|
m = jt.make_var(p.shape, init=jt.zeros)
|
|
v = jt.make_var(p.shape, init=jt.zeros)
|
|
|
|
m.assign(betas[0] * m + (1-betas[0]) * g)
|
|
v.assign(betas[1] * v + (1-betas[1]) * g * g)
|
|
step_size = lr * jt.sqrt(1-betas[1]**adam_step) / (1-betas[0] ** adam_step)
|
|
p -= m * step_size / (jt.sqrt(v) + eps)
|
|
|