mirror of https://github.com/Jittor/Jittor
320 lines
10 KiB
Python
320 lines
10 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import unittest
|
|
import jittor as jt
|
|
import numpy as np
|
|
from .test_core import expect_error
|
|
from .test_grad import ngrad
|
|
|
|
def conv(x, w):
|
|
N,H,W,C = x.shape
|
|
Kh, Kw, _C, Kc = w.shape
|
|
assert C==_C
|
|
xx = x.reindex([N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc], [
|
|
'i0', # Nid
|
|
'i1-i3', # Hid+Khid
|
|
'i2-i4', # Wid+KWid
|
|
'i5', # Cid
|
|
])
|
|
ww = w.broadcast_var(xx)
|
|
yy = xx*ww
|
|
y = yy.sum([3,4,5]) # Kh, Kw, C
|
|
return y, yy
|
|
|
|
def conv_naive(x, w):
|
|
N,H,W,C = x.shape
|
|
Kh, Kw, _C, Kc = w.shape
|
|
assert C==_C
|
|
y = np.zeros([N,H+Kh-1,W+Kw-1,Kc])
|
|
for i0 in range(N):
|
|
for i1 in range(H+Kh-1):
|
|
for i2 in range(W+Kw-1):
|
|
for i3 in range(Kh):
|
|
for i4 in range(Kw):
|
|
for i5 in range(C):
|
|
for i6 in range(Kc):
|
|
if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continue
|
|
y[i0, i1, i2, i6] += x[i0, i1-i3, i2-i4, i5] * w[i3,i4,i5,i6]
|
|
return y
|
|
|
|
def conv_transpose(x, w):
|
|
N,H,W,C = x.shape
|
|
Kh, Kw, _C, Kc = w.shape
|
|
assert C==_C
|
|
xx = x.reindex([N,H*2+Kh-1,W*2+Kw-1,Kh,Kw,C,Kc], [
|
|
'i0', # Nid
|
|
'(i1-i3)/2', # Hid+Khid
|
|
'(i2-i4)/2', # Wid+KWid
|
|
'i5', # Cid
|
|
], 0, ['(i1-i3)%2', '(i2-i4)%2'])
|
|
ww = w.broadcast_var(xx)
|
|
yy = xx*ww
|
|
y = yy.sum([3,4,5]) # Kh, Kw, C
|
|
return y, yy
|
|
|
|
def conv_transpose_naive(x, w):
|
|
N,H,W,C = x.shape
|
|
Kh, Kw, _C, Kc = w.shape
|
|
assert C==_C
|
|
y = np.zeros([N,H*2+Kh-1,W*2+Kw-1,Kc])
|
|
for i0 in range(N):
|
|
for i1 in range(H*2+Kh-1):
|
|
for i2 in range(W*2+Kw-1):
|
|
for i3 in range(Kh):
|
|
for i4 in range(Kw):
|
|
for i5 in range(C):
|
|
for i6 in range(Kc):
|
|
if (i1-i3)//2<0 or (i2-i4)//2<0 or (i1-i3)//2>=H or (i2-i4)//2>=W: continue
|
|
if (i1-i3)%2 or (i2-i4)%2: continue
|
|
y[i0, i1, i2, i6] += x[i0, (i1-i3)//2, (i2-i4)//2, i5] * w[i3,i4,i5,i6]
|
|
return y
|
|
|
|
|
|
def is_fused(x):
|
|
return 's0' in x.debug_msg()
|
|
|
|
def check_fused(dim):
|
|
jt.clean()
|
|
graph = jt.dump_all_graphs()
|
|
fused = True
|
|
has_v = False
|
|
for node in graph.nodes_info:
|
|
shape = node.split('[')[-1].split(',')
|
|
ndim = len(shape)-1
|
|
if ndim>dim:
|
|
has_v = True
|
|
if 's0' not in node:
|
|
fused = False
|
|
assert fused and has_v, graph.nodes_info
|
|
|
|
def resize_and_crop(x, bbox, interpolation="nearest"):
|
|
N, k = bbox.shape
|
|
H, W = x.shape
|
|
assert k==4
|
|
shape = [N,H,W]
|
|
# fx x cx
|
|
# +------------>
|
|
# fy | a dx | b
|
|
# | dy
|
|
# y | - o -
|
|
# |
|
|
# cy | c | d
|
|
# v
|
|
img = x
|
|
bb = [ bbox.reindex(shape, ["i0", str(i)]) for i in range(4) ]
|
|
hid = jt.index(shape, 1)
|
|
wid = jt.index(shape, 2)
|
|
one = jt.float(1).broadcast(shape)
|
|
x = bb[0]*jt.float(H-1)+hid*(bb[2]-bb[0])
|
|
y = bb[1]*jt.float(W-1)+wid*(bb[3]-bb[1])
|
|
if interpolation=="nearest":
|
|
return img.reindex_var([x.round_int(), y.round_int()])
|
|
if interpolation=="bilinear":
|
|
fx, fy = x.floor_int(), y.floor_int()
|
|
cx, cy = fx+one, fy+one
|
|
dx, dy = x-fx, y-fy
|
|
a = img.reindex_var([fx, fy])
|
|
b = img.reindex_var([cx, fy])
|
|
c = img.reindex_var([fx, cy])
|
|
d = img.reindex_var([cx, cy])
|
|
dnx, dny = one-dx, one-dy
|
|
ab = dx*b + dnx*a
|
|
cd = dx*d + dnx*c
|
|
o = ab*dny + cd*dy
|
|
return o
|
|
raise(f"Not support {interpolation}")
|
|
|
|
|
|
def resize_and_crop_naive(x, bbox, interpolation="nearest"):
|
|
N, k = bbox.shape
|
|
H, W = x.shape
|
|
assert k==4
|
|
y = np.zeros([N,H,W])
|
|
if interpolation=="nearest":
|
|
for i in range(N):
|
|
for j in range(H):
|
|
for k in range(W):
|
|
nj = int(round(bbox[i,0]*(H-1)+j*(bbox[i,2]-bbox[i,0])))
|
|
nk = int(round(bbox[i,1]*(W-1)+k*(bbox[i,3]-bbox[i,1])))
|
|
if nk<0 or nk>=W or nj<0 or nj>=H:
|
|
y[i,j,k] = 0
|
|
else:
|
|
y[i,j,k] = x[nj,nk]
|
|
return y
|
|
else: # bilinear
|
|
# fx x cx
|
|
# +------------>
|
|
# fy | a dx | b
|
|
# | dy
|
|
# y | - o -
|
|
# |
|
|
# cy | c | d
|
|
# v
|
|
from math import floor, ceil
|
|
data = x
|
|
output = y
|
|
sample = lambda nj, nk: 0 if nk<0 or nk>=W or nj<0 or nj>=H else data[nj,nk]
|
|
for i in range(N):
|
|
for j in range(H):
|
|
for k in range(W):
|
|
x = bbox[i,0]*(H-1)+j*(bbox[i,2]-bbox[i,0])
|
|
y = bbox[i,1]*(W-1)+k*(bbox[i,3]-bbox[i,1])
|
|
fx, fy = floor(x), floor(y)
|
|
cx, cy = fx+1, fy+1
|
|
a = sample(fx, fy)
|
|
b = sample(cx, fy)
|
|
c = sample(fx, cy)
|
|
d = sample(cx, cy)
|
|
dx, dy = x-fx, y-fy
|
|
dnx, dny = 1-dx, 1-dy
|
|
ab = dx*b + dnx*a
|
|
cd = dx*d + dnx*c
|
|
o = ab*dny + cd*dy
|
|
output[i,j,k] = o
|
|
return output
|
|
|
|
class TestReindexOp(unittest.TestCase):
|
|
def test_pad(self):
|
|
size = 10
|
|
lpad = 3
|
|
rpad = 4
|
|
a = jt.random([size])
|
|
b = a.reindex([size+lpad+rpad], [f"i0-{lpad}"], -1)
|
|
na, nb = jt.fetch_sync([a, b])
|
|
assert (nb[lpad:lpad+size]==na).all()
|
|
assert (nb[:lpad]==-1).all()
|
|
assert (nb[-rpad:]==-1).all()
|
|
|
|
def test_matmul(self):
|
|
size = 10
|
|
a = jt.random([size,size])
|
|
b = jt.random([size,size])
|
|
cc = a.reindex([size,size,size],["i0","i1"]) * \
|
|
b.reindex([size,size,size],["i1","i2"])
|
|
c = cc.sum(dim=1)
|
|
na, nb, nc = jt.fetch_sync([a, b, c])
|
|
assert is_fused(cc)
|
|
assert not is_fused(c)
|
|
check_fused(len(a.shape))
|
|
npc = np.matmul(na,nb)
|
|
assert np.allclose(npc, nc)
|
|
|
|
def test_conv(self):
|
|
N,H,W,C = 3,10,10,3
|
|
Kh, Kw, Kc = 3, 3, 4
|
|
x = jt.random([N,H,W,C])
|
|
w = jt.random([Kh,Kw,C,Kc])
|
|
y, yy = conv(x, w)
|
|
ny = y.data
|
|
assert ny.shape == (N, H+Kh-1, W+Kw-1, Kc), (ny.shape, [N, H+Kh-1, W+Kw-1, Kc])
|
|
assert is_fused(yy)
|
|
check_fused(len(x.shape))
|
|
npy = conv_naive(x.data, w.data)
|
|
assert np.allclose(npy, ny)
|
|
|
|
def test_conv_transpose(self):
|
|
N,H,W,C = 3,10,10,3
|
|
Kh, Kw, Kc = 3, 3, 4
|
|
x = jt.random([N,H,W,C])
|
|
w = jt.random([Kh,Kw,C,Kc])
|
|
y, yy = conv_transpose(x, w)
|
|
ny = y.data
|
|
assert is_fused(yy)
|
|
check_fused(len(x.shape))
|
|
npy = conv_transpose_naive(x.data, w.data)
|
|
assert np.allclose(npy, ny), (np.where(np.abs(npy-ny)>1e-4), npy[0,:4,:4,0], ny[0,:4,:4,0])
|
|
|
|
|
|
def test_conv_transpose_group(self):
|
|
N,C,H,W = 3,6,10,10
|
|
i,o,h,w = 6,2,3,3
|
|
g = 2
|
|
x = jt.random([N,C,H,W])
|
|
ww = jt.random([i,o,h,w])
|
|
ct = jt.nn.ConvTranspose(i,o*g,(h,w), groups=2, bias=False)
|
|
assert ct.weight.shape == ww.shape, (ct.weight.shape, ww.shape)
|
|
ct.weight = ww
|
|
y = ct(x)
|
|
y2 = jt.nn.conv_transpose(x, ww, groups=2)
|
|
np.testing.assert_allclose(y.data, y2.data)
|
|
|
|
def test_conv_transpose_grad(self):
|
|
N,H,W,C = 1,5,5,2
|
|
Kh, Kw, Kc = 3, 3, 2
|
|
x = jt.random([N,H,W,C])
|
|
w = jt.random([Kh,Kw,C,Kc])
|
|
y, yy = conv_transpose(x, w)
|
|
mask = jt.random(y.shape)
|
|
loss = (y*mask).sum()
|
|
dx, dw = jt.grad(loss, [x, w])
|
|
jdx, jdw = jt.fetch_sync([dx, dw])
|
|
check_fused(len(x.shape))
|
|
nmask = mask.data
|
|
_, (ndx, ndw) = ngrad(lambda args: \
|
|
(conv_transpose_naive(args[0], args[1])*nmask).sum(),
|
|
[np.float64(x.data), np.float64(w.data)], 1e-7)
|
|
assert np.allclose(ndx, jdx), (ndx, jdx, ndx-jdx)
|
|
assert np.allclose(ndw, jdw), (ndw, jdw)
|
|
|
|
def test_resize_and_crop(self):
|
|
jt.set_seed(3)
|
|
N, H, W = 4, 5, 5
|
|
for interpolation in ["bilinear", "nearest"]:
|
|
x = jt.random([H, W])
|
|
# x = jt.ones([H, W])
|
|
bbox = jt.random([N, 4])
|
|
# bbox = jt.float([[0.51,0.71,0.61,0.81]])
|
|
# bbox = jt.float([[0,0,1,1]])
|
|
y = resize_and_crop(x, bbox, interpolation)
|
|
ny = resize_and_crop_naive(x.data, bbox.data, interpolation)
|
|
assert np.allclose(y.data, ny), (y.data, ny, x.data)
|
|
|
|
# test grad
|
|
mask = jt.random(y.shape)
|
|
# mask = jt.ones(y.shape)
|
|
nmask = mask.data
|
|
import gc; gc.collect()
|
|
loss = y*mask
|
|
dx, dbbox = jt.grad(loss, [x, bbox])
|
|
_, (ndx, ndbbox) = ngrad(lambda args: \
|
|
(resize_and_crop_naive(args[0], args[1], interpolation)*nmask).sum(),
|
|
[np.float64(x.data), np.float64(bbox.data)], 1e-7)
|
|
assert np.allclose(y.data, ny), (y.data, ny, x.data)
|
|
assert np.allclose(ndx, dx.data, 1e-2), (ndx, dx.data)
|
|
assert np.allclose(ndbbox, dbbox.data, 1e-2), (ndbbox, dbbox.data)
|
|
|
|
|
|
|
|
def test_doc(self):
|
|
assert "Reindex Operator" in jt.reindex.__doc__
|
|
|
|
|
|
|
|
def test_reindex_fuse_error(self):
|
|
a = jt.zeros([10,10])
|
|
b = jt.array([1])
|
|
c = a.reindex([8,8], ["@e0(0)", "@e1(0,i0 / @e0(0))"], extras=[b, jt.ones([10,10])])
|
|
c.sync()
|
|
# print(c)
|
|
|
|
def test_reindex_wrong_op(self):
|
|
a = jt.zeros([10,10])
|
|
b = jt.array([1])
|
|
c = a.reindex([8,8], ["@e0(0) // 1", "@e0(0)"], extras=[b, b])
|
|
expect_error(lambda: c.sync())
|
|
|
|
|
|
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
|
class TestReindexOpCuda(TestReindexOp):
|
|
def setUp(self):
|
|
# TODO: replace to 2
|
|
jt.flags.use_cuda = 1
|
|
def tearDown(self):
|
|
jt.flags.use_cuda = 0
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |