mirror of https://github.com/Jittor/Jittor
floor keep type and better mem_info
This commit is contained in:
parent
08eeb67de5
commit
d482161be0
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.84'
|
||||
__version__ = '1.2.3.85'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -562,7 +562,7 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
|
|||
'''
|
||||
if high is None: low, high = 0, low
|
||||
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
|
||||
v = jt.floor(v)
|
||||
v = jt.floor_int(v)
|
||||
return v.astype(dtype)
|
||||
|
||||
def randint_like(x, low, high=None) -> Var:
|
||||
|
|
|
@ -26,6 +26,7 @@ dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
|||
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||
mpi = jt.mpi
|
||||
img_open_hook = HookTimer(Image, "open")
|
||||
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
|
||||
|
@ -513,8 +514,12 @@ Example::
|
|||
now = time.time()
|
||||
self.batch_time = now - start
|
||||
start = now
|
||||
|
||||
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
|
||||
jt.display_memory_info()
|
||||
else:
|
||||
for _ in self._epochs():
|
||||
self.batch_id = 0
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
batch_data.append(self[int(idx)])
|
||||
|
@ -522,12 +527,16 @@ Example::
|
|||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
yield batch_data
|
||||
self.batch_id += 1
|
||||
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
|
||||
jt.display_memory_info()
|
||||
batch_data = []
|
||||
|
||||
# depend on drop_last
|
||||
if not self.drop_last and len(batch_data) > 0:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
self.batch_id += 1
|
||||
yield batch_data
|
||||
|
||||
|
||||
|
|
|
@ -129,7 +129,7 @@ class Geometric:
|
|||
|
||||
def sample(self, sample_shape):
|
||||
u = jt.rand(sample_shape)
|
||||
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
|
||||
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor_int()
|
||||
|
||||
def log_prob(self, x):
|
||||
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
|
||||
|
|
|
@ -1248,9 +1248,9 @@ def _bicubic(x, a, func):
|
|||
|
||||
def _interpolate(img, x, y, ids, mode):
|
||||
if mode == "nearest":
|
||||
return img.reindex([*ids, x.floor(), y.floor()])
|
||||
return img.reindex([*ids, x.floor_int(), y.floor_int()])
|
||||
if mode == "bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
fx, fy = x.floor_int(), y.floor_int()
|
||||
cx, cy = fx + 1, fy + 1
|
||||
dx, dy = x - fx, y - fy
|
||||
a = img.reindex_var([*ids, fx, fy])
|
||||
|
@ -1264,7 +1264,7 @@ def _interpolate(img, x, y, ids, mode):
|
|||
return o
|
||||
if mode=="bicubic": # ugly ver.
|
||||
n,c,h,w = img.shape
|
||||
fx, fy = x.floor(), y.floor()
|
||||
fx, fy = x.floor_int(), y.floor_int()
|
||||
dix, diy = x - fx, y - fy
|
||||
ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2)
|
||||
bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1)
|
||||
|
@ -1434,7 +1434,7 @@ def reflect_coordinates(x,twice_low,twice_high):
|
|||
x = (x - m).abs()
|
||||
#`fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
||||
extra = x.mod(span)
|
||||
flips = (x / span).floor()
|
||||
flips = (x / span).floor_int()
|
||||
result1 = extra+m
|
||||
result2 = span-extra+m
|
||||
con = flips%2==0
|
||||
|
@ -1486,9 +1486,9 @@ def grid_sampler_3d(X,grid,mode,padding_mode,align_corners):
|
|||
zid = z.reindex(shape,['i0','i2','i3','i4'])
|
||||
|
||||
if mode=='nearest':
|
||||
return X.reindex([nid,cid,zid.round(),yid.round(),xid.round()])
|
||||
return X.reindex([nid,cid,zid.round_int(),yid.round_int(),xid.round_int()])
|
||||
elif mode=='bilinear':
|
||||
fx,fy,fz = xid.floor(),yid.floor(),zid.floor()
|
||||
fx,fy,fz = xid.floor_int(),yid.floor_int(),zid.floor_int()
|
||||
cx,cy,cz = fx+1,fy+1,fz+1
|
||||
dx,dy,dz = xid-fx,yid-fy,zid-fz
|
||||
dnx,dny,dnz = cx-xid,cy-yid,cz-zid
|
||||
|
@ -1523,10 +1523,10 @@ def grid_sampler_2d(X,grid,mode,padding_mode,align_corners):
|
|||
yid = y.reindex(shape,['i0','i2','i3'])
|
||||
|
||||
if mode=='nearest':
|
||||
return X.reindex([nid,cid,yid.round(),xid.round()])
|
||||
return X.reindex([nid,cid,yid.round_int(),xid.round_int()])
|
||||
elif mode=='bilinear':
|
||||
#xid,yid = (xid+0.00001),(yid+0.00001)
|
||||
fx,fy = (xid).floor(),(yid).floor()
|
||||
fx,fy = (xid).floor_int(),(yid).floor_int()
|
||||
cx,cy = fx+1,fy+1
|
||||
dx,dy = xid-fx,yid-fy
|
||||
dnx,dny = cx-xid,cy-yid
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, use_temp_allocator, 1, "Enable temp allocator");
|
||||
vector<TempAllocator*> TempAllocator::temp_allocators;
|
||||
|
||||
TempAllocator::~TempAllocator() {
|
||||
while (!cached_blocks.empty()) {
|
||||
|
|
|
@ -24,6 +24,7 @@ struct TempCachingBlock {
|
|||
struct TempAllocator : Allocator {
|
||||
static const size_t ALIGN_SIZE = 512;
|
||||
static const size_t ID_LIMIT = 1 << 18;
|
||||
static vector<TempAllocator*> temp_allocators;
|
||||
Allocator* underlying;
|
||||
size_t cache_blocks_limit, used_memory, unused_memory;
|
||||
std::map<unsigned long long, TempCachingBlock*> cached_blocks;
|
||||
|
@ -33,6 +34,7 @@ struct TempAllocator : Allocator {
|
|||
|
||||
|
||||
inline TempAllocator(size_t cache_blocks_limit=2) : cache_blocks_limit(cache_blocks_limit), used_memory(0), unused_memory(0), tot_block_id(0), occupied_id_mapper(new TempCachingBlock*[ID_LIMIT]) {
|
||||
temp_allocators.push_back(this);
|
||||
}
|
||||
inline TempAllocator(Allocator* underlying, size_t cache_blocks_limit=2) : TempAllocator(cache_blocks_limit) {
|
||||
setup(underlying);
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#elif defined(__APPLE__)
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
#include <unistd.h>
|
||||
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
|
@ -105,16 +106,29 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
|||
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
|
||||
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
|
||||
}
|
||||
if (use_temp_allocator && exe.temp_allocator) {
|
||||
for (auto& a : TempAllocator::temp_allocators) {
|
||||
auto total = a->used_memory + a->unused_memory;
|
||||
all_total += total;
|
||||
a->is_cuda() ? gpu_total += total : cpu_total += total;
|
||||
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
|
||||
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
|
||||
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
|
||||
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
|
||||
}
|
||||
}
|
||||
log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"}
|
||||
<< "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"}
|
||||
<< "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n';
|
||||
if (use_temp_allocator && exe.temp_allocator) {
|
||||
TempAllocator* temp_allocator = (TempAllocator*)exe.temp_allocator;
|
||||
log << "\nname:" << temp_allocator->name() << "\n";
|
||||
log << "used_memory:" << FloatOutput{(double)temp_allocator->used_memory, " KMG", 1024, "B"} << "\n";
|
||||
log << "unused_memory:" << FloatOutput{(double)temp_allocator->unused_memory, " KMG", 1024, "B"} << "\n";
|
||||
|
||||
}
|
||||
auto cpu_free = get_avphys_pages() * sysconf(_SC_PAGESIZE);
|
||||
size_t gpu_free = 0, _gpu_total = 0;
|
||||
#ifdef HAS_CUDA
|
||||
cudaMemGetInfo(&gpu_free, &_gpu_total);
|
||||
#endif
|
||||
log << "free: cpu(">>FloatOutput{(double)cpu_free, " KMG", 1024, "B"}
|
||||
>> ") gpu(">>FloatOutput{(double)gpu_free, " KMG", 1024, "B"} >> ")\n";
|
||||
if (dump_var) {
|
||||
vector<Node*> queue;
|
||||
unordered_set<Node*> visited;
|
||||
|
@ -156,6 +170,12 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
|
|||
log.end();
|
||||
}
|
||||
|
||||
extern vector<void(*)()> sigquit_callback;
|
||||
|
||||
void meminfo_callback() {
|
||||
display_memory_info();
|
||||
}
|
||||
|
||||
MemInfo::MemInfo() {
|
||||
#if defined(__linux__)
|
||||
struct sysinfo info = {0};
|
||||
|
@ -174,6 +194,7 @@ MemInfo::MemInfo() {
|
|||
cudaGetDeviceProperties(&prop, 0);
|
||||
total_cuda_ram = prop.totalGlobalMem;
|
||||
#endif
|
||||
sigquit_callback.push_back(&meminfo_callback);
|
||||
}
|
||||
|
||||
MemInfo mem_info;
|
||||
|
|
|
@ -58,6 +58,9 @@ static unordered_set<string> unary_ops = {
|
|||
"round",
|
||||
"floor",
|
||||
"ceil",
|
||||
"round_int",
|
||||
"floor_int",
|
||||
"ceil_int",
|
||||
"cast",
|
||||
"sin",
|
||||
"asin",
|
||||
|
@ -81,9 +84,9 @@ static unordered_set<string> unary_float_ops = {
|
|||
"sqrt",
|
||||
};
|
||||
static unordered_set<string> unary_int_ops = {
|
||||
"round",
|
||||
"floor",
|
||||
"ceil",
|
||||
"round_int",
|
||||
"floor_int",
|
||||
"ceil_int",
|
||||
};
|
||||
|
||||
static unordered_set<string> binary_ops = {
|
||||
|
|
|
@ -62,6 +62,9 @@ constexpr int ns_max_len = 16;
|
|||
m(round) \
|
||||
m(floor) \
|
||||
m(ceil) \
|
||||
m(round_int) \
|
||||
m(floor_int) \
|
||||
m(ceil_int) \
|
||||
m(cast) \
|
||||
\
|
||||
m(sin) \
|
||||
|
|
|
@ -168,7 +168,7 @@ static unordered_set<string> unary_ops = {
|
|||
>>> a
|
||||
jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32)
|
||||
>>> jt.round(a)
|
||||
jt.Var([ 2 0 0 -1], dtype=int32)
|
||||
jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32)
|
||||
*/
|
||||
"round",
|
||||
|
||||
|
@ -185,7 +185,7 @@ static unordered_set<string> unary_ops = {
|
|||
>>> a
|
||||
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
|
||||
>>> jt.floor(a)
|
||||
jt.Var([-2 -1 -1 -1], dtype=int32)
|
||||
jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32)
|
||||
*/
|
||||
"floor",
|
||||
|
||||
|
@ -203,10 +203,63 @@ static unordered_set<string> unary_ops = {
|
|||
>>> a
|
||||
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
|
||||
>>> jt.ceil(a)
|
||||
jt.Var([-1 0 0 0], dtype=int32)
|
||||
jt.Var([-1.0 0.0 0.0 0.0], dtype=float32)
|
||||
*/
|
||||
"ceil",
|
||||
|
||||
/**
|
||||
Returns the closest integer of the input ``x``.
|
||||
|
||||
----------------
|
||||
|
||||
* [in] x: the input jt.Var.
|
||||
|
||||
----------------
|
||||
|
||||
Example-1::
|
||||
>>> a = jt.randn(4)
|
||||
>>> a
|
||||
jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32)
|
||||
>>> jt.round_int(a)
|
||||
jt.Var([ 2 0 0 -1], dtype=int32)
|
||||
*/
|
||||
"round_int",
|
||||
|
||||
/**
|
||||
Returns the largest integer less than or equal to the input ``x``.
|
||||
|
||||
----------------
|
||||
|
||||
* [in] x: the input jt.Var.
|
||||
|
||||
----------------
|
||||
Example-1::
|
||||
>>> a = jt.randn(4)
|
||||
>>> a
|
||||
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
|
||||
>>> jt.floor_int(a)
|
||||
jt.Var([-2 -1 -1 -1], dtype=int32)
|
||||
*/
|
||||
"floor_int",
|
||||
|
||||
/**
|
||||
Returns the smallest integer greater than or equal to the input ``x``.
|
||||
|
||||
----------------
|
||||
|
||||
* [in] x: the input jt.Var.
|
||||
|
||||
----------------
|
||||
|
||||
Example-1::
|
||||
>>> a = jt.randn(4)
|
||||
>>> a
|
||||
jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32)
|
||||
>>> jt.ceil_int(a)
|
||||
jt.Var([-1 0 0 0], dtype=int32)
|
||||
*/
|
||||
"ceil_int",
|
||||
|
||||
/**
|
||||
Returns the sine of the input ``x``.
|
||||
|
||||
|
|
|
@ -21,6 +21,9 @@ namespace jittor {
|
|||
#define round(T,x) ((T) ::roundf((x)))
|
||||
#define floor(T,x) ((T) ::floorf((x)))
|
||||
#define ceil(T,x) ((T) ::ceilf((x)))
|
||||
#define round_int(T,x) ((T) ::roundf((x)))
|
||||
#define floor_int(T,x) ((T) ::floorf((x)))
|
||||
#define ceil_int(T,x) ((T) ::ceilf((x)))
|
||||
|
||||
#define sin(T,x) ((T) ::sinf((x)))
|
||||
#define asin(T,x) ((T) ::asinf((x)))
|
||||
|
@ -49,6 +52,9 @@ namespace jittor {
|
|||
#define round(T,x) ((T)std::round((x)))
|
||||
#define floor(T,x) ((T)std::floor((x)))
|
||||
#define ceil(T,x) ((T)std::ceil((x)))
|
||||
#define round_int(T,x) ((T)std::round((x)))
|
||||
#define floor_int(T,x) ((T)std::floor((x)))
|
||||
#define ceil_int(T,x) ((T)std::ceil((x)))
|
||||
|
||||
#define sin(T,x) ((T) std::sin((x)))
|
||||
#define asin(T,x) ((T) std::asin((x)))
|
||||
|
|
|
@ -188,17 +188,52 @@ int segfault_happen = 0;
|
|||
string thread_local thread_name;
|
||||
static int _pid = getpid();
|
||||
|
||||
static inline void do_exit() {
|
||||
#ifdef __APPLE__
|
||||
_Exit(1);
|
||||
#else
|
||||
std::quick_exit(1);
|
||||
#endif
|
||||
}
|
||||
|
||||
vector<void(*)()> sigquit_callback;
|
||||
int64 last_q_time;
|
||||
|
||||
void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
||||
if (signal == SIGQUIT) {
|
||||
if (_pid == getpid()) {
|
||||
std::cerr << "Caught SIGQUIT" << std::endl;
|
||||
int64 now = clock();
|
||||
if (now > last_q_time && last_q_time+CLOCKS_PER_SEC/10 > now) {
|
||||
last_q_time = now;
|
||||
std::cerr << "GDB attach..." << std::endl;
|
||||
breakpoint();
|
||||
} else {
|
||||
last_q_time = now;
|
||||
for (auto f : sigquit_callback)
|
||||
f();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (signal == SIGCHLD) {
|
||||
if (si->si_code != CLD_EXITED && si->si_status != SIGTERM) {
|
||||
LOGe << "Caught SIGCHLD"
|
||||
<< "si_errno:" << si->si_errno
|
||||
<< "si_code:" << si->si_code
|
||||
<< "si_status:" << si->si_status
|
||||
<< ", quick exit";
|
||||
exited = true;
|
||||
do_exit();
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (signal == SIGINT) {
|
||||
if (_pid == getpid()) {
|
||||
LOGe << "Caught SIGINT, quick exit";
|
||||
}
|
||||
exited = true;
|
||||
#ifdef __APPLE__
|
||||
_Exit(1);
|
||||
#else
|
||||
std::quick_exit(1);
|
||||
#endif
|
||||
do_exit();
|
||||
}
|
||||
std::cerr << "Caught segfault at address " << si->si_addr << ", "
|
||||
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;
|
||||
|
@ -237,6 +272,7 @@ int register_sigaction() {
|
|||
sigaction(SIGSTOP, &sa, NULL);
|
||||
sigaction(SIGFPE, &sa, NULL);
|
||||
sigaction(SIGINT, &sa, NULL);
|
||||
sigaction(SIGCHLD, &sa, NULL);
|
||||
sigaction(SIGILL, &sa, NULL);
|
||||
sigaction(SIGBUS, &sa, NULL);
|
||||
sigaction(SIGQUIT, &sa, NULL);
|
||||
|
|
|
@ -214,6 +214,81 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
assert x[i] == a
|
||||
assert y[i] == b
|
||||
assert z[i] == c
|
||||
|
||||
def test_children_died(self):
|
||||
src = """
|
||||
import jittor as jt
|
||||
from jittor.dataset import Dataset
|
||||
import numpy as np
|
||||
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=160)
|
||||
|
||||
def __getitem__(self, k):
|
||||
if k>100:
|
||||
while 1:
|
||||
pass
|
||||
return { "a":np.array([1,2,3]) }
|
||||
|
||||
dataset = YourDataset()
|
||||
dataset.set_attrs(num_workers=2)
|
||||
|
||||
for d in dataset:
|
||||
dataset.workers[0].p.kill()
|
||||
pass
|
||||
"""
|
||||
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
|
||||
with open(fname, 'w') as f:
|
||||
f.write(src)
|
||||
import subprocess as sp
|
||||
import sys
|
||||
cmd = sys.executable + " " + fname
|
||||
print(cmd)
|
||||
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
|
||||
s = r.stderr.decode()
|
||||
print(s)
|
||||
assert r.returncode != 0
|
||||
assert "SIGCHLD" in s
|
||||
assert "quick exit" in s
|
||||
|
||||
|
||||
def test_children_died2(self):
|
||||
src = """
|
||||
import jittor as jt
|
||||
from jittor.dataset import Dataset
|
||||
import numpy as np
|
||||
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=160)
|
||||
|
||||
def __getitem__(self, k):
|
||||
if k>100:
|
||||
while 1:
|
||||
pass
|
||||
return { "a":np.array([1,2,3]) }
|
||||
|
||||
dataset = YourDataset()
|
||||
dataset.set_attrs(num_workers=2)
|
||||
|
||||
for d in dataset:
|
||||
break
|
||||
dataset.terminate()
|
||||
"""
|
||||
fname = os.path.join(jt.flags.cache_path, "children_dead_test.py")
|
||||
with open(fname, 'w') as f:
|
||||
f.write(src)
|
||||
import subprocess as sp
|
||||
import sys
|
||||
cmd = sys.executable + " " + fname
|
||||
print(cmd)
|
||||
r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
|
||||
s = r.stderr.decode()
|
||||
print(s)
|
||||
assert r.returncode == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -115,9 +115,9 @@ def resize_and_crop(x, bbox, interpolation="nearest"):
|
|||
x = bb[0]*jt.float(H-1)+hid*(bb[2]-bb[0])
|
||||
y = bb[1]*jt.float(W-1)+wid*(bb[3]-bb[1])
|
||||
if interpolation=="nearest":
|
||||
return img.reindex_var([x.round(), y.round()])
|
||||
return img.reindex_var([x.round_int(), y.round_int()])
|
||||
if interpolation=="bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
fx, fy = x.floor_int(), y.floor_int()
|
||||
cx, cy = fx+one, fy+one
|
||||
dx, dy = x-fx, y-fy
|
||||
a = img.reindex_var([fx, fy])
|
||||
|
|
|
@ -50,9 +50,9 @@ def resize_and_crop(x, bbox, interpolation="nearest", out_size=[224,224]):
|
|||
x = bb[0]*(H-1.0)+hid*((H-1)*1.0/(shape[1]-1))*(bb[2]-bb[0])
|
||||
y = bb[1]*(W-1.0)+wid*((W-1)*1.0/(shape[2]-1))*(bb[3]-bb[1])
|
||||
if interpolation=="nearest":
|
||||
return img.reindex([x.round(), y.round(), cid])
|
||||
return img.reindex([x.round_int(), y.round_int(), cid])
|
||||
if interpolation=="bilinear":
|
||||
fx, fy = x.floor(), y.floor()
|
||||
fx, fy = x.floor_int(), y.floor_int()
|
||||
cx, cy = fx+one, fy+one
|
||||
dx, dy = x-fx, y-fy
|
||||
a = img.reindex_var([fx, fy, cid])
|
||||
|
|
Loading…
Reference in New Issue