This commit is contained in:
Dun Liang 2021-07-30 20:40:31 +08:00
commit b7fff3072a
2 changed files with 3 additions and 4 deletions

View File

@ -100,7 +100,6 @@ def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
return gauss_(var,0, std)
#TODO: bound = gain * math.sqrt(6.0/fan) ??
def xavier_uniform(shape, dtype, gain=1.0):
assert len(shape)>1
@ -108,7 +107,7 @@ def xavier_uniform(shape, dtype, gain=1.0):
for i in shape[2:]:
matsize *= i
fan = (shape[1] * matsize) + (shape[0] * matsize)
bound = gain * math.sqrt(1.0/fan)
bound = gain * math.sqrt(6.0/fan)
return uniform(shape, dtype, -bound, bound)
def xavier_uniform_(var, gain=1.0):

View File

@ -7,9 +7,9 @@ import jittor_utils
from jittor_utils import LOG
import sys
with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
jittor_utils.try_import_jit_utils_core()
has_error = 0
def convert(data):