This commit is contained in:
guowei yang 2020-05-26 17:18:58 +08:00
parent 981ca334b3
commit 1abbd8cbda
3 changed files with 117 additions and 8 deletions

View File

@ -358,7 +358,7 @@ def norm(x, k, dim):
if k==1:
return x.abs().sum(dim)
if k==2:
return (x**2).sum(dim).maximum(1e-6).sqrt()
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
Var.norm = norm
origin_reshape = reshape

View File

@ -15,6 +15,7 @@ from jittor import init, Module
import numpy as np
import math
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.optim import *
def matmul_transpose(a, b):
'''
@ -115,8 +116,11 @@ def cross_entropy_loss(output, target, ignore_index=None):
def mse_loss(output, target):
return (output-target).sqr().mean()
def bce_loss(output, target):
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
def bce_loss(output, target, size_average=True):
if size_average:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
else:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
def l1_loss(output, target):
return (output-target).abs().mean()
@ -136,8 +140,8 @@ class MSELoss(Module):
class BCELoss(Module):
def __init__(self):
pass
def execute(self, output, target):
return bce_loss(output, target)
def execute(self, output, target, size_average=True):
return bce_loss(output, target, size_average)
class L1Loss(Module):
def __init__(self):
@ -149,9 +153,9 @@ class BCEWithLogitsLoss(Module):
def __init__(self):
self.sigmoid = Sigmoid()
self.bce = BCELoss()
def execute(self, output, target):
def execute(self, output, target, size_average=True):
output = self.sigmoid(output)
output = self.bce(output, target)
output = self.bce(output, target, size_average)
return output
def softmax(x, dim = None):
@ -228,6 +232,64 @@ class BatchNorm(Module):
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class BatchNorm1d(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0], keepdims=1)
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0])
running_var = self.running_var.broadcast(x, [0])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0])
b = self.bias.broadcast(x, [0])
return norm_x * w + b
class InstanceNorm2d(Module):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
def execute(self, x):
xmean = jt.mean(x, dims=[2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
if self.sync and jt.mpi:
xmean = xmean.mpi_all_reduce("mean")
x2mean = x2mean.mpi_all_reduce("mean")
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
Relu = jt.make_module(relu)
ReLU = Relu
@ -459,6 +521,16 @@ class ReplicationPad2d(Module):
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
])
class Embedding(Module):
def __init__(self, num, dim):
self.num = num
self.dim = dim
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
def execute(self, x):
res = self.weight[x].reshape([x.shape[0],self.dim])
return res
class PixelShuffle(Module):
def __init__(self, upscale_factor):
self.upscale_factor = upscale_factor
@ -521,7 +593,7 @@ class Upsample(Module):
class Sequential(Module):
def __init__(self, *args):
self.layers = args
self.layers = list(args)
def __getitem__(self, idx):
return self.layers[idx]
def execute(self, x):
@ -539,3 +611,7 @@ class Sequential(Module):
parents.pop()
if callback_leave:
callback_leave(parents, k, self, n_children)
def append(self, mod):
self.layers.append(mod)
ModuleList = Sequential

View File

@ -132,6 +132,39 @@ class SGD(Optimizer):
p -= v * lr
p.detach_inplace()
class RMSprop(Optimizer):
""" RMSprop Optimizer.
Example:
```
optimizer = nn.RMSprop(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
optimizer.step(loss)
```
"""
def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99):
# def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
super().__init__(params, lr)
self.eps = eps
self.alpha = alpha
# initialize required arguments for each param_groups
for pg in self.param_groups:
values = pg["values"] = []
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.pre_step(loss)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
eps = pg.get("eps", self.eps)
alpha = pg.get("alpha", self.alpha)
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
if p.is_stop_grad(): continue
v.assign(alpha * v + (1-alpha) * g * g)
p -= lr * g / (jt.sqrt(v) + eps)
p.detach_inplace()
class Adam(Optimizer):
""" Adam Optimizer.
Example: