mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
b7fff3072a
|
@ -100,7 +100,6 @@ def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||||
return gauss_(var,0, std)
|
return gauss_(var,0, std)
|
||||||
|
|
||||||
|
|
||||||
#TODO: bound = gain * math.sqrt(6.0/fan) ??
|
|
||||||
def xavier_uniform(shape, dtype, gain=1.0):
|
def xavier_uniform(shape, dtype, gain=1.0):
|
||||||
assert len(shape)>1
|
assert len(shape)>1
|
||||||
|
|
||||||
|
@ -108,7 +107,7 @@ def xavier_uniform(shape, dtype, gain=1.0):
|
||||||
for i in shape[2:]:
|
for i in shape[2:]:
|
||||||
matsize *= i
|
matsize *= i
|
||||||
fan = (shape[1] * matsize) + (shape[0] * matsize)
|
fan = (shape[1] * matsize) + (shape[0] * matsize)
|
||||||
bound = gain * math.sqrt(1.0/fan)
|
bound = gain * math.sqrt(6.0/fan)
|
||||||
return uniform(shape, dtype, -bound, bound)
|
return uniform(shape, dtype, -bound, bound)
|
||||||
|
|
||||||
def xavier_uniform_(var, gain=1.0):
|
def xavier_uniform_(var, gain=1.0):
|
||||||
|
|
|
@ -7,8 +7,8 @@ import jittor_utils
|
||||||
from jittor_utils import LOG
|
from jittor_utils import LOG
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
jittor_utils.try_import_jit_utils_core()
|
with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
|
||||||
|
jittor_utils.try_import_jit_utils_core()
|
||||||
|
|
||||||
has_error = 0
|
has_error = 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue