mirror of https://github.com/Jittor/Jittor
Merge remote-tracking branch 'origin/gmh' into models
This commit is contained in:
commit
2c8124d144
|
@ -220,10 +220,11 @@ class Dropout(Module):
|
|||
if self.p > 0 and self.is_train:
|
||||
if self.p == 1:
|
||||
noise = jt.zeros(input.shape)
|
||||
output = output * noise
|
||||
else:
|
||||
noise = jt.random(input.shape)
|
||||
noise = (noise > self.p).int()
|
||||
output = output * noise
|
||||
output = output * noise / (1.0 - self.p) # div keep prob
|
||||
return output
|
||||
|
||||
class Linear(Module):
|
||||
|
|
Loading…
Reference in New Issue