diff --git a/python/jittor/init.py b/python/jittor/init.py index 2770b3b1..f65661d8 100644 --- a/python/jittor/init.py +++ b/python/jittor/init.py @@ -100,7 +100,6 @@ def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'): return gauss_(var,0, std) -#TODO: bound = gain * math.sqrt(6.0/fan) ?? def xavier_uniform(shape, dtype, gain=1.0): assert len(shape)>1 @@ -108,7 +107,7 @@ def xavier_uniform(shape, dtype, gain=1.0): for i in shape[2:]: matsize *= i fan = (shape[1] * matsize) + (shape[0] * matsize) - bound = gain * math.sqrt(1.0/fan) + bound = gain * math.sqrt(6.0/fan) return uniform(shape, dtype, -bound, bound) def xavier_uniform_(var, gain=1.0): diff --git a/python/jittor_utils/auto_diff.py b/python/jittor_utils/auto_diff.py index 0bb60353..1235ca31 100644 --- a/python/jittor_utils/auto_diff.py +++ b/python/jittor_utils/auto_diff.py @@ -7,8 +7,8 @@ import jittor_utils from jittor_utils import LOG import sys -jittor_utils.try_import_jit_utils_core() - +with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): + jittor_utils.try_import_jit_utils_core() has_error = 0