add distribution for rl

This commit is contained in:
Dun Liang 2021-04-29 23:03:16 +08:00
parent b75b6def89
commit 862d564d85
9 changed files with 192 additions and 32 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.62'
__version__ = '1.2.2.63'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -0,0 +1,56 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
def simple_presum(x):
src = '''
__inline_static__
@python.jittor.auto_parallel(1)
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
out[i0*(nl+1)] = 0;
for (int i=0; i<nl; i++)
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*(nl+1)+i];
}
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->num);
'''
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
cpu_src=src, cuda_src=src)
class OneHotCategorical:
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
class Categorical:
def __init__(self, probs=None, logits=None):
OneHotCategorical.__init__(self, probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim-1)
return (one_hot * index).sum(-1)

View File

@ -787,6 +787,39 @@ def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
return out
def get_max_memory_treemap(build_by=0, do_print=True):
'''show treemap of max memory consumption
Example::
net = jt.models.resnet18()
with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
imgs = jt.randn((1,3,224,224))
net(imgs).sync()
jt.get_max_memory_treemap()
Output::
|
./python/jittor/test/test_memory_profiler.py:100(test_sample)
| [19.03 MB; 29.67%]
| ./python/jittor/test/test_memory_profiler.py:100
| |
| ./python/jittor/__init__.py:730(__call__)
| [19.03 MB; 29.67%]
| ./python/jittor/__init__.py:730
| |
| ./python/jittor/models/resnet.py:152(execute)
| [19.03 MB; 29.67%]
| ./python/jittor/models/resnet.py:152
| |
| ./python/jittor/models/resnet.py:142(_forward_impl)
| | [6.13 MB; 9.55%]
| | ./python/jittor/models/resnet.py:142
| | |
'''
div1 = "[!@#div1!@#]"
div2 = "[!@#div2!@#]"
div3 = "[!@#div3!@#]"

View File

@ -338,6 +338,8 @@ def log_softmax(x,dim=None):
def log_sigmoid(x):
return jt.log(jt.sigmoid(x))
def logsumexp(x, dim, keepdim=False):
return x.exp().sum(dim, keepdim).log()
class Identity(Module):
def __init__(self, *args, **kwargs):

View File

@ -0,0 +1,47 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.distributions as jd
class TestOneHot(unittest.TestCase):
def test_presum(self):
a = jt.array([[1,2,3,4]])
b = jd.simple_presum(a)
assert (b.data == [[0,1,3,6,10]]).all()
def test_one_hot(self):
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
x = a.sample().numpy()
for i in range(1000):
x += a.sample().numpy()
print(x)
assert (x > 200).all()
y = a.sample([2,3])
y.sync()
assert y.shape == [2,3,4]
def test_cate(self):
a = jd.Categorical(jt.array([0.25, 0.25, 0.25, 0.25]))
x =np.array([0,0,0,0])
for i in range(1000):
x[a.sample().item()]+=1
assert (x > 200).all()
y = a.sample([2,3])
y.sync()
assert y.shape == [2,3]
if __name__ == "__main__":
unittest.main()

View File

@ -93,5 +93,14 @@ class TestMemoryProfiler(unittest.TestCase):
assert(out_[4].endswith('(_run_module_as_main)'))
assert(out_[8].endswith('(_run_code)'))
def test_sample(self):
net = jt.models.resnet18()
with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
imgs = jt.randn((1,3,224,224))
net(imgs).sync()
jt.get_max_memory_treemap()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,12 @@
from socketserver import ThreadingTCPServer
import socket
def handle_connect(req:socket.socket, c_addr, server):
print("get connect", c_addr, req)
while True:
buf = req.recv(2048)
print(buf)
server = ThreadingTCPServer(("127.0.0.1", 8900), handle_connect)
server.serve_forever()

View File

@ -9,6 +9,6 @@ docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile .
# docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
while true; do
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 -v ~/https:/https jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/fullchain.pem --key=/https/privkey.pem --host=0.0.0.0"
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 -v /etc/letsencrypt/:/https jittor/converter_server bash -c "python3.7 -m pip install -U jittor && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/live/randonl.me/fullchain.pem --key=/https/live/randonl.me/privkey.pem --host=0.0.0.0"
sleep 10
done

View File

@ -10,6 +10,8 @@
#include "ops/op_register.h"
#include "misc/cuda_flags.h"
#define __inline_static__ inline static
#ifndef JIT
namespace jittor {
@ -130,38 +132,37 @@ void CodeOp::jit_prepare(JK& jk) {
jk << _CS("][out") << JK::hex(i) << _CS("_type:")
<< _outputs[i]->dtype();
}
if (flags.get(NodeFlags::_cuda)) {
jk << _CS("][HEADER:") << cuda_header;
ASSERT(cuda_src.size());
jk << _CS("\nnamespace jittor {\n");
int i=0;
// move cuda kernel function into header
for (; i<cuda_src.size(); i++) {
if (cuda_src[i] == ' ' || cuda_src[i] == '\t' || cuda_src[i] == '\n') {
jk << cuda_src[i];
} else
if (cuda_src[i] == '_') {
int presum = 0;
while (i < cuda_src.size()) {
jk << cuda_src[i];
if (cuda_src[i] == '{') presum ++;
else if (cuda_src[i] == '}') {
presum--;
if (presum==0)
break;
}
i++;
string& header = flags.get(NodeFlags::_cuda) ?
cuda_header : cpu_header;
string& src = flags.get(NodeFlags::_cuda) ?
cuda_src : cpu_src;
jk << _CS("][HEADER:") << header;
CHECK(src.size());
jk << _CS("\nnamespace jittor {\n");
int i=0;
// move cuda kernel function into header
for (; i<src.size(); i++) {
if (src[i] == ' ' || src[i] == '\t' || src[i] == '\n') {
jk << src[i];
} else
if (src[i] == '_') {
int presum = 0;
while (i < src.size()) {
jk << src[i];
if (src[i] == '{') presum ++;
else if (src[i] == '}') {
presum--;
if (presum==0)
break;
}
} else break;
}
jk << _CS("}][CODE:");
for (; i<cuda_src.size(); i++) jk << cuda_src[i];
jk << ']';
} else {
jk << _CS("][HEADER:") << cpu_header;
jk << _CS("][CODE:") << cpu_src << ']';
ASSERT(cpu_src.size());
i++;
}
} else break;
}
jk << _CS("}][CODE:");
for (; i<src.size(); i++) jk << src[i];
jk << ']';
}
} // jittor