mirror of https://github.com/Jittor/Jittor
set default amp level
This commit is contained in:
parent
df8628a3a5
commit
16b7966a9a
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.7.0'
|
||||
__version__ = '1.3.7.1'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -1732,9 +1732,12 @@ Arguments of hook are defined as::
|
|||
|
||||
def float16(self):
|
||||
'''convert all parameters to float16'''
|
||||
self._amp_level = 3 if flags.th_mode else 4
|
||||
cls = self.__class__
|
||||
cls.__call__ = cls.__half_call__
|
||||
# self._amp_level = 3 if flags.th_mode else 4
|
||||
# amp level better set globally
|
||||
self._amp_level = -1
|
||||
if self._amp_level >= 0:
|
||||
cls = self.__class__
|
||||
cls.__call__ = cls.__half_call__
|
||||
for p in self.parameters():
|
||||
if p.dtype.is_float():
|
||||
p.assign(p.float16())
|
||||
|
|
Loading…
Reference in New Issue