mirror of https://github.com/Jittor/Jittor
add normal, log_prob,entropy ,kl_divergence
This commit is contained in:
parent
04be711ddb
commit
6c93dbb5fb
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.2.68'
|
||||
__version__ = '1.2.2.69'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Haoyang Peng <2247838039@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import math
|
||||
import numpy as np
|
||||
import jittor as jt
|
||||
|
||||
def simple_presum(x):
|
||||
|
@ -29,20 +32,31 @@ class OneHotCategorical:
|
|||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
self.logits = logits
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
rand = jt.rand(shape)
|
||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
||||
return one_hot
|
||||
|
||||
|
||||
|
||||
|
||||
def log_prob(self,x):
|
||||
return jt.log(self.probs)[0,x]
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
logits = jt.clamp(self.logits,min_v=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
class Categorical:
|
||||
def __init__(self, probs=None, logits=None):
|
||||
OneHotCategorical.__init__(self, probs, logits)
|
||||
|
@ -53,3 +67,43 @@ class Categorical:
|
|||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
||||
index = one_hot.index(one_hot.ndim-1)
|
||||
return (one_hot * index).sum(-1)
|
||||
|
||||
def log_prob(self, x):
|
||||
return jt.log(self.probs)[0,x]
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
logits = jt.clamp(self.logits,min_v=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
class Normal:
|
||||
def __init__(self,mu,sigma):
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def sample(self,sample_shape):
|
||||
return jt.normal(self.mu, self.sigma, sample_shape)
|
||||
|
||||
def log_prob(self, x):
|
||||
var = self.sigma**2
|
||||
log_scale = jt.log(self.sigma)
|
||||
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
|
||||
|
||||
def entropy(self):
|
||||
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
|
||||
|
||||
|
||||
def kl_divergence(cur_dist,old_dist):
|
||||
assert isinstance(cur_dist,type(old_dist))
|
||||
if isinstance(cur_dist,Normal):
|
||||
vr = (cur_dist.sigma / old_dist.sigma)**2
|
||||
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
|
||||
return 0.5*(vr+t1-1-jt.log(vr))
|
||||
if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical):# ?
|
||||
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
||||
t[jt.array((old_dist.probs == 0))] = math.inf
|
||||
t[jt.array((cur_dist.probs == 0))] = 0
|
||||
return t.sum(-1)
|
||||
|
|
@ -40,8 +40,37 @@ class TestOneHot(unittest.TestCase):
|
|||
y = a.sample([2,3])
|
||||
y.sync()
|
||||
assert y.shape == [2,3]
|
||||
|
||||
def test_normal(self):
|
||||
import torch
|
||||
for _ in range(10):
|
||||
mu = np.random.uniform(-1,1)
|
||||
sigma = np.random.uniform(0,2)
|
||||
jn = jd.Normal(mu,sigma)
|
||||
tn = torch.distributions.Normal(mu,sigma)
|
||||
assert np.allclose(jn.entropy().data,tn.entropy().numpy())
|
||||
x = np.random.uniform(-1,1)
|
||||
# print(jn.log_prob(x))
|
||||
# print(tn.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x)))
|
||||
mu2 = np.random.uniform(-1,1)
|
||||
sigma2 = np.random.uniform(0,2)
|
||||
jn2 = jd.Normal(mu2,sigma2)
|
||||
tn2 = torch.distributions.Normal(mu2,sigma2)
|
||||
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
||||
|
||||
|
||||
def test_categorical(self):
|
||||
import torch
|
||||
for _ in range(10):
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
|
||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||
x = np.random.randint(0,10)
|
||||
# print(jc.log_prob(x),tc.log_prob(x))
|
||||
assert np.allclose(jc.log_prob(x),tc.log_prob(torch.tensor(x)))
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue