mirror of https://github.com/Jittor/Jittor
RMSprop
This commit is contained in:
parent
981ca334b3
commit
1abbd8cbda
|
@ -358,7 +358,7 @@ def norm(x, k, dim):
|
|||
if k==1:
|
||||
return x.abs().sum(dim)
|
||||
if k==2:
|
||||
return (x**2).sum(dim).maximum(1e-6).sqrt()
|
||||
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
|
|
|
@ -15,6 +15,7 @@ from jittor import init, Module
|
|||
import numpy as np
|
||||
import math
|
||||
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
|
||||
from jittor.optim import *
|
||||
|
||||
def matmul_transpose(a, b):
|
||||
'''
|
||||
|
@ -115,8 +116,11 @@ def cross_entropy_loss(output, target, ignore_index=None):
|
|||
def mse_loss(output, target):
|
||||
return (output-target).sqr().mean()
|
||||
|
||||
def bce_loss(output, target):
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
def bce_loss(output, target, size_average=True):
|
||||
if size_average:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
else:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
|
||||
|
||||
def l1_loss(output, target):
|
||||
return (output-target).abs().mean()
|
||||
|
@ -136,8 +140,8 @@ class MSELoss(Module):
|
|||
class BCELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return bce_loss(output, target)
|
||||
def execute(self, output, target, size_average=True):
|
||||
return bce_loss(output, target, size_average)
|
||||
|
||||
class L1Loss(Module):
|
||||
def __init__(self):
|
||||
|
@ -149,9 +153,9 @@ class BCEWithLogitsLoss(Module):
|
|||
def __init__(self):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss()
|
||||
def execute(self, output, target):
|
||||
def execute(self, output, target, size_average=True):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target)
|
||||
output = self.bce(output, target, size_average)
|
||||
return output
|
||||
|
||||
def softmax(x, dim = None):
|
||||
|
@ -228,6 +232,64 @@ class BatchNorm(Module):
|
|||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
||||
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0])
|
||||
b = self.bias.broadcast(x, [0])
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
|
||||
def execute(self, x):
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
|
@ -459,6 +521,16 @@ class ReplicationPad2d(Module):
|
|||
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
|
||||
])
|
||||
|
||||
class Embedding(Module):
|
||||
def __init__(self, num, dim):
|
||||
self.num = num
|
||||
self.dim = dim
|
||||
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
res = self.weight[x].reshape([x.shape[0],self.dim])
|
||||
return res
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
@ -521,7 +593,7 @@ class Upsample(Module):
|
|||
|
||||
class Sequential(Module):
|
||||
def __init__(self, *args):
|
||||
self.layers = args
|
||||
self.layers = list(args)
|
||||
def __getitem__(self, idx):
|
||||
return self.layers[idx]
|
||||
def execute(self, x):
|
||||
|
@ -539,3 +611,7 @@ class Sequential(Module):
|
|||
parents.pop()
|
||||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
||||
def append(self, mod):
|
||||
self.layers.append(mod)
|
||||
|
||||
ModuleList = Sequential
|
||||
|
|
|
@ -132,6 +132,39 @@ class SGD(Optimizer):
|
|||
p -= v * lr
|
||||
p.detach_inplace()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
Example:
|
||||
```
|
||||
optimizer = nn.RMSprop(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
|
||||
optimizer.step(loss)
|
||||
```
|
||||
"""
|
||||
def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99):
|
||||
# def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
|
||||
super().__init__(params, lr)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
|
||||
# initialize required arguments for each param_groups
|
||||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
eps = pg.get("eps", self.eps)
|
||||
alpha = pg.get("alpha", self.alpha)
|
||||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
v.assign(alpha * v + (1-alpha) * g * g)
|
||||
p -= lr * g / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
Example:
|
||||
|
|
Loading…
Reference in New Issue