add safe_clip and safe_log

This commit is contained in:
Dun Liang 2021-06-11 14:11:58 +08:00
parent fd798d9925
commit 826513a156
11 changed files with 135 additions and 22 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.24'
__version__ = '1.2.3.25'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -1049,7 +1049,7 @@ if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path,
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
# TODO: move to compile_extern.py
compile_extern()
# compile_extern()
with jit_utils.import_scope(import_flags):
import jittor_core as core

View File

@ -34,7 +34,7 @@ class OneHotCategorical:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.log(probs)
logits = jt.safe_log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(self.probs)
@ -69,7 +69,7 @@ class Categorical:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.log(probs)
logits = jt.safe_log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.logits = logits
@ -85,7 +85,7 @@ class Categorical:
return (one_hot * index).sum(-1)
def log_prob(self, x):
return jt.log(self.probs)[0,x]
return jt.safe_log(self.probs)[0,x]
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
@ -104,11 +104,11 @@ class Normal:
def log_prob(self, x):
var = self.sigma**2
log_scale = jt.log(self.sigma)
log_scale = jt.safe_log(self.sigma)
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
def entropy(self):
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma)
class Uniform:
@ -123,10 +123,10 @@ class Uniform:
def log_prob(self,x):
if x < self.low or x >= self.high:
return math.inf
return -jt.log(self.high - self.low)
return -jt.safe_log(self.high - self.low)
def entropy(self):
return jt.log(self.high - self.low)
return jt.safe_log(self.high - self.low)
class Geometric:
@ -138,15 +138,15 @@ class Geometric:
self.logits = logits
elif logits is None:
self.prob = p
self.logits = -jt.log(1. / p - 1)
self.logits = -jt.safe_log(1. / p - 1)
def sample(self, sample_shape):
tiny = jt.info(self.probs.dtype).tiny
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
return (jt.log(u) / (jt.log(-self.probs+1))).floor()
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
def log_prob(self, x):
return x*jt.log(-self.prob+1)+jt.log(self.prob)
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
def entropy(self):
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
@ -157,16 +157,20 @@ def kl_divergence(cur_dist, old_dist):
if isinstance(cur_dist, Normal):
vr = (cur_dist.sigma / old_dist.sigma)**2
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
return 0.5*(vr+t1-1-jt.log(vr))
return 0.5*(vr+t1-1-jt.safe_log(vr))
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
t[jt.array((old_dist.probs == 0))] = math.inf
t[jt.array((cur_dist.probs == 0))] = 0
# print("t:", t)
# print("old_dist.probs:", old_dist.probs)
# print("old_dist.probs:", (old_dist.probs==0).sum())
# print("cur_dist.probs:", cur_dist.probs)
# t[jt.array((old_dist.probs == 0))] = math.inf
# t[jt.array((cur_dist.probs == 0))] = 0
return t.sum(-1)
if isinstance(cur_dist, Uniform):
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
res = math.inf
return res
if isinstance(cur_dist, Geometric):
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits

View File

@ -1255,3 +1255,7 @@ Examples::
return x.reindex(x.shape, ids)
jt.Var.roll = roll
def safe_log(x):
return jt.safe_clip(x, 1e-30, 1e30).log()
jt.Var.safe_log = safe_log

View File

