mirror of https://github.com/Jittor/Jittor
add bmm_transpose and pad and merge_loop_flags
This commit is contained in:
parent
8d5acd276a
commit
06eb505eea
|
@ -76,6 +76,8 @@ def chunk(x, chunks, dim=0):
|
|||
>>> print(res[0].shape, res[1].shape)
|
||||
[5,3,3,] [5,3,3,]
|
||||
'''
|
||||
if dim<0:
|
||||
dim += x.ndim
|
||||
l = x.shape[dim]
|
||||
res = []
|
||||
if l <= chunks:
|
||||
|
@ -395,7 +397,7 @@ jt.Var.log2 = log2
|
|||
|
||||
def item(x):
|
||||
assert x.ndim==1 and x.shape[0]==1
|
||||
return x.data[0]
|
||||
return x.numpy().item()
|
||||
|
||||
jt.Var.item = item
|
||||
|
||||
|
@ -515,7 +517,7 @@ def gather(x,dim,index):
|
|||
ins.append(jt.index(index.shape,dim=i))
|
||||
ins[dim]=index
|
||||
return x.reindex(ins)
|
||||
|
||||
jt.Var.gather = gather
|
||||
|
||||
def prod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
|
|
|
@ -37,6 +37,18 @@ def matmul_transpose(a, b):
|
|||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
|
||||
def bmm_transpose(a, b):
|
||||
'''
|
||||
returns a * b^T
|
||||
'''
|
||||
if jt.flags.use_cuda:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 1)
|
||||
t = list(range(b.ndim))
|
||||
t[-1], t[-2] = t[-2], t[-1]
|
||||
return bmm(a, b.transpose(t))
|
||||
|
||||
|
||||
def bmm(a, b):
|
||||
''' batch matrix multiply,
|
||||
shape of input a is [batch, n, m],
|
||||
|
@ -276,6 +288,14 @@ def log_softmax(x,dim=None):
|
|||
def log_sigmoid(x):
|
||||
return jt.log(jt.sigmoid(x))
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def execute(self, input):
|
||||
return input
|
||||
|
||||
class Dropout(Module):
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
|
||||
|
@ -700,6 +720,30 @@ class ConvTranspose(Module):
|
|||
return y
|
||||
|
||||
|
||||
def pad(x,padding, mode='constant', value=0):
|
||||
assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad'
|
||||
assert len(padding)%2==0 and len(padding)//2<=x.ndim
|
||||
|
||||
padding = list(padding)
|
||||
left = [0]*(x.ndim-len(padding)//2)+padding[::2][::-1]
|
||||
right = [0]*(x.ndim-len(padding)//2)+padding[1::2][::-1]
|
||||
|
||||
out_dims = []
|
||||
out_shape = []
|
||||
for i,n,l,r in zip(range(x.ndim),x.shape,left,right):
|
||||
out_shape.append(n+l+r)
|
||||
if mode == 'constant':
|
||||
out_dims.append(f'i{i}-{l}')
|
||||
elif mode == 'replicate':
|
||||
out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n+l-1} ? {n-1} : i{i}-{l}")
|
||||
elif mode == 'reflect':
|
||||
out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n+l-1} ? {2*(n-1)+l}-i{i} : i{i}-{l}")
|
||||
elif mode == 'circular':
|
||||
out_dims.append(f"i{i}<{l} ? {n-l}+i{i} : i{i} > {n+l-1} ? i{i}-{n+l} : i{i}-{l}")
|
||||
|
||||
return x.reindex(out_shape,out_dims,overflow_value=value)
|
||||
|
||||
|
||||
class ReflectionPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
self.padding = padding
|
||||
|
|
|
@ -69,5 +69,18 @@ class TestPad(unittest.TestCase):
|
|||
check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1)))
|
||||
check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test function pad
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,3,224,224)
|
||||
padding = (10,11,2,3)
|
||||
for mode in ['constant','replicate','reflect','circular']:
|
||||
j_data = jt.array(arr)
|
||||
t_data = torch.tensor(arr)
|
||||
t_output = tnn.functional.pad(t_data,padding,mode=mode).detach().numpy()
|
||||
j_output = jnn.pad(j_data,padding,mode).numpy()
|
||||
assert np.allclose(t_output,j_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -238,12 +238,14 @@ class Hook:
|
|||
return
|
||||
if mod_name != "":
|
||||
mod_name = "<" + mod_name + ">"
|
||||
def forward_hook(self2, input, output):
|
||||
def forward_hook(self2, input, output, kw=None):
|
||||
ex_name = '[' + self2.__class__.__name__ + ']'
|
||||
if "relu" not in self2.__class__.__name__.lower():
|
||||
# not test relu, because input may be inplaced
|
||||
self.record(self2.__ad_mod_name__+".input", input, ex_name)
|
||||
self.record(self2.__ad_mod_name__+".output", output, ex_name)
|
||||
if kw is not None:
|
||||
self.record(self2.__ad_mod_name__+".kw", kw, ex_name)
|
||||
|
||||
names = []
|
||||
for name, module in mod.named_modules():
|
||||
|
|
|
@ -10,6 +10,9 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
|
||||
DEFINE_FLAG(int, merge_loop_mismatch_threshold, 2, "");
|
||||
|
||||
void MergeLoopPass::run() {
|
||||
auto choice = op->get_loop_option("merge", 1);
|
||||
if (!choice) return;
|
||||
|
@ -44,7 +47,7 @@ void MergeLoopPass::run() {
|
|||
while (cpx < ki.size() && cpx<kj.size() && ki[cpx] == kj[cpx]) cpx++;
|
||||
int mismatch = std::max(ki.size(), kj.size()) - cpx;
|
||||
LOGvvvv << "loop key " << ki << kj << "mismatch" << mismatch;
|
||||
if (mismatch>=2 || cpx==0)
|
||||
if (mismatch>=merge_loop_mismatch_threshold || cpx==0)
|
||||
continue;
|
||||
loops[i]->insert(0, loops[j]->children);
|
||||
loops[i]->merge_loop();
|
||||
|
|
Loading…
Reference in New Issue