mirror of https://github.com/Jittor/Jittor
Memory Leak
This commit is contained in:
parent
87f51fc13d
commit
f38058bbe7
|
@ -326,6 +326,66 @@ class BatchNorm(Module):
|
|||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
||||
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0])
|
||||
b = self.bias.broadcast(x, [0])
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
if self.sync and jt.mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
|
@ -608,6 +668,47 @@ class ReplicationPad2d(Module):
|
|||
y_idx[j,i] = b
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
|
||||
class BCELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return -(target*jt.log(jt.maximum(output,1e-20))+(1-target)*jt.log(jt.maximum(1-output,1e-20))).mean()
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
x = 1 / (1 + jt.exp(-output))
|
||||
return -(target*jt.log(jt.maximum(x,1e-20))+(1-target)*jt.log(jt.maximum(1-x,1e-20))).mean()
|
||||
|
||||
class L1Loss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return jt.abs(output - target).mean()
|
||||
|
||||
class MSELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return (output - target).sqr().mean()
|
||||
|
||||
class CrossEntropyLoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target, ignore_index=None):
|
||||
return cross_entropy_loss(output, target, ignore_index=ignore_index)
|
||||
|
||||
class Embedding(Module):
|
||||
def __init__(self, num, dim):
|
||||
self.num = num
|
||||
self.dim = dim
|
||||
self.weight = jt.init.gauss([num,dim],'float32').stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
res = self.weight[x].reshape([x.shape[0],self.dim])
|
||||
return res
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
self.upscale_factor = upscale_factor
|
||||
|
|
Loading…
Reference in New Issue