polish distributions

This commit is contained in:
Dun Liang 2021-05-12 21:32:39 +08:00
parent 25def1f399
commit 0922678bc8
3 changed files with 116 additions and 21 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.71'
__version__ = '1.2.2.72'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -10,6 +10,7 @@
import math
import numpy as np
import jittor as jt
from jittor.nn import binary_cross_entropy_with_logits
def simple_presum(x):
src = '''
@ -48,7 +49,11 @@ class OneHotCategorical:
return one_hot
def log_prob(self,x):
return jt.log(self.probs)[0,x]
if len(x.shape) == 1:
x = x.unsqueeze(0)
logits = self.logits.broadcast(x.shape)
indices = jt.argmax(x, dim=-1)[0]
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
@ -59,7 +64,18 @@ class OneHotCategorical:
class Categorical:
def __init__(self, probs=None, logits=None):
OneHotCategorical.__init__(self, probs, logits)
assert not (probs is None and logits is None)
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.logits = logits
self.cum_probs = simple_presum(probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
@ -79,12 +95,12 @@ class Categorical:
class Normal:
def __init__(self,mu,sigma):
def __init__(self, mu, sigma):
self.mu = mu
self.sigma = sigma
def sample(self,sample_shape):
return jt.normal(self.mu, self.sigma, sample_shape)
def sample(self, sample_shape=None):
return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape)
def log_prob(self, x):
var = self.sigma**2
@ -95,15 +111,62 @@ class Normal:
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
def kl_divergence(cur_dist,old_dist):
assert isinstance(cur_dist,type(old_dist))
if isinstance(cur_dist,Normal):
class Uniform:
def __init__(self,low,high):
self.low = low
self.high = high
assert high > low
def sample(self,sample_shape):
return jt.uniform(self.low,self.high,sample_shape)
def log_prob(self,x):
if x < self.low or x >= self.high:
return math.inf
return -jt.log(self.high - self.low)
def entropy(self):
return jt.log(self.high - self.low)
class Geometric:
def __init__(self,p=None,logits=None):
assert (p is not None) or (logits is not None)
assert 0 < p and p < 1
if p is None:
self.prob = jt.sigmoid(logits)
self.logits = logits
elif logits is None:
self.prob = p
self.logits = -jt.log(1. / p - 1)
def sample(self, sample_shape):
tiny = jt.info(self.probs.dtype).tiny
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
return (jt.log(u) / (jt.log(-self.probs+1))).floor()
def log_prob(self, x):
return x*jt.log(-self.prob+1)+jt.log(self.prob)
def entropy(self):
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
def kl_divergence(cur_dist, old_dist):
assert isinstance(cur_dist, type(old_dist))
if isinstance(cur_dist, Normal):
vr = (cur_dist.sigma / old_dist.sigma)**2
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
return 0.5*(vr+t1-1-jt.log(vr))
if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):# ?
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
t[jt.array((old_dist.probs == 0))] = math.inf
t[jt.array((cur_dist.probs == 0))] = 0
return t.sum(-1)
if isinstance(cur_dist, Uniform):
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
res = math.inf
return res
if isinstance(cur_dist, Geometric):
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits

View File

@ -1,4 +1,3 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
@ -25,11 +24,22 @@ class TestOneHot(unittest.TestCase):
x = a.sample().numpy()
for i in range(1000):
x += a.sample().numpy()
print(x)
assert (x > 200).all()
y = a.sample([2,3])
y.sync()
assert y.shape == [2,3,4]
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
import torch
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
x = np.zeros((4,10))
for _ in range(4):
nx = np.random.randint(0,9)
x[_,nx] = 1
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
def test_cate(self):
a = jd.Categorical(jt.array([0.25, 0.25, 0.25, 0.25]))
@ -43,15 +53,13 @@ class TestOneHot(unittest.TestCase):
def test_normal(self):
import torch
for _ in range(10):
for _ in range(4):
mu = np.random.uniform(-1,1)
sigma = np.random.uniform(0,2)
jn = jd.Normal(mu,sigma)
tn = torch.distributions.Normal(mu,sigma)
assert np.allclose(jn.entropy().data,tn.entropy().numpy())
x = np.random.uniform(-1,1)
# print(jn.log_prob(x))
# print(tn.log_prob(torch.tensor(x)))
assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x)))
mu2 = np.random.uniform(-1,1)
sigma2 = np.random.uniform(0,2)
@ -61,16 +69,40 @@ class TestOneHot(unittest.TestCase):
def test_categorical(self):
import torch
for _ in range(10):
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
x = np.random.randint(0,10)
# print(jc.log_prob(x),tc.log_prob(x))
assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x)))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x = np.random.randint(0,10,(4))
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
def test_uniform(self):
import torch
for _ in range(4):
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
high, high2 = low + leng, low2 + leng2
ju, ju2 = jd.Uniform(low,high),jd.Uniform(low2,high2)
tu, tu2 = torch.distributions.Uniform(low,high),torch.distributions.Uniform(low2,high2)
assert np.allclose(ju.entropy().data,tu.entropy().numpy())
x = np.random.uniform(low,high)
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
def test_geometric(self):
import torch
for _ in range(4):
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
x = np.random.randint(1,10)
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
if __name__ == "__main__":
unittest.main()