mirror of https://github.com/Jittor/Jittor
add softplus
This commit is contained in:
parent
cb52493db9
commit
98545bede4
|
@ -249,6 +249,9 @@ def full(shape,val,dtype="float32"):
|
|||
shape = (shape,)
|
||||
return unary(val, dtype).broadcast(shape)
|
||||
|
||||
def full_like(x,val):
|
||||
return full(x.shape,val,x.dtype)
|
||||
|
||||
def zeros_like(x):
|
||||
return zeros(x.shape,x.dtype)
|
||||
|
||||
|
|
|
@ -98,6 +98,24 @@ def expand(x, shape):
|
|||
jt.Var.expand = expand
|
||||
|
||||
|
||||
def t(x):
|
||||
pose = [i for i in range(x.ndim)]
|
||||
pose[-1], pose[-2] = pose[-2], pose[-1]
|
||||
return x.transpose(*pose)
|
||||
jt.Var.t = t
|
||||
|
||||
def any(x,dim=0,keepdim=False):
|
||||
y = x.int().sum(dim,keepdims=keepdim)
|
||||
return y>0
|
||||
jt.Var.any = any
|
||||
|
||||
def all(x,dim=0,keepdim=False):
|
||||
y = x.int().sum(dim,keepdims=keepdim)
|
||||
return y==x.shape[dim]
|
||||
|
||||
jt.Var.all = all
|
||||
|
||||
|
||||
def median(x,dim=None,keepdim=False):
|
||||
if dim is None:
|
||||
x = x.reshape(-1)
|
||||
|
@ -406,6 +424,8 @@ def meshgrid(*tensors):
|
|||
Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids,
|
||||
where the i th grid is defined by expanding the i th input over dimensions defined by other inputs.
|
||||
'''
|
||||
if len(tensors)==1 and isinstance(tensors[0], list):
|
||||
tensors = tensors[0]
|
||||
size = len(tensors)
|
||||
shape = []
|
||||
for i in range(size):
|
||||
|
@ -470,6 +490,25 @@ def tolist(x):
|
|||
return x.numpy().tolist()
|
||||
jt.Var.tolist = tolist
|
||||
|
||||
def view_as(x,y):
|
||||
return x.reshape(y.shape)
|
||||
jt.Var.view_as = view_as
|
||||
|
||||
def diag(x,diagonal=0):
|
||||
assert x.ndim==1 or (x.ndim==2 and x.shape[0]==x.shape[1])
|
||||
d = diagonal if diagonal>=0 else -diagonal
|
||||
d_str = f'+{diagonal}' if diagonal>=0 else f'{diagonal}'
|
||||
|
||||
if x.ndim==1:
|
||||
output_shape = (x.shape[0]+d,)*2
|
||||
return x.reindex(output_shape,[f'i1-{d}' if diagonal>=0 else f'i0-{d}'],overflow_conditions=[f'i0{d_str}!=i1'])
|
||||
else:
|
||||
output_shape = (x.shape[0]-d,)
|
||||
return x.reindex(output_shape,[f'i0+{d}' if diagonal<=0 else 'i0',f'i0+{d}' if diagonal>=0 else 'i0'])
|
||||
|
||||
jt.Var.diag = diag
|
||||
|
||||
|
||||
def topk(input, k, dim=None, largest=True, sorted=True):
|
||||
if input.numel()==0:
|
||||
return jt.array([],dtype=input.dtype),jt.array([],dtype='int32')
|
||||
|
|
|
@ -881,6 +881,13 @@ class Sigmoid(Module):
|
|||
def execute(self, x) :
|
||||
return x.sigmoid()
|
||||
|
||||
def softplus(x,bata=1,threshold=20):
|
||||
return 1 / beta * jt.log(1 + (beta * x).exp())
|
||||
|
||||
def hardtanh(x,min_val=-1,max_val=1):
|
||||
return jt.clamp(x,min_v=min_val,max_v=max_val)
|
||||
|
||||
|
||||
class Softplus(Module):
|
||||
r'''
|
||||
SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
|
||||
|
@ -1219,6 +1226,9 @@ class Sequential(Module):
|
|||
else:
|
||||
self.append(mod)
|
||||
def __getitem__(self, idx):
|
||||
if idx not in self.layers:
|
||||
return list(self.layers.values())[idx]
|
||||
|
||||
return self.layers[idx]
|
||||
def __iter__(self):
|
||||
return self.layers.values().__iter__()
|
||||
|
|
Loading…
Reference in New Issue