set default amp level

This commit is contained in:
Dun Liang 2023-04-03 01:27:55 +08:00
parent df8628a3a5
commit 16b7966a9a
1 changed files with 7 additions and 4 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.7.0'
__version__ = '1.3.7.1'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -1732,9 +1732,12 @@ Arguments of hook are defined as::
def float16(self):
'''convert all parameters to float16'''
self._amp_level = 3 if flags.th_mode else 4
cls = self.__class__
cls.__call__ = cls.__half_call__
# self._amp_level = 3 if flags.th_mode else 4
# amp level better set globally
self._amp_level = -1
if self._amp_level >= 0:
cls = self.__class__
cls.__call__ = cls.__half_call__
for p in self.parameters():
if p.dtype.is_float():
p.assign(p.float16())