mirror of https://github.com/Jittor/Jittor
add distribution for rl
This commit is contained in:
parent
46e6a66411
commit
c31360b96c
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.62'
|
||||
__version__ = '1.2.2.63'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
|
||||
def simple_presum(x):
|
||||
src = '''
|
||||
__inline_static__
|
||||
@python.jittor.auto_parallel(1)
|
||||
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
|
||||
out[i0*(nl+1)] = 0;
|
||||
for (int i=0; i<nl; i++)
|
||||
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*(nl+1)+i];
|
||||
}
|
||||
|
||||
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->num);
|
||||
'''
|
||||
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
|
||||
cpu_src=src, cuda_src=src)
|
||||
|
||||
|
||||
class OneHotCategorical:
|
||||
def __init__(self, probs=None, logits=None):
|
||||
assert not (probs is None and logits is None)
|
||||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.cum_probs = simple_presum(probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
rand = jt.rand(shape)
|
||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
||||
return one_hot
|
||||
|
||||
|
||||
|
||||
class Categorical:
|
||||
def __init__(self, probs=None, logits=None):
|
||||
OneHotCategorical.__init__(self, probs, logits)
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
rand = jt.rand(shape)
|
||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
||||
index = one_hot.index(one_hot.ndim-1)
|
||||
return (one_hot * index).sum(-1)
|
|
@ -787,6 +787,39 @@ def print_tree(now, max_memory_size, prefix1, prefix2, build_by):
|
|||
return out
|
||||
|
||||
def get_max_memory_treemap(build_by=0, do_print=True):
|
||||
'''show treemap of max memory consumption
|
||||
|
||||
Example::
|
||||
|
||||
net = jt.models.resnet18()
|
||||
with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
|
||||
imgs = jt.randn((1,3,224,224))
|
||||
net(imgs).sync()
|
||||
jt.get_max_memory_treemap()
|
||||
|
||||
Output::
|
||||
|
||||
|
|
||||
├─./python/jittor/test/test_memory_profiler.py:100(test_sample)
|
||||
| [19.03 MB; 29.67%]
|
||||
| ./python/jittor/test/test_memory_profiler.py:100
|
||||
| |
|
||||
| └─./python/jittor/__init__.py:730(__call__)
|
||||
| [19.03 MB; 29.67%]
|
||||
| ./python/jittor/__init__.py:730
|
||||
| |
|
||||
| └─./python/jittor/models/resnet.py:152(execute)
|
||||
| [19.03 MB; 29.67%]
|
||||
| ./python/jittor/models/resnet.py:152
|
||||
| |
|
||||
| ├─./python/jittor/models/resnet.py:142(_forward_impl)
|
||||
| | [6.13 MB; 9.55%]
|
||||
| | ./python/jittor/models/resnet.py:142
|
||||
| | |
|
||||
|
||||
|
||||
|
||||
'''
|
||||
div1 = "[!@#div1!@#]"
|
||||
div2 = "[!@#div2!@#]"
|
||||
div3 = "[!@#div3!@#]"
|
||||
|
|
|
@ -338,6 +338,8 @@ def log_softmax(x,dim=None):
|
|||
def log_sigmoid(x):
|
||||
return jt.log(jt.sigmoid(x))
|
||||
|
||||
def logsumexp(x, dim, keepdim=False):
|
||||
return x.exp().sum(dim, keepdim).log()
|
||||
|
||||
class Identity(Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.distributions as jd
|
||||
|
||||
|
||||
class TestOneHot(unittest.TestCase):
|
||||
def test_presum(self):
|
||||
a = jt.array([[1,2,3,4]])
|
||||
b = jd.simple_presum(a)
|
||||
assert (b.data == [[0,1,3,6,10]]).all()
|
||||
|
||||
def test_one_hot(self):
|
||||
a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25]))
|
||||
x = a.sample().numpy()
|
||||
for i in range(1000):
|
||||
x += a.sample().numpy()
|
||||
print(x)
|
||||
assert (x > 200).all()
|
||||
y = a.sample([2,3])
|
||||
y.sync()
|
||||
assert y.shape == [2,3,4]
|
||||
|
||||
def test_cate(self):
|
||||
a = jd.Categorical(jt.array([0.25, 0.25, 0.25, 0.25]))
|
||||
x =np.array([0,0,0,0])
|
||||
for i in range(1000):
|
||||
x[a.sample().item()]+=1
|
||||
assert (x > 200).all()
|
||||
y = a.sample([2,3])
|
||||
y.sync()
|
||||
assert y.shape == [2,3]
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -93,5 +93,14 @@ class TestMemoryProfiler(unittest.TestCase):
|
|||
assert(out_[4].endswith('(_run_module_as_main)'))
|
||||
assert(out_[8].endswith('(_run_code)'))
|
||||
|
||||
def test_sample(self):
|
||||
net = jt.models.resnet18()
|
||||
with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
|
||||
imgs = jt.randn((1,3,224,224))
|
||||
net(imgs).sync()
|
||||
jt.get_max_memory_treemap()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from socketserver import ThreadingTCPServer
|
||||
import socket
|
||||
|
||||
def handle_connect(req:socket.socket, c_addr, server):
|
||||
print("get connect", c_addr, req)
|
||||
while True:
|
||||
buf = req.recv(2048)
|
||||
|
||||
print(buf)
|
||||
|
||||
server = ThreadingTCPServer(("127.0.0.1", 8900), handle_connect)
|
||||
server.serve_forever()
|
|
@ -9,6 +9,6 @@ docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile .
|
|||
|
||||
# docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
|
||||
while true; do
|
||||
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 -v ~/https:/https jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/fullchain.pem --key=/https/privkey.pem --host=0.0.0.0"
|
||||
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 -v /etc/letsencrypt/:/https jittor/converter_server bash -c "python3.7 -m pip install -U jittor && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/live/randonl.me/fullchain.pem --key=/https/live/randonl.me/privkey.pem --host=0.0.0.0"
|
||||
sleep 10
|
||||
done
|
|
@ -10,6 +10,8 @@
|
|||
#include "ops/op_register.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
#define __inline_static__ inline static
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
namespace jittor {
|
||||
|
@ -130,38 +132,37 @@ void CodeOp::jit_prepare(JK& jk) {
|
|||
jk << _CS("][out") << JK::hex(i) << _CS("_type:")
|
||||
<< _outputs[i]->dtype();
|
||||
}
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
jk << _CS("][HEADER:") << cuda_header;
|
||||
ASSERT(cuda_src.size());
|
||||
jk << _CS("\nnamespace jittor {\n");
|
||||
int i=0;
|
||||
// move cuda kernel function into header
|
||||
for (; i<cuda_src.size(); i++) {
|
||||
if (cuda_src[i] == ' ' || cuda_src[i] == '\t' || cuda_src[i] == '\n') {
|
||||
jk << cuda_src[i];
|
||||
} else
|
||||
if (cuda_src[i] == '_') {
|
||||
int presum = 0;
|
||||
while (i < cuda_src.size()) {
|
||||
jk << cuda_src[i];
|
||||
if (cuda_src[i] == '{') presum ++;
|
||||
else if (cuda_src[i] == '}') {
|
||||
presum--;
|
||||
if (presum==0)
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
string& header = flags.get(NodeFlags::_cuda) ?
|
||||
cuda_header : cpu_header;
|
||||
string& src = flags.get(NodeFlags::_cuda) ?
|
||||
cuda_src : cpu_src;
|
||||
|
||||
jk << _CS("][HEADER:") << header;
|
||||
CHECK(src.size());
|
||||
jk << _CS("\nnamespace jittor {\n");
|
||||
int i=0;
|
||||
// move cuda kernel function into header
|
||||
for (; i<src.size(); i++) {
|
||||
if (src[i] == ' ' || src[i] == '\t' || src[i] == '\n') {
|
||||
jk << src[i];
|
||||
} else
|
||||
if (src[i] == '_') {
|
||||
int presum = 0;
|
||||
while (i < src.size()) {
|
||||
jk << src[i];
|
||||
if (src[i] == '{') presum ++;
|
||||
else if (src[i] == '}') {
|
||||
presum--;
|
||||
if (presum==0)
|
||||
break;
|
||||
}
|
||||
} else break;
|
||||
}
|
||||
jk << _CS("}][CODE:");
|
||||
for (; i<cuda_src.size(); i++) jk << cuda_src[i];
|
||||
jk << ']';
|
||||
} else {
|
||||
jk << _CS("][HEADER:") << cpu_header;
|
||||
jk << _CS("][CODE:") << cpu_src << ']';
|
||||
ASSERT(cpu_src.size());
|
||||
i++;
|
||||
}
|
||||
} else break;
|
||||
}
|
||||
jk << _CS("}][CODE:");
|
||||
for (; i<src.size(); i++) jk << src[i];
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
Loading…
Reference in New Issue