mirror of https://github.com/Jittor/Jittor
optimize ring buffer and copy free array op
This commit is contained in:
parent
1e21b660bb
commit
143bb01d8e
|
@ -7,9 +7,12 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.7'
|
||||
__version__ = '1.2.0.8'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
ori_float = float
|
||||
ori_bool = bool
|
||||
from . import compiler
|
||||
from .compiler import LOG, has_cuda
|
||||
from .compiler import compile_custom_ops, compile_custom_op
|
||||
|
@ -874,16 +877,13 @@ def to_float(v):
|
|||
def to_bool(v):
|
||||
dtype = str(v.dtype)
|
||||
assert dtype.startswith("int") or dtype=="bool"
|
||||
return bool(v.item())
|
||||
return ori_bool(v.item())
|
||||
|
||||
Var.item = item
|
||||
Var.__int__ = to_int
|
||||
Var.__float__ = to_float
|
||||
Var.__bool__ = to_bool
|
||||
|
||||
ori_int = int
|
||||
ori_float = float
|
||||
|
||||
int = int32
|
||||
Var.int = Var.int32
|
||||
float = float32
|
||||
|
|
|
@ -15,8 +15,6 @@ from jittor.dataset.utils import get_random_list, get_order_list, collate_batch,
|
|||
from collections.abc import Sequence, Mapping
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
from jittor_utils import ring_buffer
|
||||
from jittor_utils.ring_buffer import RingBuffer
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from jittor_utils import LOG
|
||||
|
@ -30,8 +28,8 @@ img_open_hook = HookTimer(Image, "open")
|
|||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
buffer = mp.Array('c', buffer_size, lock=False)
|
||||
self.buffer = RingBuffer(buffer)
|
||||
self.buffer = jt.RingBuffer(buffer_size)
|
||||
|
||||
self.status = mp.Array('f', 5, lock=False)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
|
||||
self.p.daemon = True
|
||||
|
@ -253,13 +251,12 @@ Example::
|
|||
msg.append(f"progress:{self.last_id}/{self.batch_len}")
|
||||
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
|
||||
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
|
||||
msg.append(f"recv_raw_call: {ring_buffer.recv_raw_call}")
|
||||
msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
|
||||
msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)")
|
||||
for i in range(self.num_workers):
|
||||
w = self.workers[i]
|
||||
s = w.status
|
||||
msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer.allocator}")
|
||||
msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}")
|
||||
LOG.i('\n'.join(msg))
|
||||
|
||||
def _stop_all_workers(self):
|
||||
|
|
|
@ -31,8 +31,10 @@ pytype_map = {
|
|||
"uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"],
|
||||
"uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"],
|
||||
"void": ["...", "GET_PY_NONE", "..."],
|
||||
"PyObject*": ["","",""],
|
||||
}
|
||||
def get_pytype_map(T, i):
|
||||
assert T != ""
|
||||
if T in pytype_map:
|
||||
return pytype_map[T][i]
|
||||
return ["from_py_object", "to_py_object", "is_type"][i]+"<"+T+">"
|
||||
|
@ -204,7 +206,7 @@ def get_def_code(df, scope_name, pyname, self_as_arg0=False):
|
|||
func_call = f"(GET_RAW_PTR({scope_name},self))->" + func_call
|
||||
if pyname == "__init__":
|
||||
# XXX->xxx(...) ---> new XXX xxx(...)
|
||||
assert "->" in func_call
|
||||
assert "->" in func_call, func_call
|
||||
func_call = "new " + func_call.replace("->", " ")
|
||||
if no_need_convert:
|
||||
func_quick_check_runable = ""
|
||||
|
|
|
@ -39,7 +39,7 @@ class TestReduceOp(unittest.TestCase):
|
|||
idims = [(), (0,), (1,), (2,), (3,), (0, 2), (1,3), (1,2,3), 2, 3]
|
||||
|
||||
iop = [ op[7:] for op in dir(jt) if op.startswith("reduce_")]
|
||||
assert len(iop) >= 10
|
||||
assert len(iop) >= 10, iop
|
||||
for a in ia:
|
||||
check(a, iop[0], idims[0])
|
||||
for op in iop:
|
||||
|
|
|
@ -65,8 +65,9 @@ class TestRelu(unittest.TestCase):
|
|||
# ***************************************************************
|
||||
# Test GELU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.GELU(), tnn.GELU())
|
||||
if hasattr(tnn, "GELU"):
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.GELU(), tnn.GELU())
|
||||
|
||||
# ***************************************************************
|
||||
# Test Softplus Layer
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
from .test_core import expect_error
|
||||
from jittor.dataset.mnist import MNIST
|
||||
import jittor.transform as trans
|
||||
from tqdm import tqdm
|
||||
|
||||
def test_ring_buffer():
|
||||
buffer = jt.RingBuffer(1000)
|
||||
def test_send_recv(data):
|
||||
print("test send recv", type(data))
|
||||
buffer.push(data)
|
||||
recv = buffer.pop()
|
||||
if isinstance(data, np.ndarray):
|
||||
assert (recv == data).all()
|
||||
else:
|
||||
assert data == recv
|
||||
|
||||
n_byte = 0
|
||||
test_send_recv(1)
|
||||
n_byte += 1 + 8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
test_send_recv(100000000000)
|
||||
n_byte += 1 + 8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
|
||||
test_send_recv(1e-5)
|
||||
n_byte += 1 + 8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
test_send_recv(100000000000.0)
|
||||
n_byte += 1 + 8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
|
||||
test_send_recv("float32")
|
||||
n_byte += 1 + 8 + 7
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
test_send_recv("")
|
||||
n_byte += 1 + 8 + 0
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
test_send_recv("xxxxxxxxxx")
|
||||
n_byte += 1 + 8 + 10
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
|
||||
test_send_recv([1,0.2])
|
||||
n_byte += 1 + 8 + 1 + 8 + 1 + 8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
test_send_recv({'asd':1})
|
||||
n_byte += 1 + 8 + 1 + 8 + 3 + 1 + 8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
|
||||
test_send_recv(np.random.rand(10,10))
|
||||
n_byte += 1 + 16 + 2 + 10*10*8
|
||||
assert n_byte == buffer.total_pop() and n_byte == buffer.total_push()
|
||||
test_send_recv(test_ring_buffer)
|
||||
|
||||
expect_error(lambda: test_send_recv(np.random.rand(10,1000)))
|
||||
|
||||
|
||||
class TestRingBuffer(unittest.TestCase):
|
||||
|
||||
def test_ring_buffer(self):
|
||||
test_ring_buffer()
|
||||
|
||||
def test_dataset(self):
|
||||
return
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=300, shuffle=True)
|
||||
self.train_loader.num_workers = 1
|
||||
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
|
||||
# self.train_loader.display_worker_status()
|
||||
if batch_idx > 30:
|
||||
break
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -213,7 +213,7 @@ def to_tensor(img):
|
|||
img_ = transform.to_tensor(img)
|
||||
"""
|
||||
if isinstance(img, Image.Image):
|
||||
return np.array(img).transpose((2,0,1)) / np.float32(255)
|
||||
return np.array(img).transpose((2,0,1)) * np.float32(1.0/255.0)
|
||||
return img
|
||||
|
||||
|
||||
|
@ -323,7 +323,7 @@ class ImageNormalize:
|
|||
if isinstance(img, Image.Image):
|
||||
img = (np.array(img).transpose((2,0,1)) \
|
||||
- self.mean*np.float32(255.)) \
|
||||
/ (self.std*np.float32(255.))
|
||||
* (np.float32(1./255.)/self.std)
|
||||
else:
|
||||
img = (img - self.mean) / self.std
|
||||
return img
|
||||
|
|
|
@ -127,7 +127,7 @@ def str_to_char_array(s, array_len):
|
|||
return a
|
||||
|
||||
def char_array_to_str(a):
|
||||
return str(a.tostring(), 'ascii').strip()
|
||||
return str(a.tobytes(), 'ascii').strip()
|
||||
|
||||
class RingBuffer:
|
||||
def __init__(self, buffer):
|
||||
|
|
|
@ -64,7 +64,7 @@ struct NanoVector {
|
|||
// @pyjt(__init__)
|
||||
inline NanoVector(const NanoVector& nv) : data(nv.data), offset(nv.offset) {}
|
||||
|
||||
void clear() { data = offset = 0; }
|
||||
inline void clear() { data = offset = 0; }
|
||||
|
||||
// @pyjt(__len__, __map_len__)
|
||||
inline int size() const {
|
||||
|
@ -158,10 +158,22 @@ struct NanoVector {
|
|||
for (auto a : v) push_back_check_overflow(a);
|
||||
}
|
||||
|
||||
inline static NanoVector make(const int64* v, int n) {
|
||||
NanoVector nv;
|
||||
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
|
||||
return nv;
|
||||
}
|
||||
|
||||
inline static NanoVector make(const int32* v, int n) {
|
||||
NanoVector nv;
|
||||
for (int i=0; i<n; i++) nv.push_back_check_overflow(v[i]);
|
||||
return nv;
|
||||
}
|
||||
|
||||
inline NanoVector(int64 x) { push_back(x); }
|
||||
|
||||
// @pyjt(__repr__)
|
||||
string to_string() {
|
||||
inline string to_string() {
|
||||
string s="[";
|
||||
for (int i=0; i<size(); i++) {
|
||||
s += S(at(i));
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <sys/mman.h>
|
||||
#include "common.h"
|
||||
#include "misc/ring_buffer.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
RingBuffer::RingBuffer(uint64 size, bool multiprocess) : m(multiprocess), cv(multiprocess) {
|
||||
int i=0;
|
||||
for (;(1ll<<i)<size;i++);
|
||||
size_mask = (1ll<<i)-1;
|
||||
this->size = size_mask+1;
|
||||
size_bit = i;
|
||||
l = r = is_wait = 0;
|
||||
is_multiprocess = multiprocess;
|
||||
}
|
||||
|
||||
RingBuffer::~RingBuffer() {
|
||||
}
|
||||
|
||||
|
||||
RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess) {
|
||||
int i=0;
|
||||
for (;(1ll<<i)<size;i++);
|
||||
uint64 size_mask = (1ll<<i)-1;
|
||||
size = size_mask+1;
|
||||
uint64 total_size = sizeof(RingBuffer) + size;
|
||||
void* ptr = multiprocess ?
|
||||
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0) :
|
||||
mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0) :
|
||||
// mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, -1, 0) :
|
||||
(void*)malloc(total_size);
|
||||
std::memset(ptr, 0, total_size);
|
||||
auto rb = (RingBuffer*)ptr;
|
||||
new (rb) RingBuffer(size, multiprocess);
|
||||
return rb;
|
||||
}
|
||||
|
||||
void RingBuffer::free_ring_buffer(RingBuffer* rb) {
|
||||
uint64 total_size = sizeof(RingBuffer) + rb->size;
|
||||
auto is_multiprocess = rb->is_multiprocess;
|
||||
rb->~RingBuffer();
|
||||
if (is_multiprocess) {
|
||||
munmap(rb, total_size);
|
||||
} else {
|
||||
rb->~RingBuffer();
|
||||
free((void*)rb);
|
||||
}
|
||||
}
|
||||
|
||||
// test
|
||||
|
||||
JIT_TEST(ring_buffer_benchmark) {
|
||||
size_t n = 1ll << 20;
|
||||
size_t size = 1<<15;
|
||||
// size_t n = 1ll << 30;
|
||||
// size_t size = 1<<20;
|
||||
// size_t n = 1ll << 10;
|
||||
// size_t size = 1<<5;
|
||||
RingBuffer* rb = RingBuffer::make_ring_buffer(size, 0);
|
||||
std::thread p([&]() {
|
||||
for (size_t i=0; i<n; i++) {
|
||||
rb->push_t<int>(i);
|
||||
}
|
||||
});
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
size_t s = 0;
|
||||
for (size_t i=0; i<n; i++) {
|
||||
auto x = rb->pop_t<int>();
|
||||
s += x;
|
||||
}
|
||||
auto finish = std::chrono::high_resolution_clock::now();
|
||||
auto tt = std::chrono::duration_cast<std::chrono::nanoseconds>(finish-start).count();
|
||||
p.join();
|
||||
expect_error([&]() { rb->push(size+1); });
|
||||
RingBuffer::free_ring_buffer(rb);
|
||||
|
||||
LOGi << tt << tt*1.0/n;
|
||||
LOGi << s << (n*(n-1)/2);
|
||||
ASSERTop(s,==,(n*(n-1)/2));
|
||||
ASSERTop(tt*1.0/n,<=,50);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,205 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <pthread.h>
|
||||
#include <sys/mman.h>
|
||||
#include <cstring>
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct RingBuffer {
|
||||
|
||||
struct Mutex {
|
||||
pthread_mutex_t m;
|
||||
inline Mutex(bool multiprocess=0) {
|
||||
pthread_mutexattr_t attr;
|
||||
pthread_mutexattr_init(&attr);
|
||||
if (multiprocess)
|
||||
pthread_mutexattr_setpshared(&attr, PTHREAD_PROCESS_SHARED);
|
||||
ASSERT(0 == pthread_mutex_init((pthread_mutex_t*)&m, &attr));
|
||||
}
|
||||
|
||||
inline ~Mutex() {
|
||||
pthread_mutex_destroy(&m);
|
||||
}
|
||||
|
||||
inline void lock() {
|
||||
pthread_mutex_lock(&m);
|
||||
}
|
||||
|
||||
inline void unlock() {
|
||||
pthread_mutex_unlock(&m);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cond {
|
||||
pthread_cond_t cv;
|
||||
inline Cond(bool multiprocess=0) {
|
||||
pthread_condattr_t attr;
|
||||
pthread_condattr_init(&attr);
|
||||
if (multiprocess)
|
||||
pthread_condattr_setpshared(&attr, PTHREAD_PROCESS_SHARED);
|
||||
ASSERT(0 == pthread_cond_init((pthread_cond_t*)&cv, &attr));
|
||||
}
|
||||
|
||||
inline ~Cond() {
|
||||
pthread_cond_destroy(&cv);
|
||||
}
|
||||
|
||||
inline void wait(Mutex& m) {
|
||||
pthread_cond_wait(&cv, &m.m);
|
||||
}
|
||||
|
||||
inline void notify() {
|
||||
pthread_cond_signal(&cv);
|
||||
}
|
||||
};
|
||||
|
||||
struct MutexScope {
|
||||
Mutex* m;
|
||||
inline MutexScope(Mutex& m) : m(&m) { m.lock(); }
|
||||
inline ~MutexScope() { m->unlock(); }
|
||||
};
|
||||
|
||||
uint64 size;
|
||||
uint64 size_mask;
|
||||
uint64 size_bit;
|
||||
volatile uint64 l;
|
||||
volatile uint64 r;
|
||||
volatile int is_wait;
|
||||
bool is_multiprocess;
|
||||
Mutex m;
|
||||
Cond cv;
|
||||
char _ptr;
|
||||
|
||||
RingBuffer(uint64 size, bool multiprocess=false);
|
||||
~RingBuffer();
|
||||
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess);
|
||||
static void free_ring_buffer(RingBuffer* rb);
|
||||
|
||||
inline void wait() {
|
||||
MutexScope _(m);
|
||||
if (is_wait) {
|
||||
cv.notify();
|
||||
is_wait = 0;
|
||||
}
|
||||
is_wait = 1;
|
||||
cv.wait(m);
|
||||
}
|
||||
|
||||
inline void notify() {
|
||||
MutexScope _(m);
|
||||
cv.notify();
|
||||
is_wait = 0;
|
||||
}
|
||||
|
||||
inline void push(uint64 size, uint64& __restrict__ offset) {
|
||||
auto rr = offset;
|
||||
auto rr_next = rr + size;
|
||||
auto c1 = rr >> size_bit;
|
||||
auto c2 = (rr_next-1) >> size_bit;
|
||||
if (c1 != c2) {
|
||||
// if cross boundary
|
||||
rr = c2 << size_bit;
|
||||
rr_next = rr + size;
|
||||
}
|
||||
while (rr_next > l + this->size) {
|
||||
CHECKop(size,<=,this->size) << "Buffer size too small, please increase buffer size.";
|
||||
wait();
|
||||
}
|
||||
offset = rr_next;
|
||||
}
|
||||
|
||||
inline void commit_push(uint64 offset) {
|
||||
r = offset;
|
||||
if (is_wait)
|
||||
notify();
|
||||
}
|
||||
|
||||
inline void pop(uint64 size, uint64& __restrict__ offset) {
|
||||
auto ll = offset;
|
||||
auto ll_next = ll + size;
|
||||
auto c1 = ll >> size_bit;
|
||||
auto c2 = (ll_next-1) >> size_bit;
|
||||
if (c1 != c2) {
|
||||
// if cross boundary
|
||||
ll = c2 << size_bit;
|
||||
ll_next = ll + size;
|
||||
}
|
||||
while (ll_next > r) {
|
||||
ASSERT(size<=this->size);
|
||||
wait();
|
||||
}
|
||||
offset = ll_next;
|
||||
}
|
||||
|
||||
inline void commit_pop(uint64 offset) {
|
||||
l = offset;
|
||||
if (is_wait)
|
||||
notify();
|
||||
}
|
||||
|
||||
inline uint64 push(uint64 size) {
|
||||
auto offset = r;
|
||||
push(size, offset);
|
||||
return offset;
|
||||
}
|
||||
inline uint64 pop(uint64 size) {
|
||||
auto offset = l;
|
||||
pop(size, offset);
|
||||
return offset;
|
||||
}
|
||||
|
||||
inline char* get_ptr(uint64 size, uint64 offset) { return ((&_ptr)+((offset-size)&size_mask)); }
|
||||
|
||||
template<class T>
|
||||
inline T& get(uint64 offset) { return *(T*)((&_ptr)+((offset-sizeof(T))&size_mask)); }
|
||||
|
||||
template<class T>
|
||||
inline void push_t(const T& data, uint64& __restrict__ offset) {
|
||||
push(sizeof(T), offset);
|
||||
get<T>(offset) = data;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T& pop_t(uint64& __restrict__ offset) {
|
||||
pop(sizeof(T), offset);
|
||||
return get<T>(offset);
|
||||
}
|
||||
|
||||
inline void push_string(const string& data, uint64& __restrict__ offset) {
|
||||
push_t<int64>(data.size(), offset);
|
||||
push(data.size(), offset);
|
||||
auto ptr = get_ptr(data.size(), offset);
|
||||
std::memcpy(ptr, data.c_str(), data.size());
|
||||
}
|
||||
|
||||
inline string pop_string(uint64& __restrict__ offset) {
|
||||
auto size = pop_t<int64>(offset);
|
||||
pop(size, offset);
|
||||
auto ptr = get_ptr(size, offset);
|
||||
return string(ptr, size);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void push_t(const T& data) {
|
||||
auto offset = push(sizeof(T));
|
||||
get<T>(offset) = data;
|
||||
commit_push(offset);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T pop_t() {
|
||||
auto offset = pop(sizeof(T));
|
||||
T data = get<T>(offset);
|
||||
commit_pop(offset);
|
||||
return data;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
|
@ -7,6 +7,8 @@
|
|||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
|
||||
typedef struct _object PyObject;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct ArrayArgs {
|
||||
|
@ -22,7 +24,10 @@ struct ArrayOp : Op {
|
|||
// @pybind(None)
|
||||
ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float32);
|
||||
|
||||
// @pybind(array_)
|
||||
ArrayOp(ArrayArgs&& args);
|
||||
|
||||
ArrayOp(PyObject* obj);
|
||||
template<class T>
|
||||
inline T* ptr() { return (T*)allocation.ptr; }
|
||||
|
||||
|
|
|
@ -34,7 +34,9 @@ unordered_set<string> reduce_ops = {
|
|||
"add",
|
||||
// @pybind(prod, product, reduce_multiply)
|
||||
"multiply",
|
||||
// @pybind(reduce_logical_and, all)
|
||||
"logical_and",
|
||||
// @pybind(reduce_logical_or, any)
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"bitwise_and",
|
||||
|
|
|
@ -44,7 +44,7 @@ void MergeLoopPass::run() {
|
|||
while (cpx < ki.size() && cpx<kj.size() && ki[cpx] == kj[cpx]) cpx++;
|
||||
int mismatch = std::max(ki.size(), kj.size()) - cpx;
|
||||
LOGvvvv << "loop key " << ki << kj << "mismatch" << mismatch;
|
||||
if (mismatch>=2 || cpx==0)
|
||||
if (mismatch>=1 || cpx==0)
|
||||
continue;
|
||||
loops[i]->insert(0, loops[j]->children);
|
||||
loops[i]->merge_loop();
|
||||
|
|
|
@ -26,13 +26,15 @@ struct SimpleProfiler {
|
|||
string name;
|
||||
int64 cnt;
|
||||
int64 total_ns;
|
||||
int64 sum;
|
||||
int64 pcnt[7] = {0};
|
||||
int64 pns[7] = {0};
|
||||
int64 last[7] = {0};
|
||||
|
||||
inline SimpleProfiler(string&& name): name(move(name)), cnt(0), total_ns(0) {}
|
||||
inline ~SimpleProfiler() {
|
||||
std::cerr << "=============================\nSimpleProfiler [" << name << "] cnt: " << cnt << " total: ";
|
||||
void report() {
|
||||
std::cerr << "=============================\nSimpleProfiler [" << name << "] cnt: " << cnt
|
||||
<< " sum: " << sum << " speed: " << std::setprecision(3) << (sum*1.0/total_ns)
|
||||
<< " total: " ;
|
||||
if (total_ns < 1.e3)
|
||||
std::cerr << total_ns << " ns" << std::endl;
|
||||
else if (total_ns < 1.e6)
|
||||
|
@ -52,11 +54,15 @@ struct SimpleProfiler {
|
|||
for (int i=0; i<7; i++) std::cerr << std::setw(9) << last[i];
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
inline void add(int64 time) {
|
||||
|
||||
inline SimpleProfiler(string&& name): name(move(name)), cnt(0), total_ns(0), sum(0) {}
|
||||
inline ~SimpleProfiler() { report(); }
|
||||
inline void add(int64 time, int64 s) {
|
||||
auto nbit = 64 - _lzcnt(time);
|
||||
auto i = (nbit-1) / 5;
|
||||
if (i>6) i=6;
|
||||
cnt ++;
|
||||
sum += s;
|
||||
total_ns += time;
|
||||
pcnt[i] ++;
|
||||
pns[i] += time;
|
||||
|
@ -74,14 +80,20 @@ example:
|
|||
*/
|
||||
struct SimpleProfilerGuard {
|
||||
SimpleProfiler* p;
|
||||
int64 s;
|
||||
std::chrono::high_resolution_clock::time_point start;
|
||||
inline SimpleProfilerGuard(SimpleProfiler& p) : p(&p) {
|
||||
inline SimpleProfilerGuard(SimpleProfiler& p, int64 s=1) : p(&p), s(s) {
|
||||
start = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
void finish() {
|
||||
this->~SimpleProfilerGuard();
|
||||
s = 0;
|
||||
}
|
||||
inline ~SimpleProfilerGuard() {
|
||||
if (!s) return;
|
||||
auto finish = std::chrono::high_resolution_clock::now();
|
||||
auto total_ns = (int64_t)std::chrono::duration_cast<std::chrono::nanoseconds>(finish-start).count();
|
||||
p->add(total_ns);
|
||||
p->add(total_ns, s);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, Py
|
|||
unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||
int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||
PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||
int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
||||
|
||||
tmp_data_t tmp_data;
|
||||
|
||||
|
@ -34,6 +35,7 @@ void numpy_init() {
|
|||
fill(PyArray_GetNDArrayCFeatureVersion, 211);
|
||||
fill(PyArray_SetBaseObject, 282);
|
||||
fill(PyArray_NewCopy, 85);
|
||||
fill(PyArray_CopyInto, 82);
|
||||
|
||||
ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7);
|
||||
}
|
||||
|
|
|
@ -99,6 +99,8 @@ extern PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int,
|
|||
extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
||||
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||
extern PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||
extern int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
||||
|
||||
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
||||
|
||||
#define NPY_ARRAY_ALIGNED 0x0100
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "mem/allocator.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
#include <Python.h>
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "pyjt/numpy.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
ArrayOp::ArrayOp(PyObject* obj) {
|
||||
ArrayArgs args;
|
||||
PyObjHolder holder;
|
||||
args.ptr = nullptr;
|
||||
if (PyFloat_CheckExact(obj)) {
|
||||
tmp_data.f32 = PyFloat_AS_DOUBLE(obj);
|
||||
args = {&tmp_data, 1, ns_float32};
|
||||
} else
|
||||
if (PyLong_CheckExact(obj)) {
|
||||
tmp_data.i32 = PyLong_AsLong(obj);
|
||||
args = {&tmp_data, 1, ns_int32};
|
||||
} else
|
||||
if (PyBool_Check(obj)) {
|
||||
tmp_data.i8 = obj == Py_True;
|
||||
args = {&tmp_data, 1, ns_bool};
|
||||
} else
|
||||
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) {
|
||||
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
||||
args = move(fetch_sync({ptr}).at(0));
|
||||
} else
|
||||
if (Py_TYPE(obj) == PyArray_Type ||
|
||||
PyList_CheckExact(obj) ||
|
||||
PyObject_TypeCheck(obj, PyNumberArrType_Type)
|
||||
) {
|
||||
if (Py_TYPE(obj) != PyArray_Type) {
|
||||
holder.assign(PyArray_FROM_O(obj));
|
||||
obj = holder.obj;
|
||||
}
|
||||
auto arr = (PyArray_Proxy*)obj;
|
||||
if (arr->nd)
|
||||
args.shape = NanoVector::make(arr->dimensions, arr->nd);
|
||||
else
|
||||
args.shape.push_back(1);
|
||||
args.dtype = get_type_str(arr);
|
||||
if (is_c_style(arr))
|
||||
args.ptr = arr->data;
|
||||
|
||||
// use 32-bit by default
|
||||
if (holder.obj && args.dtype.dsize() == 8) {
|
||||
auto num = PyArray_Size(arr)/8;
|
||||
if (args.dtype.is_int()) {
|
||||
auto* __restrict__ i64 = (int64*)args.ptr;
|
||||
auto* __restrict__ i32 = (int32*)args.ptr;
|
||||
for (int i=0; i<num; i++)
|
||||
i32[i] = (int32)i64[i];
|
||||
args.dtype = ns_int32;
|
||||
} else if (args.dtype.is_float()) {
|
||||
auto* __restrict__ f64 = (float64*)args.ptr;
|
||||
auto* __restrict__ f32 = (float32*)args.ptr;
|
||||
for (int i=0; i<num; i++)
|
||||
f32[i] = (float32)f64[i];
|
||||
args.dtype = ns_float32;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
NanoVector shape = args.shape;
|
||||
output = create_output(shape, args.dtype);
|
||||
int64 size = output->size;
|
||||
if (shape.size() == 1 && shape[0] == 1) {
|
||||
output->flags.set(NodeFlags::_force_fuse);
|
||||
set_type(OpType::element);
|
||||
}
|
||||
void* host_ptr = nullptr;
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
if (!output->flags.get(NodeFlags::_force_fuse)) {
|
||||
// free prev allocation first
|
||||
event_queue.flush();
|
||||
// alloc new allocation
|
||||
auto size = output->size;
|
||||
new (&allocation) Allocation(&cuda_dual_allocator, size);
|
||||
host_ptr = cuda_dual_allocator.get_dual_allocation(allocation.allocation).host_ptr;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (!host_ptr) {
|
||||
new (&allocation) Allocation(cpu_allocator, output->size);
|
||||
host_ptr = allocation.ptr;
|
||||
}
|
||||
|
||||
if (args.ptr) {
|
||||
// if has ptr, copy from ptr
|
||||
std::memcpy(host_ptr, args.ptr, size);
|
||||
} else {
|
||||
// this is non-continue numpy array
|
||||
int64 dims[args.shape.size()];
|
||||
for (int i=0; i<args.shape.size(); i++)
|
||||
dims[i] = args.shape[i];
|
||||
holder.assign(PyArray_New(
|
||||
PyArray_Type, // subtype
|
||||
args.shape.size(), // nd
|
||||
dims, // dims
|
||||
get_typenum(args.dtype), // type_num
|
||||
NULL, // strides
|
||||
(void*)host_ptr, // data
|
||||
0, // itemsize
|
||||
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEABLE, // flags
|
||||
NULL // obj
|
||||
));
|
||||
ASSERT(0==PyArray_CopyInto(holder.obj, obj));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -318,7 +318,7 @@ DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) {
|
|||
int64 size = PyArray_Size(arr);
|
||||
T args;
|
||||
if (arr->nd)
|
||||
args.shape = vector<int64>(arr->dimensions, arr->dimensions+arr->nd);
|
||||
args.shape = NanoVector::make(arr->dimensions, arr->nd);
|
||||
else
|
||||
args.shape.push_back(1);
|
||||
args.dtype = get_type_str(arr);
|
||||
|
|
|
@ -11,6 +11,14 @@ namespace jittor {
|
|||
|
||||
struct PyObjHolder {
|
||||
PyObject* obj;
|
||||
inline PyObjHolder() : obj(nullptr) {
|
||||
}
|
||||
void assign(PyObject* obj) {
|
||||
if (!obj) {
|
||||
LOGf << "Python error occur";
|
||||
}
|
||||
this->obj = obj;
|
||||
}
|
||||
inline PyObjHolder(PyObject* obj) : obj(obj) {
|
||||
if (!obj) {
|
||||
LOGf << "Python error occur";
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "pyjt/py_ring_buffer.h"
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restrict__ offset) {
|
||||
PyObjHolder pickle(PyImport_ImportModule("pickle"));
|
||||
PyObjHolder dumps(PyObject_GetAttrString(pickle.obj, "dumps"));
|
||||
rb->push_t<uint8>(6, offset);
|
||||
PyObjHolder ret(PyObject_CallFunctionObjArgs(dumps.obj, obj, nullptr));
|
||||
obj = ret.obj;
|
||||
Py_ssize_t size;
|
||||
char* s;
|
||||
ASSERT(0 == PyBytes_AsStringAndSize(ret.obj, &s, &size));
|
||||
rb->push_t<int64>(size, offset);
|
||||
rb->push(size, offset);
|
||||
LOGir << string(rb->get_ptr(size, offset), size);
|
||||
std::memcpy(rb->get_ptr(size, offset), s, size);
|
||||
return;
|
||||
}
|
||||
|
||||
static PyObject* pop_py_object_pickle(RingBuffer* rb, uint64& __restrict__ offset) {
|
||||
PyObjHolder pickle(PyImport_ImportModule("pickle"));
|
||||
PyObjHolder loads(PyObject_GetAttrString(pickle.obj, "loads"));
|
||||
|
||||
auto size = rb->pop_t<int64>(offset);
|
||||
rb->pop(size, offset);
|
||||
PyObjHolder s(PyBytes_FromStringAndSize(rb->get_ptr(size, offset), size));
|
||||
|
||||
PyObjHolder ret(PyObject_CallFunctionObjArgs(loads.obj, s.obj, nullptr));
|
||||
return ret.release();
|
||||
}
|
||||
|
||||
|
||||
static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ offset) {
|
||||
if (PyLong_CheckExact(obj)) {
|
||||
int64 x = PyLong_AsLongLong(obj);
|
||||
rb->push_t<uint8>(0, offset);
|
||||
rb->push_t<int64>(x, offset);
|
||||
return;
|
||||
}
|
||||
if (PyFloat_CheckExact(obj)) {
|
||||
float64 x = PyFloat_AS_DOUBLE(obj);
|
||||
rb->push_t<uint8>(1, offset);
|
||||
rb->push_t<float64>(x, offset);
|
||||
return;
|
||||
}
|
||||
if (PyUnicode_CheckExact(obj)) {
|
||||
Py_ssize_t size;
|
||||
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
|
||||
rb->push_t<uint8>(2, offset);
|
||||
rb->push_t<int64>(size, offset);
|
||||
rb->push(size, offset);
|
||||
std::memcpy(rb->get_ptr(size, offset), s, size);
|
||||
return;
|
||||
}
|
||||
if (PyList_CheckExact(obj) || PyTuple_CheckExact(obj)) {
|
||||
rb->push_t<uint8>(3, offset);
|
||||
auto size = Py_SIZE(obj);
|
||||
auto arr = PySequence_Fast_ITEMS(obj);
|
||||
rb->push_t<int64>(size, offset);
|
||||
for (int64 i=0; i<size; i++) {
|
||||
auto oi = arr[i];
|
||||
push_py_object(rb, oi, offset);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (PyDict_CheckExact(obj)) {
|
||||
rb->push_t<uint8>(4, offset);
|
||||
auto size = Py_SIZE(obj);
|
||||
rb->push_t<int64>(size, offset);
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(obj, &pos, &key, &value)) {
|
||||
push_py_object(rb, key, offset);
|
||||
push_py_object(rb, value, offset);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type ||
|
||||
Py_TYPE(obj) == PyArray_Type) {
|
||||
ArrayArgs args;
|
||||
int64 size=0;
|
||||
rb->push_t<uint8>(5, offset);
|
||||
if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) {
|
||||
auto ptr = GET_RAW_PTR(VarHolder, obj);
|
||||
args = move(fetch_sync({ptr}).at(0));
|
||||
size = ptr->var->size;
|
||||
} else {
|
||||
auto arr = (PyArray_Proxy*)obj;
|
||||
if (arr->nd)
|
||||
args.shape = NanoVector::make(arr->dimensions, arr->nd);
|
||||
else
|
||||
args.shape.push_back(1);
|
||||
args.dtype = get_type_str(arr);
|
||||
size = PyArray_Size(arr);
|
||||
if (!is_c_style(arr)) {
|
||||
rb->push_t<NanoVector>(args.shape, offset);
|
||||
rb->push_t<NanoString>(args.dtype, offset);
|
||||
rb->push(size, offset);
|
||||
args.ptr = rb->get_ptr(size, offset);
|
||||
int64 dims[args.shape.size()];
|
||||
for (int i=0; i<args.shape.size(); i++)
|
||||
dims[i] = args.shape[i];
|
||||
PyObjHolder oh(PyArray_New(
|
||||
PyArray_Type, // subtype
|
||||
args.shape.size(), // nd
|
||||
dims, // dims
|
||||
get_typenum(args.dtype), // type_num
|
||||
NULL, // strides
|
||||
(void*)args.ptr, // data
|
||||
0, // itemsize
|
||||
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEABLE, // flags
|
||||
NULL // obj
|
||||
));
|
||||
ASSERT(0==PyArray_CopyInto(oh.obj, obj));
|
||||
return;
|
||||
} else {
|
||||
args.ptr = arr->data;
|
||||
}
|
||||
}
|
||||
rb->push_t<NanoVector>(args.shape, offset);
|
||||
rb->push_t<NanoString>(args.dtype, offset);
|
||||
rb->push(size, offset);
|
||||
std::memcpy(rb->get_ptr(size, offset), args.ptr, size);
|
||||
return;
|
||||
}
|
||||
push_py_object_pickle(rb, obj, offset);
|
||||
}
|
||||
|
||||
|
||||
static PyObject* to_py_object3(ArrayArgs&& a) {
|
||||
return to_py_object(jit_op_maker::array_(move(a)));
|
||||
}
|
||||
|
||||
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
||||
auto t = rb->pop_t<uint8>(offset);
|
||||
if (t==0) {
|
||||
auto x = rb->pop_t<int64>(offset);
|
||||
return PyLong_FromLongLong(x);
|
||||
}
|
||||
if (t==1) {
|
||||
auto x = rb->pop_t<float64>(offset);
|
||||
return PyFloat_FromDouble(x);
|
||||
}
|
||||
if (t==2) {
|
||||
auto size = rb->pop_t<int64>(offset);
|
||||
rb->pop(size, offset);
|
||||
return PyUnicode_FromStringAndSize(rb->get_ptr(size, offset), size);
|
||||
}
|
||||
if (t==3) {
|
||||
auto size = rb->pop_t<int64>(offset);
|
||||
PyObjHolder list(PyList_New(size));
|
||||
for (uint i=0; i<size; i++) {
|
||||
PyObject* o = pop_py_object(rb, offset);
|
||||
PyList_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
}
|
||||
if (t==4) {
|
||||
auto size = rb->pop_t<int64>(offset);
|
||||
PyObjHolder dict(PyDict_New());
|
||||
for (int64 i=0; i<size; i++) {
|
||||
PyObject* key = pop_py_object(rb, offset);
|
||||
PyObject* value = pop_py_object(rb, offset);
|
||||
PyDict_SetItem(dict.obj, key, value);
|
||||
}
|
||||
return dict.release();
|
||||
}
|
||||
if (t==5) {
|
||||
ArrayArgs args;
|
||||
args.shape = rb->pop_t<NanoVector>(offset);
|
||||
args.dtype = rb->pop_t<NanoString>(offset);
|
||||
int64 size = args.dtype.dsize();
|
||||
for (int i=0; i<args.shape.size(); i++)
|
||||
size *= args.shape[i];
|
||||
rb->pop(size, offset);
|
||||
args.ptr = rb->get_ptr(size, offset);
|
||||
return to_py_object3(move(args));
|
||||
}
|
||||
if (t==6) {
|
||||
return pop_py_object_pickle(rb, offset);
|
||||
}
|
||||
if (t == 255) {
|
||||
LOGf << "WorkerError:" << rb->pop_string(offset);
|
||||
} else
|
||||
LOGf << "unsupport type:" << (int)t;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void PyMultiprocessRingBuffer::push(PyObject* obj) {
|
||||
auto offset = rb->r;
|
||||
auto offset_bk = offset;
|
||||
try {
|
||||
push_py_object(rb, obj, offset);
|
||||
} catch (const std::exception& e) {
|
||||
offset = offset_bk;
|
||||
rb->push_t<uint8>(255, offset);
|
||||
rb->push_string(string(e.what()), offset);
|
||||
}
|
||||
rb->commit_push(offset);
|
||||
}
|
||||
|
||||
PyObject* PyMultiprocessRingBuffer::pop() {
|
||||
auto offset = rb->l;
|
||||
auto obj = pop_py_object(rb, offset);
|
||||
rb->commit_pop(offset);
|
||||
return obj;
|
||||
}
|
||||
|
||||
PyMultiprocessRingBuffer::PyMultiprocessRingBuffer(uint64 size) {
|
||||
rb = RingBuffer::make_ring_buffer(size, 1);
|
||||
}
|
||||
|
||||
PyMultiprocessRingBuffer::~PyMultiprocessRingBuffer() {
|
||||
RingBuffer::free_ring_buffer(rb);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <Python.h>
|
||||
#include "misc/ring_buffer.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(RingBuffer)
|
||||
struct PyMultiprocessRingBuffer {
|
||||
RingBuffer* rb;
|
||||
// @pyjt(__init__)
|
||||
PyMultiprocessRingBuffer(uint64 size);
|
||||
// @pyjt(__dealloc__)
|
||||
~PyMultiprocessRingBuffer();
|
||||
// @pyjt(push,send)
|
||||
void push(PyObject* obj);
|
||||
// @pyjt(pop,recv)
|
||||
PyObject* pop();
|
||||
// @pyjt(clear)
|
||||
inline void clear() { rb->l = rb->r = 0; }
|
||||
|
||||
// @pyjt(total_pop)
|
||||
inline uint64 total_pop() { return rb->l; }
|
||||
// @pyjt(total_push)
|
||||
inline uint64 total_push() { return rb->r; }
|
||||
// @pyjt(__repr__)
|
||||
inline string to_string() {
|
||||
string s="Buffer(free=";
|
||||
auto size = rb->size;
|
||||
auto used = rb->r - rb->l;
|
||||
s += S(100 - used*100.0/size);
|
||||
s += "% size=";
|
||||
s += S(size);
|
||||
s += ")";
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
}
|
Loading…
Reference in New Issue