mirror of https://github.com/Jittor/Jittor
update nn.dropout
This commit is contained in:
parent
68608dd74b
commit
a4d4bfe16a
|
@ -219,10 +219,11 @@ class Dropout(Module):
|
|||
if self.p > 0 and self.is_train:
|
||||
if self.p == 1:
|
||||
noise = jt.zeros(input.shape)
|
||||
output = output * noise
|
||||
else:
|
||||
noise = jt.random(input.shape)
|
||||
noise = (noise > self.p).int()
|
||||
output = output * noise
|
||||
output = output * noise / (1.0 - self.p) # div keep prob
|
||||
return output
|
||||
|
||||
class Linear(Module):
|
||||
|
|
Loading…
Reference in New Issue