mirror of https://github.com/Jittor/Jittor
polish distributions
This commit is contained in:
parent
9c0f3cfdf4
commit
4ec2bfacb2
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.31'
|
||||
__version__ = '1.2.3.32'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -437,11 +437,11 @@ def pow(x, y):
|
|||
Var.pow = Var.__pow__ = pow
|
||||
|
||||
def argmax(x, dim, keepdims:bool=False):
|
||||
return x.arg_reduce("max", dim, keepdims)
|
||||
return jt.arg_reduce(x, "max", dim, keepdims)
|
||||
Var.argmax = argmax
|
||||
|
||||
def argmin(x, dim, keepdims:bool=False):
|
||||
return x.arg_reduce("min", dim, keepdims)
|
||||
return jt.arg_reduce(x, "min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=True) -> Var:
|
||||
|
|
|
@ -29,18 +29,7 @@ kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in
|
|||
|
||||
class OneHotCategorical:
|
||||
def __init__(self, probs=None, logits=None):
|
||||
assert not (probs is None and logits is None)
|
||||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.safe_log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
self.logits = logits
|
||||
Categorical.__init__(self, probs, logits)
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
|
@ -48,17 +37,12 @@ class OneHotCategorical:
|
|||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
||||
return one_hot
|
||||
|
||||
def log_prob(self,x):
|
||||
if len(x.shape) == 1:
|
||||
x = x.unsqueeze(0)
|
||||
logits = self.logits.broadcast(x.shape)
|
||||
indices = jt.argmax(x, dim=-1)[0]
|
||||
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
|
||||
def log_prob(self, x):
|
||||
x = jt.argmax(x, dim=-1)[0]
|
||||
return Categorical.log_prob(self, x)
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
logits = jt.clamp(self.logits,min_v=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
p_log_p = self.logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
|
@ -68,29 +52,32 @@ class Categorical:
|
|||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
probs = probs / probs.sum(-1, True)
|
||||
if logits is None:
|
||||
logits = jt.safe_log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.probs = probs
|
||||
self.logits = logits
|
||||
self.cum_probs = simple_presum(probs)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
def sample(self, sample_shape=()):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
rand = jt.rand(shape)
|
||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
||||
index = one_hot.index(one_hot.ndim-1)
|
||||
index = one_hot.index(one_hot.ndim - 1)
|
||||
return (one_hot * index).sum(-1)
|
||||
|
||||
|
||||
def log_prob(self, x):
|
||||
return jt.safe_log(self.probs)[0,x]
|
||||
|
||||
a = self.probs.ndim
|
||||
b = x.ndim
|
||||
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
|
||||
indexes = indexes + (x,)
|
||||
return jt.safe_log(self.probs).getitem(indexes)
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
logits = jt.clamp(self.logits,min_v=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
p_log_p = self.logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
|
@ -141,8 +128,7 @@ class Geometric:
|
|||
self.logits = -jt.safe_log(1. / p - 1)
|
||||
|
||||
def sample(self, sample_shape):
|
||||
tiny = jt.info(self.probs.dtype).tiny
|
||||
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
|
||||
u = jt.rand(sample_shape)
|
||||
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
|
||||
|
||||
def log_prob(self, x):
|
||||
|
@ -160,12 +146,6 @@ def kl_divergence(cur_dist, old_dist):
|
|||
return 0.5*(vr+t1-1-jt.safe_log(vr))
|
||||
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
||||
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
||||
# print("t:", t)
|
||||
# print("old_dist.probs:", old_dist.probs)
|
||||
# print("old_dist.probs:", (old_dist.probs==0).sum())
|
||||
# print("cur_dist.probs:", cur_dist.probs)
|
||||
# t[jt.array((old_dist.probs == 0))] = math.inf
|
||||
# t[jt.array((cur_dist.probs == 0))] = 0
|
||||
return t.sum(-1)
|
||||
if isinstance(cur_dist, Uniform):
|
||||
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||
|
|
|
@ -634,23 +634,6 @@ def kthvalue(input, k, dim=None, keepdim=False):
|
|||
|
||||
jt.Var.kthvalue = kthvalue
|
||||
|
||||
|
||||
def gather(x,dim,index):
|
||||
if dim<0:
|
||||
dim+=index.ndim
|
||||
x_shape = list(x.shape )
|
||||
i_shape = list(index.shape)
|
||||
assert i_shape[dim]>0
|
||||
assert x.ndim == index.ndim
|
||||
i_shape[dim]=x_shape[dim]
|
||||
assert i_shape == x_shape
|
||||
ins = []
|
||||
for i in range(index.ndim):
|
||||
ins.append(jt.index(index.shape,dim=i))
|
||||
ins[dim]=index
|
||||
return x.reindex(ins)
|
||||
jt.Var.gather = gather
|
||||
|
||||
def _prod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
x = x.sum(dim=dim)
|
||||
|
|
|
@ -31,14 +31,14 @@ class TestOneHot(unittest.TestCase):
|
|||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
import torch
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||
x = np.zeros((4,10))
|
||||
for _ in range(4):
|
||||
nx = np.random.randint(0,9)
|
||||
x[_,nx] = 1
|
||||
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
def test_cate(self):
|
||||
|
@ -67,17 +67,55 @@ class TestOneHot(unittest.TestCase):
|
|||
tn2 = torch.distributions.Normal(mu2,sigma2)
|
||||
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
||||
|
||||
def test_categorical(self):
|
||||
def test_categorical1(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
|
||||
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x = np.random.randint(0,10,(4))
|
||||
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
def test_categorical2(self):
|
||||
def check(prob_shape, sample_shape):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||
|
||||
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x1 = jc.sample(sample_shape)
|
||||
x2 = tc.sample(sample_shape)
|
||||
assert tuple(x1.shape) == tuple(x2.shape)
|
||||
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
|
||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
|
||||
check((10,), (4,))
|
||||
check((2,3), (4,))
|
||||
check((3,4,5,6), (2,))
|
||||
|
||||
def test_one_hot_categorical2(self):
|
||||
def check(prob_shape, sample_shape):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x1 = jc.sample(sample_shape)
|
||||
x2 = tc.sample(sample_shape)
|
||||
assert tuple(x1.shape) == tuple(x2.shape)
|
||||
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
|
||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
|
||||
check((10,), (4,))
|
||||
check((2,3), (4,))
|
||||
check((3,4,5,6), (2,))
|
||||
|
||||
def test_uniform(self):
|
||||
import torch
|
||||
|
@ -98,11 +136,11 @@ class TestOneHot(unittest.TestCase):
|
|||
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
||||
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
||||
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
|
||||
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
|
||||
np.testing.assert_allclose(jg.entropy().data,tg.entropy().numpy(), atol=1e-4)
|
||||
x = np.random.randint(1,10)
|
||||
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)), atol=1e-4)
|
||||
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||
np.testing.assert_allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2), atol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue