mirror of https://github.com/Jittor/Jittor
191 lines
6.2 KiB
Python
191 lines
6.2 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2023 Jittor. All Rights Reserved.
|
|
# Maintainers:
|
|
# Haoyang Peng <2247838039@qq.com>
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
#
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import math
|
|
import os
|
|
import numpy as np
|
|
import jittor as jt
|
|
from jittor import nn
|
|
from jittor.nn import binary_cross_entropy_with_logits
|
|
from jittor import lgamma, igamma
|
|
from jittor.math_util.gamma import gamma_grad, sample_gamma
|
|
|
|
def simple_presum(x):
|
|
src = '''
|
|
__inline_static__
|
|
@python.jittor.auto_parallel(1)
|
|
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
|
|
out[i0*(nl+1)] = 0;
|
|
for (int i=0; i<nl; i++)
|
|
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i];
|
|
}
|
|
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]);
|
|
'''
|
|
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
|
|
cpu_src=src, cuda_src=src)
|
|
|
|
|
|
class OneHotCategorical:
|
|
def __init__(self, probs=None, logits=None):
|
|
Categorical.__init__(self, probs, logits)
|
|
|
|
def sample(self, sample_shape=[]):
|
|
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
|
rand = jt.rand(shape)
|
|
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
|
return one_hot
|
|
|
|
def log_prob(self, x):
|
|
x = jt.argmax(x, dim=-1)[0]
|
|
return Categorical.log_prob(self, x)
|
|
|
|
def entropy(self):
|
|
p_log_p = self.logits * self.probs
|
|
return -p_log_p.sum(-1)
|
|
|
|
|
|
class Categorical:
|
|
def __init__(self, probs=None, logits=None):
|
|
assert not (probs is None and logits is None)
|
|
if probs is None:
|
|
# cannot align to pytorch
|
|
probs = jt.sigmoid(logits)
|
|
probs = probs / probs.sum(-1, True)
|
|
if logits is None:
|
|
logits = jt.safe_log(probs)
|
|
with jt.no_grad():
|
|
self.probs = probs
|
|
self.logits = logits
|
|
self.cum_probs = simple_presum(self.probs)
|
|
self.cum_probs_l = self.cum_probs[..., :-1]
|
|
self.cum_probs_r = self.cum_probs[..., 1:]
|
|
|
|
def sample(self, sample_shape=()):
|
|
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
|
rand = jt.rand(shape)
|
|
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
|
index = one_hot.index(one_hot.ndim - 1)
|
|
return (one_hot * index).sum(-1)
|
|
|
|
def log_prob(self, x):
|
|
a = self.probs.ndim
|
|
b = x.ndim
|
|
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
|
|
indexes = indexes + (x,)
|
|
return jt.safe_log(self.probs).getitem(indexes)
|
|
|
|
def entropy(self):
|
|
p_log_p = self.logits * self.probs
|
|
return -p_log_p.sum(-1)
|
|
|
|
|
|
class Normal:
|
|
def __init__(self, mu, sigma):
|
|
self.mu = mu
|
|
self.sigma = sigma
|
|
|
|
def sample(self, sample_shape=None):
|
|
return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape)
|
|
|
|
def log_prob(self, x):
|
|
var = self.sigma**2
|
|
log_scale = jt.safe_log(self.sigma)
|
|
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
|
|
|
|
def entropy(self):
|
|
return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma)
|
|
|
|
|
|
class Uniform:
|
|
def __init__(self,low,high):
|
|
self.low = low
|
|
self.high = high
|
|
assert high > low
|
|
|
|
def sample(self,sample_shape):
|
|
return jt.uniform(self.low,self.high,sample_shape)
|
|
|
|
def log_prob(self,x):
|
|
if x < self.low or x >= self.high:
|
|
return math.inf
|
|
return -jt.safe_log(self.high - self.low)
|
|
|
|
def entropy(self):
|
|
return jt.safe_log(self.high - self.low)
|
|
|
|
|
|
class Geometric:
|
|
def __init__(self,p=None,logits=None):
|
|
assert (p is not None) or (logits is not None)
|
|
assert 0 < p and p < 1
|
|
if p is None:
|
|
self.prob = jt.sigmoid(logits)
|
|
self.logits = logits
|
|
elif logits is None:
|
|
self.prob = p
|
|
self.logits = -jt.safe_log(1. / p - 1)
|
|
|
|
def sample(self, sample_shape):
|
|
u = jt.rand(sample_shape)
|
|
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor_int()
|
|
|
|
def log_prob(self, x):
|
|
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
|
|
|
|
def entropy(self):
|
|
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
|
|
|
|
|
|
class GammaDistribution:
|
|
'''
|
|
For now only support gamma distribution.
|
|
'''
|
|
def __init__(self, concentration, rate):
|
|
self.concentration = concentration
|
|
self.rate = rate
|
|
self.lgamma_alpha = lgamma.apply(jt.array([concentration,]))
|
|
|
|
def sample(self, shape):
|
|
return sample_gamma(self.concentration, shape)
|
|
|
|
def cdf(self, value):
|
|
return igamma(self.concentration, value)
|
|
|
|
def log_prob(self, value):
|
|
return (self.concentration * jt.log(self.rate) +
|
|
(self.concentration - 1) * jt.log(value) -
|
|
self.rate * value - self.lgamma_alpha)
|
|
|
|
def mean(self):
|
|
return self.concentration / self.rate
|
|
|
|
def mode(self):
|
|
return np.minimum((self.concentration - 1) / self.rate, 1)
|
|
|
|
def variance(self):
|
|
return self.concentration / (self.rate * self.rate)
|
|
|
|
|
|
def kl_divergence(cur_dist, old_dist):
|
|
assert isinstance(cur_dist, type(old_dist))
|
|
if isinstance(cur_dist, Normal):
|
|
vr = (cur_dist.sigma / old_dist.sigma)**2
|
|
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
|
|
return 0.5*(vr+t1-1-jt.safe_log(vr))
|
|
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
|
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
|
return t.sum(-1)
|
|
if isinstance(cur_dist, Uniform):
|
|
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
|
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
|
|
res = math.inf
|
|
return res
|
|
if isinstance(cur_dist, Geometric):
|
|
return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
|