@ -28,12 +28,12 @@ def matmul_transpose(a, b):
'''
returns a * b^T
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
if len(a.shape)>2:
if len(a.shape) != 2:
aa = a.reshape((-1, a.shape[-1]))
cc = matmul_transpose(aa, b)
return cc.reshape(a.shape[:-1]+(-1,))
assert len(a.shape) == 2 and len(b.shape) == 2
shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2])

View File

@ -104,6 +104,8 @@ int OpCompiler::total_member_count() {
// array need a extra local var
if (op->ops[i]->name()==string("array"))
member_count += 1;
if (op->ops[i]->name()==string("safe_clip"))
member_count += 2;
member_count += v.size();
i += 1;
}
@ -826,11 +828,15 @@ string OpCompiler::__get_fused_src(
const unordered_set<string> members = {
"x", "y", "z", "cond", "output", "extras"
};
const unordered_set<string> scalar_members = {
"left", "right"
};
const unordered_set<string> unchanged = {
"for", "const", "auto", "get_random_engine",
"int", "float", "bool", "CHECK", "STRINGIZE",
"void", "__restrict__", "if", "true", "false",
"Op", "Var", "Node", "itof", "assert", "ASSERT"
"Op", "Var", "Node", "itof", "assert", "ASSERT",
"float64"
};
auto not_change = [&](const string& s) -> bool {
if (unchanged.count(s)) return true;
@ -941,7 +947,8 @@ string OpCompiler::__get_fused_src(
while (l<src.size() && isvar(src[l])) l++;
auto var = src.substr(j, l-j);
if (var[0] == ':' || isdigit(var[0]) || not_change(var) || src[j-1]=='.' || src[j-1]=='>') {} else
if (members.count(var)) {
if (members.count(var) || scalar_members.count(var)) {
bool is_member = members.count(var);
string arg_name = "op" + S(oi) + "_" + var;
if (l<src.size() && src[l]=='[') {
// handle extras[...]
@ -964,7 +971,8 @@ string OpCompiler::__get_fused_src(
" = (("+name3+"Op*)(ops[" + S(oi) + "]))->" + var;
fused_kernel_args += ";\n";
kernel_args.insert(arg_name);
op_members[oi].push_back(arg_name);
if (is_member)
op_members[oi].push_back(arg_name);
}
fused_kernel += arg_name;
j = l-1;

View File

@ -0,0 +1,47 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "ops/safe_clip_op.h"
#include "ops/op_register.h"
namespace jittor {
#ifndef JIT
SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::element);
y = create_output(nullptr, x->dtype());
}
VarPtr SafeClipOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return dout;
}
void SafeClipOp::infer_shape() {
y->set_shape(x->shape);
}
void SafeClipOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() <<']';
}
#else // JIT
void SafeClipOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
Tx left_value = (Tx)std::max((float64)std::numeric_limits<Tx>::lowest(), left);
Tx right_value = (Tx)std::min((float64)std::numeric_limits<Tx>::max(), right);
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
for (index_t i=0; i<num; i++)
yp[i] = xp[i] < left_value ? left_value : (xp[i] > right_value ? right_value : xp[i]);
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,33 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct SafeClipOp : Op {
Var* x, * y;
float64 left, right;
/** Safe clip value to a range, and keep
the gradient pass thought.
* [in] x: input value
* [in] left: float64 clip min value.
* [in] right: float64 clip max value.
*/
// @pybind(safe_clip)
SafeClipOp(Var* x, float64 left, float64 right);
const char* name() const override { return "safe_clip"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -67,6 +67,10 @@ void LoopToFuncPass::run() {
args.push_back(d.get());
continue;
}
if (endswith(d->attrs["lvalue"], "_value")) {
args.push_back(d.get());
continue;
}
}
}
func->push_back(d->clone());

View File

@ -347,5 +347,11 @@ class TestMatmul(unittest.TestCase):
def test_matmul_example2_cuda(self):
self.test_matmul_example2()
def test_linear1d(self):
linear = jt.nn.Linear(10,20)
a = jt.random((10,))
b = linear(a)
assert b.shape == (20,)
if __name__ == "__main__":
unittest.main()

View File

@ -69,6 +69,13 @@ class TestUnaryOp(unittest.TestCase):
b1 = b.sigmoid().numpy()
assert np.isnan(b1).any() == False
def test_safe_clip(self):
a = jt.array([-1.0,0,0.4,1,2,3])
b = a.safe_clip(0.1, 0.5)
assert np.allclose(b.data, [0.1,0.1,0.4,0.5,0.5,0.5])
da = jt.grad(b, a)
assert (da.data == 1).all()
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
pass