mirror of https://github.com/Jittor/Jittor
add exp lr and remove warning as error
This commit is contained in:
parent
ee020b60f7
commit
01974db52d
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.80'
|
||||
__version__ = '1.2.3.81'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -912,7 +912,8 @@ def check_clang_latest_supported_cpu():
|
|||
apple_cpus_id = max([int(cpu[7:]) for cpu in apple_cpus])
|
||||
return f'apple-a{apple_cpus_id}'
|
||||
|
||||
cc_flags += " -Wall -Werror -Wno-unknown-pragmas -std=c++14 -fPIC "
|
||||
# cc_flags += " -Wall -Werror -Wno-unknown-pragmas -std=c++14 -fPIC "
|
||||
cc_flags += " -Wall -Wno-unknown-pragmas -std=c++14 -fPIC "
|
||||
# 1. Arch/CPU specific optimization
|
||||
if platform.machine() == "x86_64":
|
||||
cc_flags += " -march=native "
|
||||
|
|
|
@ -116,6 +116,33 @@ class CosineAnnealingLR(object):
|
|||
if param_group.get("lr") != None:
|
||||
param_group["lr"] = self.get_lr(self.base_lr_pg[i], param_group["lr"])
|
||||
|
||||
|
||||
class ExponentialLR(object):
|
||||
""" learning rate is multiplied by gamma in each step.
|
||||
"""
|
||||
def __init__(self, optimizer, gamma, last_epoch=-1):
|
||||
self.optimizer = optimizer
|
||||
self.gamma = gamma
|
||||
self.last_epoch = last_epoch
|
||||
self.base_lr = optimizer.lr
|
||||
self.base_lr_pg = [pg.get("lr") for pg in optimizer.param_groups]
|
||||
|
||||
def get_lr(self, base_lr, now_lr):
|
||||
if self.last_epoch == 0:
|
||||
return base_lr
|
||||
return base_lr * self.gamma ** self.last_epoch
|
||||
|
||||
def step(self):
|
||||
self.last_epoch += 1
|
||||
self.update_lr()
|
||||
|
||||
def update_lr(self):
|
||||
self.optimizer.lr = self.get_lr(self.base_lr, self.optimizer.lr)
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
if param_group.get("lr") != None:
|
||||
param_group["lr"] = self.get_lr(self.base_lr_pg[i], param_group["lr"])
|
||||
|
||||
|
||||
class StepLR(object):
|
||||
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
|
||||
self.optimizer = optimizer
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include "pyjt/py_converter.h"
|
||||
#include "pyjt/py_arg_printer.h"
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
|
||||
// #pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
|
||||
#endif
|
||||
#ifdef __GNUC__
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue