mirror of https://github.com/Jittor/Jittor
add safe_clip and safe_log
This commit is contained in:
parent
fd798d9925
commit
826513a156
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.24'
|
||||
__version__ = '1.2.3.25'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1049,7 +1049,7 @@ if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path,
|
|||
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
||||
|
||||
# TODO: move to compile_extern.py
|
||||
compile_extern()
|
||||
# compile_extern()
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
import jittor_core as core
|
||||
|
|
|
@ -34,7 +34,7 @@ class OneHotCategorical:
|
|||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.log(probs)
|
||||
logits = jt.safe_log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
|
@ -69,7 +69,7 @@ class Categorical:
|
|||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.log(probs)
|
||||
logits = jt.safe_log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.logits = logits
|
||||
|
@ -85,7 +85,7 @@ class Categorical:
|
|||
return (one_hot * index).sum(-1)
|
||||
|
||||
def log_prob(self, x):
|
||||
return jt.log(self.probs)[0,x]
|
||||
return jt.safe_log(self.probs)[0,x]
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
|
@ -104,11 +104,11 @@ class Normal:
|
|||
|
||||
def log_prob(self, x):
|
||||
var = self.sigma**2
|
||||
log_scale = jt.log(self.sigma)
|
||||
log_scale = jt.safe_log(self.sigma)
|
||||
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
|
||||
|
||||
def entropy(self):
|
||||
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
|
||||
return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma)
|
||||
|
||||
|
||||
class Uniform:
|
||||
|
@ -123,10 +123,10 @@ class Uniform:
|
|||
def log_prob(self,x):
|
||||
if x < self.low or x >= self.high:
|
||||
return math.inf
|
||||
return -jt.log(self.high - self.low)
|
||||
return -jt.safe_log(self.high - self.low)
|
||||
|
||||
def entropy(self):
|
||||
return jt.log(self.high - self.low)
|
||||
return jt.safe_log(self.high - self.low)
|
||||
|
||||
|
||||
class Geometric:
|
||||
|
@ -138,15 +138,15 @@ class Geometric:
|
|||
self.logits = logits
|
||||
elif logits is None:
|
||||
self.prob = p
|
||||
self.logits = -jt.log(1. / p - 1)
|
||||
self.logits = -jt.safe_log(1. / p - 1)
|
||||
|
||||
def sample(self, sample_shape):
|
||||
tiny = jt.info(self.probs.dtype).tiny
|
||||
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
|
||||
return (jt.log(u) / (jt.log(-self.probs+1))).floor()
|
||||
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
|
||||
|
||||
def log_prob(self, x):
|
||||
return x*jt.log(-self.prob+1)+jt.log(self.prob)
|
||||
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
|
||||
|
||||
def entropy(self):
|
||||
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
|
||||
|
@ -157,16 +157,20 @@ def kl_divergence(cur_dist, old_dist):
|
|||
if isinstance(cur_dist, Normal):
|
||||
vr = (cur_dist.sigma / old_dist.sigma)**2
|
||||
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
|
||||
return 0.5*(vr+t1-1-jt.log(vr))
|
||||
return 0.5*(vr+t1-1-jt.safe_log(vr))
|
||||
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
||||
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
||||
t[jt.array((old_dist.probs == 0))] = math.inf
|
||||
t[jt.array((cur_dist.probs == 0))] = 0
|
||||
# print("t:", t)
|
||||
# print("old_dist.probs:", old_dist.probs)
|
||||
# print("old_dist.probs:", (old_dist.probs==0).sum())
|
||||
# print("cur_dist.probs:", cur_dist.probs)
|
||||
# t[jt.array((old_dist.probs == 0))] = math.inf
|
||||
# t[jt.array((cur_dist.probs == 0))] = 0
|
||||
return t.sum(-1)
|
||||
if isinstance(cur_dist, Uniform):
|
||||
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
|
||||
res = math.inf
|
||||
return res
|
||||
if isinstance(cur_dist, Geometric):
|
||||
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
|
||||
return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
|
||||
|
|
|
@ -1255,3 +1255,7 @@ Examples::
|
|||
return x.reindex(x.shape, ids)
|
||||
|
||||
jt.Var.roll = roll
|
||||
|
||||
def safe_log(x):
|
||||
return jt.safe_clip(x, 1e-30, 1e30).log()
|
||||
jt.Var.safe_log = safe_log
|
||||
|
|
|
@ -28,12 +28,12 @@ def matmul_transpose(a, b):
|
|||
'''
|
||||
returns a * b^T
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
|
||||
if len(a.shape)>2:
|
||||
if len(a.shape) != 2:
|
||||
aa = a.reshape((-1, a.shape[-1]))
|
||||
cc = matmul_transpose(aa, b)
|
||||
return cc.reshape(a.shape[:-1]+(-1,))
|
||||
assert len(a.shape) == 2 and len(b.shape) == 2
|
||||
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
|
|
|
@ -104,6 +104,8 @@ int OpCompiler::total_member_count() {
|
|||
// array need a extra local var
|
||||
if (op->ops[i]->name()==string("array"))
|
||||
member_count += 1;
|
||||
if (op->ops[i]->name()==string("safe_clip"))
|
||||
member_count += 2;
|
||||
member_count += v.size();
|
||||
i += 1;
|
||||
}
|
||||
|
@ -826,11 +828,15 @@ string OpCompiler::__get_fused_src(
|
|||
const unordered_set<string> members = {
|
||||
"x", "y", "z", "cond", "output", "extras"
|
||||
};
|
||||
const unordered_set<string> scalar_members = {
|
||||
"left", "right"
|
||||
};
|
||||
const unordered_set<string> unchanged = {
|
||||
"for", "const", "auto", "get_random_engine",
|
||||
"int", "float", "bool", "CHECK", "STRINGIZE",
|
||||
"void", "__restrict__", "if", "true", "false",
|
||||
"Op", "Var", "Node", "itof", "assert", "ASSERT"
|
||||
"Op", "Var", "Node", "itof", "assert", "ASSERT",
|
||||
"float64"
|
||||
};
|
||||
auto not_change = [&](const string& s) -> bool {
|
||||
if (unchanged.count(s)) return true;
|
||||
|
@ -941,7 +947,8 @@ string OpCompiler::__get_fused_src(
|
|||
while (l<src.size() && isvar(src[l])) l++;
|
||||
auto var = src.substr(j, l-j);
|
||||
if (var[0] == ':' || isdigit(var[0]) || not_change(var) || src[j-1]=='.' || src[j-1]=='>') {} else
|
||||
if (members.count(var)) {
|
||||
if (members.count(var) || scalar_members.count(var)) {
|
||||
bool is_member = members.count(var);
|
||||
string arg_name = "op" + S(oi) + "_" + var;
|
||||
if (l<src.size() && src[l]=='[') {
|
||||
// handle extras[...]
|
||||
|
@ -964,7 +971,8 @@ string OpCompiler::__get_fused_src(
|
|||
" = (("+name3+"Op*)(ops[" + S(oi) + "]))->" + var;
|
||||
fused_kernel_args += ";\n";
|
||||
kernel_args.insert(arg_name);
|
||||
op_members[oi].push_back(arg_name);
|
||||
if (is_member)
|
||||
op_members[oi].push_back(arg_name);
|
||||
}
|
||||
fused_kernel += arg_name;
|
||||
j = l-1;
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "ops/safe_clip_op.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::element);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
VarPtr SafeClipOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return dout;
|
||||
}
|
||||
|
||||
void SafeClipOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void SafeClipOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() <<']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
void SafeClipOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
Tx left_value = (Tx)std::max((float64)std::numeric_limits<Tx>::lowest(), left);
|
||||
Tx right_value = (Tx)std::min((float64)std::numeric_limits<Tx>::max(), right);
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
index_t num = y->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
yp[i] = xp[i] < left_value ? left_value : (xp[i] > right_value ? right_value : xp[i]);
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,33 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct SafeClipOp : Op {
|
||||
Var* x, * y;
|
||||
float64 left, right;
|
||||
/** Safe clip value to a range, and keep
|
||||
the gradient pass thought.
|
||||
|
||||
* [in] x: input value
|
||||
* [in] left: float64 clip min value.
|
||||
* [in] right: float64 clip max value.
|
||||
|
||||
*/
|
||||
// @pybind(safe_clip)
|
||||
SafeClipOp(Var* x, float64 left, float64 right);
|
||||
|
||||
const char* name() const override { return "safe_clip"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -67,6 +67,10 @@ void LoopToFuncPass::run() {
|
|||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
if (endswith(d->attrs["lvalue"], "_value")) {
|
||||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
func->push_back(d->clone());
|
||||
|
|
|
@ -347,5 +347,11 @@ class TestMatmul(unittest.TestCase):
|
|||
def test_matmul_example2_cuda(self):
|
||||
self.test_matmul_example2()
|
||||
|
||||
def test_linear1d(self):
|
||||
linear = jt.nn.Linear(10,20)
|
||||
a = jt.random((10,))
|
||||
b = linear(a)
|
||||
assert b.shape == (20,)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -69,6 +69,13 @@ class TestUnaryOp(unittest.TestCase):
|
|||
b1 = b.sigmoid().numpy()
|
||||
assert np.isnan(b1).any() == False
|
||||
|
||||
def test_safe_clip(self):
|
||||
a = jt.array([-1.0,0,0.4,1,2,3])
|
||||
b = a.safe_clip(0.1, 0.5)
|
||||
assert np.allclose(b.data, [0.1,0.1,0.4,0.5,0.5,0.5])
|
||||
da = jt.grad(b, a)
|
||||
assert (da.data == 1).all()
|
||||
|
||||
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue