Merge remote-tracking branch 'origin/gmh' into models

This commit is contained in:
Dun Liang 2020-04-20 15:28:23 +08:00
commit 2c8124d144
1 changed files with 2 additions and 1 deletions

View File

@ -220,10 +220,11 @@ class Dropout(Module):
if self.p > 0 and self.is_train:
if self.p == 1:
noise = jt.zeros(input.shape)
output = output * noise
else:
noise = jt.random(input.shape)
noise = (noise > self.p).int()
output = output * noise
output = output * noise / (1.0 - self.p) # div keep prob
return output
class Linear(Module):