mirror of https://github.com/Jittor/Jittor
fix init no_grad
This commit is contained in:
parent
bb2d187dd3
commit
d87b2d8f27
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.61'
|
||||
__version__ = '1.2.2.62'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -92,14 +92,12 @@ def calculate_std(var,mode,nonlinearity,param=0.01):
|
|||
|
||||
def kaiming_uniform_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
std = calculate_std(var,mode,nonlinearity,a)
|
||||
bound = math.sqrt(3.0) * std
|
||||
with jt.no_grad():
|
||||
return uniform_(var,-bound, bound)
|
||||
bound = math.sqrt(3.0) * std
|
||||
return uniform_(var,-bound, bound)
|
||||
|
||||
def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
std = calculate_std(var,mode,nonlinearity,a)
|
||||
with jt.no_grad():
|
||||
return gauss_(var,0, std)
|
||||
return gauss_(var,0, std)
|
||||
|
||||
|
||||
#TODO: bound = gain * math.sqrt(6.0/fan) ??
|
||||
|
|
Loading…
Reference in New Issue