From 59e49b064d06586ee8255337f5e36ebd64b58435 Mon Sep 17 00:00:00 2001 From: lidongyang Date: Wed, 31 Aug 2022 19:47:33 +0800 Subject: [PATCH] fix trunc_normal_ & add some inplace function --- python/jittor/__init__.py | 32 ++++++++++++++++++++++++++++++++ python/jittor/init.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index b5c30ab2..b51279ff 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -617,6 +617,38 @@ def clamp(x, min_v=None, max_v=None): Var.clamp = clamp +def clamp_(x, min_v=None, max_v=None): + return x.assign(x.clamp(min_v=min_v, max_v=max_v)) +Var.clamp_ = clamp_ + +def erfinv_(x): + return x.assign(x.erfinv()) +Var.erfinv_ = erfinv_ + +def erf_(x): + return x.assign(x.erf()) +Var.erf_ = erf_ + +def abs_(x): + return x.assign(x.abs()) +Var.abs_ = abs_ + +def sigmoid_(x): + return x.assign(x.sigmoid()) +Var.sigmoid_ = sigmoid_ + +def sqrt_(x): + return x.assign(x.sqrt()) +Var.sqrt_ = sqrt_ + +def add_(x, y): + return x.assign(x.add(y)) +Var.add_ = add_ + +def multiply_(x, y): + return x.assign(x.multiply(y)) +Var.multiply_ = multiply_ + def type_as(a, b): return a.unary(op=b.dtype) Var.type_as = type_as diff --git a/python/jittor/init.py b/python/jittor/init.py index 1943cd12..6338fa01 100644 --- a/python/jittor/init.py +++ b/python/jittor/init.py @@ -691,7 +691,7 @@ def trunc_normal_(var, mean=0., std=1., a=-2., b=2.): print(linear.weight) linear.weight.trunc_normal_(std=.02) # This is ok too """ - return _no_grad_trunc_normal_(var, mean, std, a, b) + return var.assign(_no_grad_trunc_normal_(var, mean, std, a, b)) Var.trunc_normal_ = trunc_normal_ def _no_grad_trunc_normal_(var, mean, std, a, b):