mirror of https://github.com/Jittor/Jittor
fix trunc_normal_ & add some inplace function
This commit is contained in:
parent
7164b1cc0f
commit
59e49b064d
|
@ -617,6 +617,38 @@ def clamp(x, min_v=None, max_v=None):
|
||||||
|
|
||||||
Var.clamp = clamp
|
Var.clamp = clamp
|
||||||
|
|
||||||
|
def clamp_(x, min_v=None, max_v=None):
|
||||||
|
return x.assign(x.clamp(min_v=min_v, max_v=max_v))
|
||||||
|
Var.clamp_ = clamp_
|
||||||
|
|
||||||
|
def erfinv_(x):
|
||||||
|
return x.assign(x.erfinv())
|
||||||
|
Var.erfinv_ = erfinv_
|
||||||
|
|
||||||
|
def erf_(x):
|
||||||
|
return x.assign(x.erf())
|
||||||
|
Var.erf_ = erf_
|
||||||
|
|
||||||
|
def abs_(x):
|
||||||
|
return x.assign(x.abs())
|
||||||
|
Var.abs_ = abs_
|
||||||
|
|
||||||
|
def sigmoid_(x):
|
||||||
|
return x.assign(x.sigmoid())
|
||||||
|
Var.sigmoid_ = sigmoid_
|
||||||
|
|
||||||
|
def sqrt_(x):
|
||||||
|
return x.assign(x.sqrt())
|
||||||
|
Var.sqrt_ = sqrt_
|
||||||
|
|
||||||
|
def add_(x, y):
|
||||||
|
return x.assign(x.add(y))
|
||||||
|
Var.add_ = add_
|
||||||
|
|
||||||
|
def multiply_(x, y):
|
||||||
|
return x.assign(x.multiply(y))
|
||||||
|
Var.multiply_ = multiply_
|
||||||
|
|
||||||
def type_as(a, b):
|
def type_as(a, b):
|
||||||
return a.unary(op=b.dtype)
|
return a.unary(op=b.dtype)
|
||||||
Var.type_as = type_as
|
Var.type_as = type_as
|
||||||
|
|
|
@ -691,7 +691,7 @@ def trunc_normal_(var, mean=0., std=1., a=-2., b=2.):
|
||||||
print(linear.weight)
|
print(linear.weight)
|
||||||
linear.weight.trunc_normal_(std=.02) # This is ok too
|
linear.weight.trunc_normal_(std=.02) # This is ok too
|
||||||
"""
|
"""
|
||||||
return _no_grad_trunc_normal_(var, mean, std, a, b)
|
return var.assign(_no_grad_trunc_normal_(var, mean, std, a, b))
|
||||||
Var.trunc_normal_ = trunc_normal_
|
Var.trunc_normal_ = trunc_normal_
|
||||||
|
|
||||||
def _no_grad_trunc_normal_(var, mean, std, a, b):
|
def _no_grad_trunc_normal_(var, mean, std, a, b):
|
||||||
|
|
Loading…
Reference in New Issue