mirror of https://github.com/Jittor/Jittor
polish distributions
This commit is contained in:
parent
25def1f399
commit
0922678bc8
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.2.71'
|
||||
__version__ = '1.2.2.72'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import jittor as jt
|
||||
from jittor.nn import binary_cross_entropy_with_logits
|
||||
|
||||
def simple_presum(x):
|
||||
src = '''
|
||||
|
@ -48,7 +49,11 @@ class OneHotCategorical:
|
|||
return one_hot
|
||||
|
||||
def log_prob(self,x):
|
||||
return jt.log(self.probs)[0,x]
|
||||
if len(x.shape) == 1:
|
||||
x = x.unsqueeze(0)
|
||||
logits = self.logits.broadcast(x.shape)
|
||||
indices = jt.argmax(x, dim=-1)[0]
|
||||
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
|
@ -59,7 +64,18 @@ class OneHotCategorical:
|
|||
|
||||
class Categorical:
|
||||
def __init__(self, probs=None, logits=None):
|
||||
OneHotCategorical.__init__(self, probs, logits)
|
||||
assert not (probs is None and logits is None)
|
||||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.logits = logits
|
||||
self.cum_probs = simple_presum(probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
|
@ -79,12 +95,12 @@ class Categorical:
|
|||
|
||||
|
||||
class Normal:
|
||||
def __init__(self,mu,sigma):
|
||||
def __init__(self, mu, sigma):
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def sample(self,sample_shape):
|
||||
return jt.normal(self.mu, self.sigma, sample_shape)
|
||||
def sample(self, sample_shape=None):
|
||||
return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape)
|
||||
|
||||
def log_prob(self, x):
|
||||
var = self.sigma**2
|
||||
|
@ -95,15 +111,62 @@ class Normal:
|
|||
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
|
||||
|
||||
|
||||
def kl_divergence(cur_dist,old_dist):
|
||||
assert isinstance(cur_dist,type(old_dist))
|
||||
if isinstance(cur_dist,Normal):
|
||||
class Uniform:
|
||||
def __init__(self,low,high):
|
||||
self.low = low
|
||||
self.high = high
|
||||
assert high > low
|
||||
|
||||
def sample(self,sample_shape):
|
||||
return jt.uniform(self.low,self.high,sample_shape)
|
||||
|
||||
def log_prob(self,x):
|
||||
if x < self.low or x >= self.high:
|
||||
return math.inf
|
||||
return -jt.log(self.high - self.low)
|
||||
|
||||
def entropy(self):
|
||||
return jt.log(self.high - self.low)
|
||||
|
||||
|
||||
class Geometric:
|
||||
def __init__(self,p=None,logits=None):
|
||||
assert (p is not None) or (logits is not None)
|
||||
assert 0 < p and p < 1
|
||||
if p is None:
|
||||
self.prob = jt.sigmoid(logits)
|
||||
self.logits = logits
|
||||
elif logits is None:
|
||||
self.prob = p
|
||||
self.logits = -jt.log(1. / p - 1)
|
||||
|
||||
def sample(self, sample_shape):
|
||||
tiny = jt.info(self.probs.dtype).tiny
|
||||
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
|
||||
return (jt.log(u) / (jt.log(-self.probs+1))).floor()
|
||||
|
||||
def log_prob(self, x):
|
||||
return x*jt.log(-self.prob+1)+jt.log(self.prob)
|
||||
|
||||
def entropy(self):
|
||||
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
|
||||
|
||||
|
||||
def kl_divergence(cur_dist, old_dist):
|
||||
assert isinstance(cur_dist, type(old_dist))
|
||||
if isinstance(cur_dist, Normal):
|
||||
vr = (cur_dist.sigma / old_dist.sigma)**2
|
||||
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
|
||||
return 0.5*(vr+t1-1-jt.log(vr))
|
||||
if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):# ?
|
||||
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
||||
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
||||
t[jt.array((old_dist.probs == 0))] = math.inf
|
||||
t[jt.array((cur_dist.probs == 0))] = 0
|
||||
return t.sum(-1)
|
||||
|
||||
if isinstance(cur_dist, Uniform):
|
||||
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
|
||||
res = math.inf
|
||||
return res
|
||||
if isinstance(cur_dist, Geometric):
|
||||
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
|
@ -25,11 +24,22 @@ class TestOneHot(unittest.TestCase):
|
|||
x = a.sample().numpy()
|
||||
for i in range(1000):
|
||||
x += a.sample().numpy()
|
||||
print(x)
|
||||
assert (x > 200).all()
|
||||
y = a.sample([2,3])
|
||||
y.sync()
|
||||
assert y.shape == [2,3,4]
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
import torch
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
|
||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||
x = np.zeros((4,10))
|
||||
for _ in range(4):
|
||||
nx = np.random.randint(0,9)
|
||||
x[_,nx] = 1
|
||||
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
def test_cate(self):
|
||||
a = jd.Categorical(jt.array([0.25, 0.25, 0.25, 0.25]))
|
||||
|
@ -43,15 +53,13 @@ class TestOneHot(unittest.TestCase):
|
|||
|
||||
def test_normal(self):
|
||||
import torch
|
||||
for _ in range(10):
|
||||
for _ in range(4):
|
||||
mu = np.random.uniform(-1,1)
|
||||
sigma = np.random.uniform(0,2)
|
||||
jn = jd.Normal(mu,sigma)
|
||||
tn = torch.distributions.Normal(mu,sigma)
|
||||
assert np.allclose(jn.entropy().data,tn.entropy().numpy())
|
||||
x = np.random.uniform(-1,1)
|
||||
# print(jn.log_prob(x))
|
||||
# print(tn.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x)))
|
||||
mu2 = np.random.uniform(-1,1)
|
||||
sigma2 = np.random.uniform(0,2)
|
||||
|
@ -61,16 +69,40 @@ class TestOneHot(unittest.TestCase):
|
|||
|
||||
def test_categorical(self):
|
||||
import torch
|
||||
for _ in range(10):
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
|
||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||
x = np.random.randint(0,10)
|
||||
# print(jc.log_prob(x),tc.log_prob(x))
|
||||
assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x = np.random.randint(0,10,(4))
|
||||
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
def test_uniform(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
low, low2 = np.random.randint(-1,2), np.random.randint(-1,2)
|
||||
leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2)
|
||||
high, high2 = low + leng, low2 + leng2
|
||||
ju, ju2 = jd.Uniform(low,high),jd.Uniform(low2,high2)
|
||||
tu, tu2 = torch.distributions.Uniform(low,high),torch.distributions.Uniform(low2,high2)
|
||||
assert np.allclose(ju.entropy().data,tu.entropy().numpy())
|
||||
x = np.random.uniform(low,high)
|
||||
assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2))
|
||||
|
||||
def test_geometric(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
||||
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
||||
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
|
||||
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
|
||||
x = np.random.randint(1,10)
|
||||
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
|
||||
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue