fix trunc_normal_ & add some inplace function

This commit is contained in:
lidongyang 2022-08-31 19:47:33 +08:00 committed by Zheng-Ning Liu
parent 7164b1cc0f
commit 59e49b064d
2 changed files with 33 additions and 1 deletions

View File

@ -617,6 +617,38 @@ def clamp(x, min_v=None, max_v=None):
Var.clamp = clamp Var.clamp = clamp
def clamp_(x, min_v=None, max_v=None):
return x.assign(x.clamp(min_v=min_v, max_v=max_v))
Var.clamp_ = clamp_
def erfinv_(x):
return x.assign(x.erfinv())
Var.erfinv_ = erfinv_
def erf_(x):
return x.assign(x.erf())
Var.erf_ = erf_
def abs_(x):
return x.assign(x.abs())
Var.abs_ = abs_
def sigmoid_(x):
return x.assign(x.sigmoid())
Var.sigmoid_ = sigmoid_
def sqrt_(x):
return x.assign(x.sqrt())
Var.sqrt_ = sqrt_
def add_(x, y):
return x.assign(x.add(y))
Var.add_ = add_
def multiply_(x, y):
return x.assign(x.multiply(y))
Var.multiply_ = multiply_
def type_as(a, b): def type_as(a, b):
return a.unary(op=b.dtype) return a.unary(op=b.dtype)
Var.type_as = type_as Var.type_as = type_as

View File

@ -691,7 +691,7 @@ def trunc_normal_(var, mean=0., std=1., a=-2., b=2.):
print(linear.weight) print(linear.weight)
linear.weight.trunc_normal_(std=.02) # This is ok too linear.weight.trunc_normal_(std=.02) # This is ok too
""" """
return _no_grad_trunc_normal_(var, mean, std, a, b) return var.assign(_no_grad_trunc_normal_(var, mean, std, a, b))
Var.trunc_normal_ = trunc_normal_ Var.trunc_normal_ = trunc_normal_
def _no_grad_trunc_normal_(var, mean, std, a, b): def _no_grad_trunc_normal_(var, mean, std, a, b):