mirror of https://github.com/Jittor/Jittor
add dropout 2d
This commit is contained in:
parent
f4f327bd12
commit
74932f3c32
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.28'
|
||||
__version__ = '1.3.5.30'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -546,6 +546,34 @@ class Dropout(Module):
|
|||
def dropout(x,p=0.5,is_train=False):
|
||||
return Dropout(p=p,is_train=is_train)(x)
|
||||
|
||||
class Dropout2d(Module):
|
||||
def __init__(self, p=0.5, is_train=False):
|
||||
'''
|
||||
Randomly zero out entire channels, from "Efficient Object Localization Using Convolutional Networks"
|
||||
input:
|
||||
x: [N,C,H,W] or [N,C,L]
|
||||
output:
|
||||
y: same shape as x
|
||||
'''
|
||||
assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p)
|
||||
self.p = p
|
||||
self.is_train = is_train
|
||||
#TODO: test model.train() to change self.is_train
|
||||
def execute(self, input):
|
||||
output = input
|
||||
shape = input.shape[:-2]
|
||||
if self.p > 0 and self.is_train:
|
||||
if self.p == 1:
|
||||
output = jt.zeros(input.shape)
|
||||
else:
|
||||
noise = jt.random(shape)
|
||||
noise = (noise > self.p).int()
|
||||
output = output * noise.broadcast(input.shape, dims=[-2,-1]) / (1.0 - self.p) # div keep prob
|
||||
return output
|
||||
|
||||
def dropout2d(x,p=0.5,is_train=False):
|
||||
return Dropout(p=p,is_train=is_train)(x)
|
||||
|
||||
class DropPath(Module):
|
||||
'''Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
'''
|
||||
|
|
|
@ -336,5 +336,12 @@ class TestOther(unittest.TestCase):
|
|||
with jt.flag_scope(use_cuda=1):
|
||||
self.test_nan()
|
||||
|
||||
def test_dropout2d(self):
|
||||
m = jt.nn.Dropout2d(p=0.2)
|
||||
m.train()
|
||||
input = jt.randn(1, 10, 4, 3)
|
||||
output = m(input)
|
||||
output.sync()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue