mirror of https://github.com/Jittor/Jittor
polish distributions
This commit is contained in:
parent
9c0f3cfdf4
commit
4ec2bfacb2
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.31'
|
__version__ = '1.2.3.32'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
@ -437,11 +437,11 @@ def pow(x, y):
|
||||||
Var.pow = Var.__pow__ = pow
|
Var.pow = Var.__pow__ = pow
|
||||||
|
|
||||||
def argmax(x, dim, keepdims:bool=False):
|
def argmax(x, dim, keepdims:bool=False):
|
||||||
return x.arg_reduce("max", dim, keepdims)
|
return jt.arg_reduce(x, "max", dim, keepdims)
|
||||||
Var.argmax = argmax
|
Var.argmax = argmax
|
||||||
|
|
||||||
def argmin(x, dim, keepdims:bool=False):
|
def argmin(x, dim, keepdims:bool=False):
|
||||||
return x.arg_reduce("min", dim, keepdims)
|
return jt.arg_reduce(x, "min", dim, keepdims)
|
||||||
Var.argmin = argmin
|
Var.argmin = argmin
|
||||||
|
|
||||||
def randn(*size, dtype="float32", requires_grad=True) -> Var:
|
def randn(*size, dtype="float32", requires_grad=True) -> Var:
|
||||||
|
|
|
@ -29,18 +29,7 @@ kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in
|
||||||
|
|
||||||
class OneHotCategorical:
|
class OneHotCategorical:
|
||||||
def __init__(self, probs=None, logits=None):
|
def __init__(self, probs=None, logits=None):
|
||||||
assert not (probs is None and logits is None)
|
Categorical.__init__(self, probs, logits)
|
||||||
if probs is None:
|
|
||||||
# cannot align to pytorch
|
|
||||||
probs = jt.sigmoid(logits)
|
|
||||||
elif logits is None:
|
|
||||||
logits = jt.safe_log(probs)
|
|
||||||
with jt.no_grad():
|
|
||||||
self.probs = probs / probs.sum(-1, True)
|
|
||||||
self.cum_probs = simple_presum(self.probs)
|
|
||||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
|
||||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
|
||||||
self.logits = logits
|
|
||||||
|
|
||||||
def sample(self, sample_shape=[]):
|
def sample(self, sample_shape=[]):
|
||||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||||
|
@ -48,17 +37,12 @@ class OneHotCategorical:
|
||||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
def log_prob(self,x):
|
def log_prob(self, x):
|
||||||
if len(x.shape) == 1:
|
x = jt.argmax(x, dim=-1)[0]
|
||||||
x = x.unsqueeze(0)
|
return Categorical.log_prob(self, x)
|
||||||
logits = self.logits.broadcast(x.shape)
|
|
||||||
indices = jt.argmax(x, dim=-1)[0]
|
|
||||||
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
|
|
||||||
|
|
||||||
def entropy(self):
|
def entropy(self):
|
||||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
p_log_p = self.logits * self.probs
|
||||||
logits = jt.clamp(self.logits,min_v=min_real)
|
|
||||||
p_log_p = logits * self.probs
|
|
||||||
return -p_log_p.sum(-1)
|
return -p_log_p.sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,29 +52,32 @@ class Categorical:
|
||||||
if probs is None:
|
if probs is None:
|
||||||
# cannot align to pytorch
|
# cannot align to pytorch
|
||||||
probs = jt.sigmoid(logits)
|
probs = jt.sigmoid(logits)
|
||||||
elif logits is None:
|
probs = probs / probs.sum(-1, True)
|
||||||
|
if logits is None:
|
||||||
logits = jt.safe_log(probs)
|
logits = jt.safe_log(probs)
|
||||||
with jt.no_grad():
|
with jt.no_grad():
|
||||||
self.probs = probs / probs.sum(-1, True)
|
self.probs = probs
|
||||||
self.logits = logits
|
self.logits = logits
|
||||||
self.cum_probs = simple_presum(probs)
|
self.cum_probs = simple_presum(self.probs)
|
||||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||||
|
|
||||||
def sample(self, sample_shape=[]):
|
def sample(self, sample_shape=()):
|
||||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||||
rand = jt.rand(shape)
|
rand = jt.rand(shape)
|
||||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
||||||
index = one_hot.index(one_hot.ndim-1)
|
index = one_hot.index(one_hot.ndim - 1)
|
||||||
return (one_hot * index).sum(-1)
|
return (one_hot * index).sum(-1)
|
||||||
|
|
||||||
def log_prob(self, x):
|
def log_prob(self, x):
|
||||||
return jt.safe_log(self.probs)[0,x]
|
a = self.probs.ndim
|
||||||
|
b = x.ndim
|
||||||
|
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
|
||||||
|
indexes = indexes + (x,)
|
||||||
|
return jt.safe_log(self.probs).getitem(indexes)
|
||||||
|
|
||||||
def entropy(self):
|
def entropy(self):
|
||||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
p_log_p = self.logits * self.probs
|
||||||
logits = jt.clamp(self.logits,min_v=min_real)
|
|
||||||
p_log_p = logits * self.probs
|
|
||||||
return -p_log_p.sum(-1)
|
return -p_log_p.sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,8 +128,7 @@ class Geometric:
|
||||||
self.logits = -jt.safe_log(1. / p - 1)
|
self.logits = -jt.safe_log(1. / p - 1)
|
||||||
|
|
||||||
def sample(self, sample_shape):
|
def sample(self, sample_shape):
|
||||||
tiny = jt.info(self.probs.dtype).tiny
|
u = jt.rand(sample_shape)
|
||||||
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
|
|
||||||
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
|
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
|
||||||
|
|
||||||
def log_prob(self, x):
|
def log_prob(self, x):
|
||||||
|
@ -160,12 +146,6 @@ def kl_divergence(cur_dist, old_dist):
|
||||||
return 0.5*(vr+t1-1-jt.safe_log(vr))
|
return 0.5*(vr+t1-1-jt.safe_log(vr))
|
||||||
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
||||||
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
||||||
# print("t:", t)
|
|
||||||
# print("old_dist.probs:", old_dist.probs)
|
|
||||||
# print("old_dist.probs:", (old_dist.probs==0).sum())
|
|
||||||
# print("cur_dist.probs:", cur_dist.probs)
|
|
||||||
# t[jt.array((old_dist.probs == 0))] = math.inf
|
|
||||||
# t[jt.array((cur_dist.probs == 0))] = 0
|
|
||||||
return t.sum(-1)
|
return t.sum(-1)
|
||||||
if isinstance(cur_dist, Uniform):
|
if isinstance(cur_dist, Uniform):
|
||||||
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||||
|
|
|
@ -634,23 +634,6 @@ def kthvalue(input, k, dim=None, keepdim=False):
|
||||||
|
|
||||||
jt.Var.kthvalue = kthvalue
|
jt.Var.kthvalue = kthvalue
|
||||||
|
|
||||||
|
|
||||||
def gather(x,dim,index):
|
|
||||||
if dim<0:
|
|
||||||
dim+=index.ndim
|
|
||||||
x_shape = list(x.shape )
|
|
||||||
i_shape = list(index.shape)
|
|
||||||
assert i_shape[dim]>0
|
|
||||||
assert x.ndim == index.ndim
|
|
||||||
i_shape[dim]=x_shape[dim]
|
|
||||||
assert i_shape == x_shape
|
|
||||||
ins = []
|
|
||||||
for i in range(index.ndim):
|
|
||||||
ins.append(jt.index(index.shape,dim=i))
|
|
||||||
ins[dim]=index
|
|
||||||
return x.reindex(ins)
|
|
||||||
jt.Var.gather = gather
|
|
||||||
|
|
||||||
def _prod(x,dim=0):
|
def _prod(x,dim=0):
|
||||||
x = jt.log(x)
|
x = jt.log(x)
|
||||||
x = x.sum(dim=dim)
|
x = x.sum(dim=dim)
|
||||||
|
|
|
@ -31,14 +31,14 @@ class TestOneHot(unittest.TestCase):
|
||||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||||
import torch
|
import torch
|
||||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
|
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||||
x = np.zeros((4,10))
|
x = np.zeros((4,10))
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
nx = np.random.randint(0,9)
|
nx = np.random.randint(0,9)
|
||||||
x[_,nx] = 1
|
x[_,nx] = 1
|
||||||
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
|
np.testing.assert_allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||||
|
|
||||||
def test_cate(self):
|
def test_cate(self):
|
||||||
|
@ -67,17 +67,55 @@ class TestOneHot(unittest.TestCase):
|
||||||
tn2 = torch.distributions.Normal(mu2,sigma2)
|
tn2 = torch.distributions.Normal(mu2,sigma2)
|
||||||
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
||||||
|
|
||||||
def test_categorical(self):
|
def test_categorical1(self):
|
||||||
import torch
|
import torch
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||||
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
|
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
|
||||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||||
x = np.random.randint(0,10,(4))
|
x = np.random.randint(0,10,(4))
|
||||||
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
|
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||||
|
|
||||||
|
def test_categorical2(self):
|
||||||
|
def check(prob_shape, sample_shape):
|
||||||
|
import torch
|
||||||
|
for _ in range(4):
|
||||||
|
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||||
|
|
||||||
|
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
|
||||||
|
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||||
|
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||||
|
x1 = jc.sample(sample_shape)
|
||||||
|
x2 = tc.sample(sample_shape)
|
||||||
|
assert tuple(x1.shape) == tuple(x2.shape)
|
||||||
|
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
|
||||||
|
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||||
|
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
|
||||||
|
check((10,), (4,))
|
||||||
|
check((2,3), (4,))
|
||||||
|
check((3,4,5,6), (2,))
|
||||||
|
|
||||||
|
def test_one_hot_categorical2(self):
|
||||||
|
def check(prob_shape, sample_shape):
|
||||||
|
import torch
|
||||||
|
for _ in range(4):
|
||||||
|
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||||
|
|
||||||
|
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||||
|
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||||
|
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||||
|
x1 = jc.sample(sample_shape)
|
||||||
|
x2 = tc.sample(sample_shape)
|
||||||
|
assert tuple(x1.shape) == tuple(x2.shape)
|
||||||
|
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
|
||||||
|
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||||
|
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
|
||||||
|
check((10,), (4,))
|
||||||
|
check((2,3), (4,))
|
||||||
|
check((3,4,5,6), (2,))
|
||||||
|
|
||||||
def test_uniform(self):
|
def test_uniform(self):
|
||||||
import torch
|
import torch
|
||||||
|
@ -98,11 +136,11 @@ class TestOneHot(unittest.TestCase):
|
||||||
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
||||||
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
||||||
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
|
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
|
||||||
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
|
np.testing.assert_allclose(jg.entropy().data,tg.entropy().numpy(), atol=1e-4)
|
||||||
x = np.random.randint(1,10)
|
x = np.random.randint(1,10)
|
||||||
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
|
np.testing.assert_allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)), atol=1e-4)
|
||||||
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||||
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
np.testing.assert_allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2), atol=1e-4)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue