add softplus

This commit is contained in:
li-xl 2020-11-08 16:04:12 +08:00
parent cb52493db9
commit 98545bede4
3 changed files with 52 additions and 0 deletions

View File

@ -249,6 +249,9 @@ def full(shape,val,dtype="float32"):
shape = (shape,)
return unary(val, dtype).broadcast(shape)
def full_like(x,val):
return full(x.shape,val,x.dtype)
def zeros_like(x):
return zeros(x.shape,x.dtype)

View File

@ -98,6 +98,24 @@ def expand(x, shape):
jt.Var.expand = expand
def t(x):
pose = [i for i in range(x.ndim)]
pose[-1], pose[-2] = pose[-2], pose[-1]
return x.transpose(*pose)
jt.Var.t = t
def any(x,dim=0,keepdim=False):
y = x.int().sum(dim,keepdims=keepdim)
return y>0
jt.Var.any = any
def all(x,dim=0,keepdim=False):
y = x.int().sum(dim,keepdims=keepdim)
return y==x.shape[dim]
jt.Var.all = all
def median(x,dim=None,keepdim=False):
if dim is None:
x = x.reshape(-1)
@ -406,6 +424,8 @@ def meshgrid(*tensors):
Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids,
where the i th grid is defined by expanding the i th input over dimensions defined by other inputs.
'''
if len(tensors)==1 and isinstance(tensors[0], list):
tensors = tensors[0]
size = len(tensors)
shape = []
for i in range(size):
@ -470,6 +490,25 @@ def tolist(x):
return x.numpy().tolist()
jt.Var.tolist = tolist
def view_as(x,y):
return x.reshape(y.shape)
jt.Var.view_as = view_as
def diag(x,diagonal=0):
assert x.ndim==1 or (x.ndim==2 and x.shape[0]==x.shape[1])
d = diagonal if diagonal>=0 else -diagonal
d_str = f'+{diagonal}' if diagonal>=0 else f'{diagonal}'
if x.ndim==1:
output_shape = (x.shape[0]+d,)*2
return x.reindex(output_shape,[f'i1-{d}' if diagonal>=0 else f'i0-{d}'],overflow_conditions=[f'i0{d_str}!=i1'])
else:
output_shape = (x.shape[0]-d,)
return x.reindex(output_shape,[f'i0+{d}' if diagonal<=0 else 'i0',f'i0+{d}' if diagonal>=0 else 'i0'])
jt.Var.diag = diag
def topk(input, k, dim=None, largest=True, sorted=True):
if input.numel()==0:
return jt.array([],dtype=input.dtype),jt.array([],dtype='int32')

View File

@ -881,6 +881,13 @@ class Sigmoid(Module):
def execute(self, x) :
return x.sigmoid()
def softplus(x,bata=1,threshold=20):
return 1 / beta * jt.log(1 + (beta * x).exp())
def hardtanh(x,min_val=-1,max_val=1):
return jt.clamp(x,min_v=min_val,max_v=max_val)
class Softplus(Module):
r'''
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
@ -1219,6 +1226,9 @@ class Sequential(Module):
else:
self.append(mod)
def __getitem__(self, idx):
if idx not in self.layers:
return list(self.layers.values())[idx]
return self.layers[idx]
def __iter__(self):
return self.layers.values().__iter__()