floor keep type and better mem_info

This commit is contained in:
Dun Liang 2021-07-30 20:40:21 +08:00
parent 08eeb67de5
commit d482161be0
15 changed files with 242 additions and 33 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.84'
__version__ = '1.2.3.85'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -562,7 +562,7 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
'''
if high is None: low, high = 0, low
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
v = jt.floor(v)
v = jt.floor_int(v)
return v.astype(dtype)
def randint_like(x, low, high=None) -> Var:

View File

@ -26,6 +26,7 @@ dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))
class Worker:
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
@ -513,8 +514,12 @@ Example::
now = time.time()
self.batch_time = now - start
start = now
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
jt.display_memory_info()
else:
for _ in self._epochs():
self.batch_id = 0
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
@ -522,12 +527,16 @@ Example::
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data
self.batch_id += 1
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
jt.display_memory_info()
batch_data = []
# depend on drop_last
if not self.drop_last and len(batch_data) > 0:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
self.batch_id += 1
yield batch_data

View File

@ -129,7 +129,7 @@ class Geometric:
def sample(self, sample_shape):
u = jt.rand(sample_shape)
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor_int()
def log_prob(self, x):
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)

View File

@ -1248,9 +1248,9 @@ def _bicubic(x, a, func):
def _interpolate(img, x, y, ids, mode):
if mode == "nearest":
return img.reindex([*ids, x.floor(), y.floor()])
return img.reindex([*ids, x.floor_int(), y.floor_int()])
if mode == "bilinear":
fx, fy = x.floor(), y.floor()
fx, fy = x.floor_int(), y.floor_int()
cx, cy = fx + 1, fy + 1
dx, dy = x - fx, y - fy
a = img.reindex_var([*ids, fx, fy])
@ -1264,7 +1264,7 @@ def _interpolate(img, x, y, ids, mode):
return o
if mode=="bicubic": # ugly ver.
n,c,h,w = img.shape
fx, fy = x.floor(), y.floor()
fx, fy = x.floor_int(), y.floor_int()
dix, diy = x - fx, y - fy
ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2)
bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1)
@ -1434,7 +1434,7 @@ def reflect_coordinates(x,twice_low,twice_high):
x = (x - m).abs()
#`fmod` returns same sign as `in`, which is positive after the `fabs` above.
extra = x.mod(span)
flips = (x / span).floor()
flips = (x / span).floor_int()
result1 = extra+m
result2 = span-extra+m
con = flips%2==0
@ -1486,9 +1486,9 @@ def grid_sampler_3d(X,grid,mode,padding_mode,align_corners):
zid = z.reindex(shape,['i0','i2','i3','i4'])
if mode=='nearest':
return X.reindex([nid,cid,zid.round(),yid.round(),xid.round()])
return X.reindex([nid,cid,zid.round_int(),yid.round_int(),xid.round_int()])
elif mode=='bilinear':
fx,fy,fz = xid.floor(),yid.floor(),zid.floor()
fx,fy,fz = xid.floor_int(),yid.floor_int(),zid.floor_int()
cx,cy,cz = fx+1,fy+1,fz+1
dx,dy,dz = xid-fx,yid-fy,zid-fz
dnx,dny,dnz = cx-xid,cy-yid,cz-zid
@ -1523,10 +1523,10 @@ def grid_sampler_2d(X,grid,mode,padding_mode,align_corners):
yid = y.reindex(shape,['i0','i2','i3'])
if mode=='nearest':
return X.reindex([nid,cid,yid.round(),xid.round()])
return X.reindex([nid,cid,yid.round_int(),xid.round_int()])
elif mode=='bilinear':
#xid,yid = (xid+0.00001),(yid+0.00001)
fx,fy = (xid).floor(),(yid).floor()
fx,fy = (xid).floor_int(),(yid).floor_int()
cx,cy = fx+1,fy+1
dx,dy = xid-fx,yid-fy
dnx,dny = cx-xid,cy-yid

View File

@ -13,6 +13,7 @@
namespace jittor {
DEFINE_FLAG(int, use_temp_allocator, 1, "Enable temp allocator");
vector<TempAllocator*> TempAllocator::temp_allocators;
TempAllocator::~TempAllocator() {
while (!cached_blocks.empty()) {

View File

@ -24,6 +24,7 @@ struct TempCachingBlock {
struct TempAllocator : Allocator {
static const size_t ALIGN_SIZE = 512;
static const size_t ID_LIMIT = 1 << 18;
static vector<TempAllocator*> temp_allocators;
Allocator* underlying;
size_t cache_blocks_limit, used_memory, unused_memory;
std::map<unsigned long long, TempCachingBlock*> cached_blocks;
@ -33,6 +34,7 @@ struct TempAllocator : Allocator {
inline TempAllocator(size_t cache_blocks_limit=2) : cache_blocks_limit(cache_blocks_limit), used_memory(0), unused_memory(0), tot_block_id(0), occupied_id_mapper(new TempCachingBlock*[ID_LIMIT]) {
temp_allocators.push_back(this);
}
inline TempAllocator(Allocator* underlying, size_t cache_blocks_limit=2) : TempAllocator(cache_blocks_limit) {
setup(underlying);

View File

@ -12,6 +12,7 @@
#elif defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include <unistd.h>
#include "var.h"
#include "op.h"
@ -105,16 +106,29 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
}
if (use_temp_allocator && exe.temp_allocator) {
for (auto& a : TempAllocator::temp_allocators) {
auto total = a->used_memory + a->unused_memory;
all_total += total;
a->is_cuda() ? gpu_total += total : cpu_total += total;
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
}
}
log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"}
<< "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"}
<< "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n';
if (use_temp_allocator && exe.temp_allocator) {
TempAllocator* temp_allocator = (TempAllocator*)exe.temp_allocator;
log << "\nname:" << temp_allocator->name() << "\n";
log << "used_memory:" << FloatOutput{(double)temp_allocator->used_memory, " KMG", 1024, "B"} << "\n";
log << "unused_memory:" << FloatOutput{(double)temp_allocator->unused_memory, " KMG", 1024, "B"} << "\n";
}
auto cpu_free = get_avphys_pages() * sysconf(_SC_PAGESIZE);
size_t gpu_free = 0, _gpu_total = 0;
#ifdef HAS_CUDA
cudaMemGetInfo(&gpu_free, &_gpu_total);
#endif
log << "free: cpu(">>FloatOutput{(double)cpu_free, " KMG", 1024, "B"}
>> ") gpu(">>FloatOutput{(double)gpu_free, " KMG", 1024, "B"} >> ")\n";
if (dump_var) {
vector<Node*> queue;
unordered_set<Node*> visited;
@ -156,6 +170,12 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
log.end();
}
extern vector<void(*)()> sigquit_callback;
void meminfo_callback() {
display_memory_info();
}
MemInfo::MemInfo() {
#if defined(__linux__)
struct sysinfo info = {0};
@ -174,6 +194,7 @@ MemInfo::MemInfo() {
cudaGetDeviceProperties(&prop, 0);
total_cuda_ram = prop.totalGlobalMem;
#endif
sigquit_callback.push_back(&meminfo_callback);
}
MemInfo mem_info;

View File

@ -58,6 +58,9 @@ static unordered_set<string> unary_ops = {
"round",
"floor",
"ceil",
"round_int",
"floor_int",
"ceil_int",
"cast",
"sin",
"asin",
@ -81,9 +84,9 @@ static unordered_set<string> unary_float_ops = {
"sqrt",
};
static unordered_set<string> unary_int_ops = {
"round",
"floor",
"ceil",
"round_int",
"floor_int",
"ceil_int",
};
static unordered_set<string> binary_ops = {

View File

@ -62,6 +62,9 @@ constexpr int ns_max_len = 16;
m(round) \
m(floor) \
m(ceil) \
m(round_int) \
m(floor_int) \
m(ceil_int) \
m(cast) \
\
m(sin) \

View File

@ -168,7 +168,7 @@ static unordered_set<string> unary_ops = {
>>> a
jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32)
>>> jt.round(a)
jt.Var([ 2 0 0 -1], dtype=int32)
jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32)
*/
"round",
@ -185,7 +185,7 @@ static unordered_set<string> unary_ops = {
>>> a
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
>>> jt.floor(a)
jt.Var([-2 -1 -1 -1], dtype=int32)
jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32)
*/
"floor",
@ -203,10 +203,63 @@ static unordered_set<string> unary_ops = {
>>> a
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
>>> jt.ceil(a)
jt.Var([-1 0 0 0], dtype=int32)
jt.Var([-1.0 0.0 0.0 0.0], dtype=float32)
*/
"ceil",
/**
Returns the closest integer of the input ``x``.
----------------
* [in] x: the input jt.Var.
----------------
Example-1::
>>> a = jt.randn(4)
>>> a
jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32)
>>> jt.round_int(a)
jt.Var([ 2 0 0 -1], dtype=int32)
*/
"round_int",
/**
Returns the largest integer less than or equal to the input ``x``.
----------------
* [in] x: the input jt.Var.
----------------
Example-1::
>>> a = jt.randn(4)
>>> a
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
>>> jt.floor_int(a)
jt.Var([-2 -1 -1 -1], dtype=int32)
*/
"floor_int",
/**
Returns the smallest integer greater than or equal to the input ``x``.
----------------
* [in] x: the input jt.Var.
----------------
Example-1::
>>> a = jt.randn(4)
>>> a
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
>>> jt.ceil_int(a)
jt.Var([-1 0 0 0], dtype=int32)
*/
"ceil_int",
/**
Returns the sine of the input ``x``.

View File

@ -21,6 +21,9 @@ namespace jittor {
#define round(T,x) ((T) ::roundf((x)))
#define floor(T,x) ((T) ::floorf((x)))
#define ceil(T,x) ((T) ::ceilf((x)))
#define round_int(T,x) ((T) ::roundf((x)))
#define floor_int(T,x) ((T) ::floorf((x)))
#define ceil_int(T,x) ((T) ::ceilf((x)))
#define sin(T,x) ((T) ::sinf((x)))
#define asin(T,x) ((T) ::asinf((x)))
@ -49,6 +52,9 @@ namespace jittor {
#define round(T,x) ((T)std::round((x)))
#define floor(T,x) ((T)std::floor((x)))
#define ceil(T,x) ((T)std::ceil((x)))
#define round_int(T,x) ((T)std::round((x)))
#define floor_int(T,x) ((T)std::floor((x)))
#define ceil_int(T,x) ((T)std::ceil((x)))
#define sin(T,x) ((T) std::sin((x)))
#define asin(T,x) ((T) std::asin((x)))

View File

@ -188,17 +188,52 @@ int segfault_happen = 0;
string thread_local thread_name;
static int _pid = getpid();
static inline void do_exit() {
#ifdef __APPLE__
_Exit(1);
#else
std::quick_exit(1);
#endif
}
vector<void(*)()> sigquit_callback;
int64 last_q_time;
void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
if (signal == SIGQUIT) {
if (_pid == getpid()) {
std::cerr << "Caught SIGQUIT" << std::endl;
int64 now = clock();
if (now > last_q_time && last_q_time+CLOCKS_PER_SEC/10 > now) {
last_q_time = now;
std::cerr << "GDB attach..." << std::endl;
breakpoint();
} else {
last_q_time = now;
for (auto f : sigquit_callback)
f();
}
}
return;
}
if (signal == SIGCHLD) {
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM) {
LOGe << "Caught SIGCHLD"
<< "si_errno:" << si->si_errno
<< "si_code:" << si->si_code
<< "si_status:" << si->si_status
<< ", quick exit";
exited = true;
do_exit();
}
return;
}
if (signal == SIGINT) {
if (_pid == getpid()) {
LOGe << "Caught SIGINT, quick exit";
}
exited = true;
#ifdef __APPLE__
_Exit(1);
#else
std::quick_exit(1);
#endif
do_exit();
}
std::cerr << "Caught segfault at address " << si->si_addr << ", "
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;
@ -237,6 +272,7 @@ int register_sigaction() {
sigaction(SIGSTOP, &sa, NULL);
sigaction(SIGFPE, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
sigaction(SIGCHLD, &sa, NULL);
sigaction(SIGILL, &sa, NULL);
sigaction(SIGBUS, &sa, NULL);
sigaction(SIGQUIT, &sa, NULL);

View File

@ -214,6 +214,81 @@ class TestDatasetSeed(unittest.TestCase):
assert x[i] == a
assert y[i] == b
assert z[i] == c
def test_children_died(self):
src = """
import jittor as jt
from jittor.dataset import Dataset
import numpy as np
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
if k>100:
while 1:
pass
return { "a":np.array([1,2,3]) }
dataset = YourDataset()
dataset.set_attrs(num_workers=2)
for d in dataset:
dataset.workers[0].p.kill()
pass
"""
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
with open(fname, 'w') as f:
f.write(src)
import subprocess as sp
import sys
cmd = sys.executable + " " + fname
print(cmd)
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
s = r.stderr.decode()
print(s)
assert r.returncode != 0
assert "SIGCHLD" in s
assert "quick exit" in s
def test_children_died2(self):
src = """
import jittor as jt
from jittor.dataset import Dataset
import numpy as np
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
if k>100:
while 1:
pass
return { "a":np.array([1,2,3]) }
dataset = YourDataset()
dataset.set_attrs(num_workers=2)
for d in dataset:
break
dataset.terminate()
"""
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
with open(fname, 'w') as f:
f.write(src)
import subprocess as sp
import sys
cmd = sys.executable + " " + fname
print(cmd)
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
s = r.stderr.decode()
print(s)
assert r.returncode == 0
if __name__ == "__main__":

View File

@ -115,9 +115,9 @@ def resize_and_crop(x, bbox, interpolation="nearest"):
x = bb[0]*jt.float(H-1)+hid*(bb[2]-bb[0])
y = bb[1]*jt.float(W-1)+wid*(bb[3]-bb[1])
if interpolation=="nearest":
return img.reindex_var([x.round(), y.round()])
return img.reindex_var([x.round_int(), y.round_int()])
if interpolation=="bilinear":
fx, fy = x.floor(), y.floor()
fx, fy = x.floor_int(), y.floor_int()
cx, cy = fx+one, fy+one
dx, dy = x-fx, y-fy
a = img.reindex_var([fx, fy])

View File

@ -50,9 +50,9 @@ def resize_and_crop(x, bbox, interpolation="nearest", out_size=[224,224]):
x = bb[0]*(H-1.0)+hid*((H-1)*1.0/(shape[1]-1))*(bb[2]-bb[0])
y = bb[1]*(W-1.0)+wid*((W-1)*1.0/(shape[2]-1))*(bb[3]-bb[1])
if interpolation=="nearest":
return img.reindex([x.round(), y.round(), cid])
return img.reindex([x.round_int(), y.round_int(), cid])
if interpolation=="bilinear":
fx, fy = x.floor(), y.floor()
fx, fy = x.floor_int(), y.floor_int()
cx, cy = fx+one, fy+one
dx, dy = x-fx, y-fy
a = img.reindex_var([fx, fy, cid])