polish distributions

This commit is contained in:
Dun Liang 2021-06-17 21:45:48 +08:00
parent 9c0f3cfdf4
commit 4ec2bfacb2
4 changed files with 69 additions and 68 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.31'
__version__ = '1.2.3.32'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -437,11 +437,11 @@ def pow(x, y):
Var.pow = Var.__pow__ = pow
def argmax(x, dim, keepdims:bool=False):
return x.arg_reduce("max", dim, keepdims)
return jt.arg_reduce(x, "max", dim, keepdims)
Var.argmax = argmax
def argmin(x, dim, keepdims:bool=False):
return x.arg_reduce("min", dim, keepdims)
return jt.arg_reduce(x, "min", dim, keepdims)
Var.argmin = argmin
def randn(*size, dtype="float32", requires_grad=True) -> Var:

View File

@ -29,18 +29,7 @@ kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in
class OneHotCategorical:
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.safe_log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
self.logits = logits
Categorical.__init__(self, probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
@ -48,17 +37,12 @@ class OneHotCategorical:
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
def log_prob(self,x):
if len(x.shape) == 1:
x = x.unsqueeze(0)
logits = self.logits.broadcast(x.shape)
indices = jt.argmax(x, dim=-1)[0]
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
def log_prob(self, x):
x = jt.argmax(x, dim=-1)[0]
return Categorical.log_prob(self, x)
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
@ -68,29 +52,32 @@ class Categorical:
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
probs = probs / probs.sum(-1, True)
if logits is None:
logits = jt.safe_log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.probs = probs
self.logits = logits
self.cum_probs = simple_presum(probs)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
def sample(self, sample_shape=[]):
def sample(self, sample_shape=()):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim-1)
index = one_hot.index(one_hot.ndim - 1)
return (one_hot * index).sum(-1)
def log_prob(self, x):
return jt.safe_log(self.probs)[0,x]
a = self.probs.ndim
b = x.ndim
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
indexes = indexes + (x,)
return jt.safe_log(self.probs).getitem(indexes)
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
@ -141,8 +128,7 @@ class Geometric:
self.logits = -jt.safe_log(1. / p - 1)
def sample(self, sample_shape):
tiny = jt.info(self.probs.dtype).tiny
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
u = jt.rand(sample_shape)
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
def log_prob(self, x):
@ -160,12 +146,6 @@ def kl_divergence(cur_dist, old_dist):
return 0.5*(vr+t1-1-jt.safe_log(vr))
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
# print("t:", t)
# print("old_dist.probs:", old_dist.probs)
# print("old_dist.probs:", (old_dist.probs==0).sum())
# print("cur_dist.probs:", cur_dist.probs)
# t[jt.array((old_dist.probs == 0))] = math.inf
# t[jt.array((cur_dist.probs == 0))] = 0
return t.sum(-1)
if isinstance(cur_dist, Uniform):
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))

View File

@ -634,23 +634,6 @@ def kthvalue(input, k, dim=None, keepdim=False):
jt.Var.kthvalue = kthvalue
def gather(x,dim,index):
if dim<0:
dim+=index.ndim
x_shape = list(x.shape )
i_shape = list(index.shape)
assert i_shape[dim]>0
assert x.ndim == index.ndim
i_shape[dim]=x_shape[dim]
assert i_shape == x_shape
ins = []
for i in range(index.ndim):
ins.append(jt.index(index.shape,dim=i))
ins[dim]=index
return x.reindex(ins)
jt.Var.gather = gather
def _prod(x,dim=0):
x = jt.log(x)
x = x.sum(dim=dim)

View File

@ -31,14 +31,14 @@ class TestOneHot(unittest.TestCase):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
import torch
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
x = np.zeros((4,10))
for _ in range(4):
nx = np.random.randint(0,9)
x[_,nx] = 1
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
np.testing.assert_allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)), atol=1e-5)
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
def test_cate(self):
@ -67,17 +67,55 @@ class TestOneHot(unittest.TestCase):
tn2 = torch.distributions.Normal(mu2,sigma2)
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
def test_categorical(self):
def test_categorical1(self):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x = np.random.randint(0,10,(4))
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
def test_categorical2(self):
def check(prob_shape, sample_shape):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x1 = jc.sample(sample_shape)
x2 = tc.sample(sample_shape)
assert tuple(x1.shape) == tuple(x2.shape)
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
check((10,), (4,))
check((2,3), (4,))
check((3,4,5,6), (2,))
def test_one_hot_categorical2(self):
def check(prob_shape, sample_shape):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x1 = jc.sample(sample_shape)
x2 = tc.sample(sample_shape)
assert tuple(x1.shape) == tuple(x2.shape)
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
check((10,), (4,))
check((2,3), (4,))
check((3,4,5,6), (2,))
def test_uniform(self):
import torch
@ -98,11 +136,11 @@ class TestOneHot(unittest.TestCase):
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
np.testing.assert_allclose(jg.entropy().data,tg.entropy().numpy(), atol=1e-4)
x = np.random.randint(1,10)
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
np.testing.assert_allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)), atol=1e-4)
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
np.testing.assert_allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2), atol=1e-4)
if __name__ == "__main__":
unittest.main()