diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 5ac11954..5c13fcee 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -154,6 +154,7 @@ def get_init_var_rand(shape, dtype): def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x)) def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale) def relu6(x): return jt.minimum(jt.maximum(x, 0.0), 6.0) +def elu(x,alpha=1.0):return jt.ternary(x>0,x,alpha*(x.exp()-1)) def sign(x): one = jt.ones(x.shape) x = jt.ternary(x>0, one, x) @@ -165,6 +166,13 @@ def gelu(x): r = erf*x*.5 return r +class ELU(Module): + def __init__(self,alpha=1.0): + self.alpha=alpha + + def execute(self,x): + return elu(x,self.alpha) + class PReLU(Module): def __init__(self, num_parameters=1, init_=0.25): self.num_parameters = num_parameters @@ -238,6 +246,30 @@ def smooth_l1_loss(y_true, y_pred,reduction="mean"): else: raise ValueError(f'not support {reduction}') +def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'): + assert output.ndim<=2 and output.ndim>0 and target.ndim==1 + n_classes = output.shape[-1] + assert weight is None or weight.numel()==n_classes + assert ignore_index<0 or ignore_index0: + weight[ignore_index]=0 + if output.ndim==2: + index = jt.index((output.shape[0],),dim=0) + loss = -output[index,target]*weight[target] + else: + loss = -output[target[0]]*weight[target[0]] + if reduction=="mean": + total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum() + return loss.sum()/total_weight + elif reduction=="sum": + return loss.sum() + elif reduction=="none": + return loss + else: + raise ValueError(f'not support {reduction}') + class CrossEntropyLoss(Module): def __init__(self,ignore_index=None): self.ignore_index = ignore_index @@ -330,6 +362,9 @@ class Dropout(Module): output = output * noise / (1.0 - self.p) # div keep prob return output +def dropout(x,p=0.5,is_train=False): + return Dropout(p=p,is_train=is_train)(x) + class Linear(Module): def __init__(self, in_features, out_features, bias=True): self.in_features = in_features