update nn.dropout

This commit is contained in:
gmh 2020-04-16 22:52:19 +08:00
parent 68608dd74b
commit a4d4bfe16a
1 changed files with 2 additions and 1 deletions

View File

@ -219,10 +219,11 @@ class Dropout(Module):
if self.p > 0 and self.is_train:
if self.p == 1:
noise = jt.zeros(input.shape)
output = output * noise
else:
noise = jt.random(input.shape)
noise = (noise > self.p).int()
output = output * noise
output = output * noise / (1.0 - self.p) # div keep prob
return output
class Linear(Module):