add normal, log_prob,entropy ,kl_divergence

This commit is contained in:
Dun Liang 2021-05-07 17:34:33 +08:00
parent 04be711ddb
commit 6c93dbb5fb
3 changed files with 88 additions and 5 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.68'
__version__ = '1.2.2.69'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -1,11 +1,14 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import math
import numpy as np
import jittor as jt
def simple_presum(x):
@ -29,20 +32,31 @@ class OneHotCategorical:
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
self.logits = logits
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
def log_prob(self,x):
return jt.log(self.probs)[0,x]
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)
class Categorical:
def __init__(self, probs=None, logits=None):
OneHotCategorical.__init__(self, probs, logits)
@ -53,3 +67,43 @@ class Categorical:
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim-1)
return (one_hot * index).sum(-1)
def log_prob(self, x):
return jt.log(self.probs)[0,x]
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)
class Normal:
def __init__(self,mu,sigma):
self.mu = mu
self.sigma = sigma
def sample(self,sample_shape):
return jt.normal(self.mu, self.sigma, sample_shape)
def log_prob(self, x):
var = self.sigma**2
log_scale = jt.log(self.sigma)
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
def entropy(self):
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
def kl_divergence(cur_dist,old_dist):
assert isinstance(cur_dist,type(old_dist))
if isinstance(cur_dist,Normal):
vr = (cur_dist.sigma / old_dist.sigma)**2
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
return 0.5*(vr+t1-1-jt.log(vr))
if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):# ?
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
t[jt.array((old_dist.probs == 0))] = math.inf
t[jt.array((cur_dist.probs == 0))] = 0
return t.sum(-1)

View File

@ -40,8 +40,37 @@ class TestOneHot(unittest.TestCase):
y = a.sample([2,3])
y.sync()
assert y.shape == [2,3]
def test_normal(self):
import torch
for _ in range(10):
mu = np.random.uniform(-1,1)
sigma = np.random.uniform(0,2)
jn = jd.Normal(mu,sigma)
tn = torch.distributions.Normal(mu,sigma)
assert np.allclose(jn.entropy().data,tn.entropy().numpy())
x = np.random.uniform(-1,1)
# print(jn.log_prob(x))
# print(tn.log_prob(torch.tensor(x)))
assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x)))
mu2 = np.random.uniform(-1,1)
sigma2 = np.random.uniform(0,2)
jn2 = jd.Normal(mu2,sigma2)
tn2 = torch.distributions.Normal(mu2,sigma2)
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
def test_categorical(self):
import torch
for _ in range(10):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
x = np.random.randint(0,10)
# print(jc.log_prob(x),tc.log_prob(x))
assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x)))
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
if __name__ == "__main__":
unittest.main()