add bmm_transpose and pad and merge_loop_flags

This commit is contained in:
li-xl 2020-11-03 17:04:25 +08:00
parent 8d5acd276a
commit 06eb505eea
5 changed files with 68 additions and 4 deletions

View File

@ -76,6 +76,8 @@ def chunk(x, chunks, dim=0):
>>> print(res[0].shape, res[1].shape)
[5,3,3,] [5,3,3,]
'''
if dim<0:
dim += x.ndim
l = x.shape[dim]
res = []
if l <= chunks:
@ -395,7 +397,7 @@ jt.Var.log2 = log2
def item(x):
assert x.ndim==1 and x.shape[0]==1
return x.data[0]
return x.numpy().item()
jt.Var.item = item
@ -515,7 +517,7 @@ def gather(x,dim,index):
ins.append(jt.index(index.shape,dim=i))
ins[dim]=index
return x.reindex(ins)
jt.Var.gather = gather
def prod(x,dim=0):
x = jt.log(x)

View File

@ -37,6 +37,18 @@ def matmul_transpose(a, b):
b = b.broadcast(shape)
return (a*b).sum(len(shape)-1)
def bmm_transpose(a, b):
'''
returns a * b^T
'''
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 1)
t = list(range(b.ndim))
t[-1], t[-2] = t[-2], t[-1]
return bmm(a, b.transpose(t))
def bmm(a, b):
''' batch matrix multiply,
shape of input a is [batch, n, m],
@ -276,6 +288,14 @@ def log_softmax(x,dim=None):
def log_sigmoid(x):
return jt.log(jt.sigmoid(x))
class Identity(Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def execute(self, input):
return input
class Dropout(Module):
def __init__(self, p=0.5, is_train=False):
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
@ -700,6 +720,30 @@ class ConvTranspose(Module):
return y
def pad(x,padding, mode='constant', value=0):
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
assert len(padding)%2==0 and len(padding)//2<=x.ndim
padding = list(padding)
left = [0]*(x.ndim-len(padding)//2)+padding[::2][::-1]
right = [0]*(x.ndim-len(padding)//2)+padding[1::2][::-1]
out_dims = []
out_shape = []
for i,n,l,r in zip(range(x.ndim),x.shape,left,right):
out_shape.append(n+l+r)
if mode == 'constant':
out_dims.append(f'i{i}-{l}')
elif mode == 'replicate':
out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n+l-1} ? {n-1} : i{i}-{l}")
elif mode == 'reflect':
out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n+l-1} ? {2*(n-1)+l}-i{i} : i{i}-{l}")
elif mode == 'circular':
out_dims.append(f"i{i}<{l} ? {n-l}+i{i} : i{i} > {n+l-1} ? i{i}-{n+l} : i{i}-{l}")
return x.reindex(out_shape,out_dims,overflow_value=value)
class ReflectionPad2d(Module):
def __init__(self, padding):
self.padding = padding

View File

@ -69,5 +69,18 @@ class TestPad(unittest.TestCase):
check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1)))
check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
# ***************************************************************
# Test function pad
# ***************************************************************
arr = np.random.randn(16,3,224,224)
padding = (10,11,2,3)
for mode in ['constant','replicate','reflect','circular']:
j_data = jt.array(arr)
t_data = torch.tensor(arr)
t_output = tnn.functional.pad(t_data,padding,mode=mode).detach().numpy()
j_output = jnn.pad(j_data,padding,mode).numpy()
assert np.allclose(t_output,j_output)
if __name__ == "__main__":
unittest.main()

View File

@ -238,12 +238,14 @@ class Hook:
return
if mod_name != "":
mod_name = "<" + mod_name + ">"
def forward_hook(self2, input, output):
def forward_hook(self2, input, output, kw=None):
ex_name = '[' + self2.__class__.__name__ + ']'
if "relu" not in self2.__class__.__name__.lower():
# not test relu, because input may be inplaced
self.record(self2.__ad_mod_name__+".input", input, ex_name)
self.record(self2.__ad_mod_name__+".output", output, ex_name)
if kw is not None:
self.record(self2.__ad_mod_name__+".kw", kw, ex_name)
names = []
for name, module in mod.named_modules():

View File

@ -10,6 +10,9 @@
namespace jittor {
DEFINE_FLAG(int, merge_loop_mismatch_threshold, 2, "");
void MergeLoopPass::run() {
auto choice = op->get_loop_option("merge", 1);
if (!choice) return;
@ -44,7 +47,7 @@ void MergeLoopPass::run() {
while (cpx < ki.size() && cpx<kj.size() && ki[cpx] == kj[cpx]) cpx++;
int mismatch = std::max(ki.size(), kj.size()) - cpx;
LOGvvvv << "loop key " << ki << kj << "mismatch" << mismatch;
if (mismatch>=2 || cpx==0)
if (mismatch>=merge_loop_mismatch_threshold || cpx==0)
continue;
loops[i]->insert(0, loops[j]->children);
loops[i]->merge_loop();