mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
b7fff3072a
|
@ -100,7 +100,6 @@ def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
return gauss_(var,0, std)
|
||||
|
||||
|
||||
#TODO: bound = gain * math.sqrt(6.0/fan) ??
|
||||
def xavier_uniform(shape, dtype, gain=1.0):
|
||||
assert len(shape)>1
|
||||
|
||||
|
@ -108,7 +107,7 @@ def xavier_uniform(shape, dtype, gain=1.0):
|
|||
for i in shape[2:]:
|
||||
matsize *= i
|
||||
fan = (shape[1] * matsize) + (shape[0] * matsize)
|
||||
bound = gain * math.sqrt(1.0/fan)
|
||||
bound = gain * math.sqrt(6.0/fan)
|
||||
return uniform(shape, dtype, -bound, bound)
|
||||
|
||||
def xavier_uniform_(var, gain=1.0):
|
||||
|
|
|
@ -7,8 +7,8 @@ import jittor_utils
|
|||
from jittor_utils import LOG
|
||||
import sys
|
||||
|
||||
jittor_utils.try_import_jit_utils_core()
|
||||
|
||||
with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
|
||||
jittor_utils.try_import_jit_utils_core()
|
||||
|
||||
has_error = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue