mirror of https://github.com/Jittor/Jittor
commit
a76c664c65
|
@ -0,0 +1,3 @@
|
|||
from .ccl_2d import ccl_2d
|
||||
from .ccl_3d import ccl_3d
|
||||
from .ccl_link import ccl_link
|
|
@ -0,0 +1,177 @@
|
|||
import jittor as jt
|
||||
|
||||
|
||||
def ccl_2d(data_2d):
|
||||
'''
|
||||
2D connected component labelling, original code from https://github.com/DanielPlayne/playne-equivalence-algorithm
|
||||
Args:
|
||||
[in]param data_2d: binary two-dimensional vector
|
||||
type data_2d: jittor array
|
||||
|
||||
Returns:
|
||||
[out]result: labeled two-dimensional vector
|
||||
|
||||
Example:
|
||||
>>> import jittor as jt
|
||||
>>> jt.flags.use_cuda = 1
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> img = cv2.imread('testImg.png', 0)
|
||||
>>> a = img.mean()
|
||||
>>> img[img <= a] = 0
|
||||
>>> img[img > a] = 1
|
||||
>>> img = jt.Var(img)
|
||||
|
||||
>>> result = ccl_2d(img)
|
||||
>>> print(jt.unique(result, return_counts=True, return_inverse=True)[0], jt.unique(result, return_counts=True, return_inverse=True)[2])
|
||||
>>> cv2.imwrite('testImg_result.png', result.numpy().astype(np.uint8) * 50)
|
||||
'''
|
||||
|
||||
data_2d = data_2d.astype(jt.uint32)
|
||||
cY = data_2d.shape[0]
|
||||
cX = data_2d.shape[1]
|
||||
data_2d_copy = data_2d.clone()
|
||||
changed = jt.ones([1], dtype=jt.uint32)
|
||||
data_2d = data_2d.reshape(cX * cY)
|
||||
result = jt.code(data_2d.shape,
|
||||
data_2d.dtype, [data_2d, changed],
|
||||
cuda_header='''
|
||||
@alias(g_image, in0)
|
||||
@alias(g_labels, out)
|
||||
''',
|
||||
cuda_src=r'''
|
||||
__global__ void init_labels(@ARGS_DEF, const int cX, const int cY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y;
|
||||
@g_labels(iy*cX + ix) = iy*cX + ix;
|
||||
}
|
||||
|
||||
__device__ __inline__ unsigned int find_root(@ARGS_DEF, unsigned int label) {
|
||||
// Resolve Label
|
||||
unsigned int next = @g_labels(label);
|
||||
|
||||
// Follow chain
|
||||
while(label != next) {
|
||||
// Move to next
|
||||
label = next;
|
||||
next = @g_labels(label);
|
||||
}
|
||||
|
||||
// Return label
|
||||
return label;
|
||||
}
|
||||
|
||||
__global__ void resolve_labels(@ARGS_DEF, const int cX, const int cY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int id = ((blockIdx.y * blockDim.y) + threadIdx.y) * cX +
|
||||
((blockIdx.x * blockDim.x) + threadIdx.x);
|
||||
|
||||
// Check Thread Range
|
||||
if(id < cX*cY) {
|
||||
// Resolve Label
|
||||
@g_labels(id) = find_root(@ARGS, @g_labels(id));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void label_equivalence(@ARGS_DEF, const int cX, const int cY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y;
|
||||
|
||||
// Check Thread Range
|
||||
if((ix < cX) && (iy < cY)) {
|
||||
// Get image and label values
|
||||
const unsigned char cyx = @g_image( iy*cX + ix);
|
||||
|
||||
// Get neighbour labels
|
||||
const unsigned int lym1x = (iy > 0) ? @g_labels((iy-1)*cX + ix) : 0;
|
||||
const unsigned int lyxm1 = (ix > 0) ? @g_labels(iy *cX + ix-1) : 0;
|
||||
const unsigned int lyx = @g_labels(iy *cX + ix);
|
||||
const unsigned int lyxp1 = (ix < cX-1) ? @g_labels(iy *cX + ix+1) : 0;
|
||||
const unsigned int lyp1x = (iy < cY-1) ? @g_labels((iy+1)*cX + ix) : 0;
|
||||
|
||||
const unsigned int lym1xm1 = (iy > 0 && ix > 0 ) ? @g_labels((iy-1)*cX + ix-1) : 0;
|
||||
const unsigned int lym1xp1 = (iy > 0 && ix < cX-1) ? @g_labels((iy-1)*cX + ix+1) : 0;
|
||||
const unsigned int lyp1xm1 = (iy < cY-1 && ix > 0 ) ? @g_labels((iy+1)*cX + ix-1) : 0;
|
||||
const unsigned int lyp1xp1 = (iy < cY-1 && ix < cX-1) ? @g_labels((iy+1)*cX + ix+1) : 0;
|
||||
|
||||
const bool nym1x = (iy > 0) ? (cyx == (@g_image((iy-1)*cX + ix))) : false;
|
||||
const bool nyxm1 = (ix > 0) ? (cyx == (@g_image(iy *cX + ix-1))) : false;
|
||||
const bool nyxp1 = (ix < cX-1) ? (cyx == (@g_image(iy *cX + ix+1))) : false;
|
||||
const bool nyp1x = (iy > cY-1) ? (cyx == (@g_image((iy+1)*cX + ix))) : false;
|
||||
|
||||
const bool nym1xm1 = (iy > 0 && ix > 0 ) ? (cyx == (@g_image((iy-1)*cX + ix-1))) : false;
|
||||
const bool nym1xp1 = (iy > 0 && ix < cX-1) ? (cyx == (@g_image((iy-1)*cX + ix+1))) : false;
|
||||
const bool nyp1xm1 = (iy < cY-1 && ix > 0 ) ? (cyx == (@g_image((iy+1)*cX + ix-1))) : false;
|
||||
const bool nyp1xp1 = (iy < cY-1 && ix < cX-1) ? (cyx == (@g_image((iy+1)*cX + ix+1))) : false;
|
||||
|
||||
// Lowest label
|
||||
unsigned int label = lyx;
|
||||
|
||||
// Find lowest neighbouring label
|
||||
label = ((nym1x) && (lym1x < label)) ? lym1x : label;
|
||||
label = ((nyxm1) && (lyxm1 < label)) ? lyxm1 : label;
|
||||
label = ((nyxp1) && (lyxp1 < label)) ? lyxp1 : label;
|
||||
label = ((nyp1x) && (lyp1x < label)) ? lyp1x : label;
|
||||
|
||||
label = ((nym1xm1) && (lym1xm1 < label)) ? lym1xm1 : label;
|
||||
label = ((nym1xp1) && (lym1xp1 < label)) ? lym1xp1 : label;
|
||||
label = ((nyp1xm1) && (lyp1xm1 < label)) ? lyp1xm1 : label;
|
||||
label = ((nyp1xp1) && (lyp1xp1 < label)) ? lyp1xp1 : label;
|
||||
|
||||
// If labels are different, resolve them
|
||||
if(label < lyx) {
|
||||
// Update label
|
||||
// Nonatomic write may overwrite another label but on average seems to give faster results
|
||||
@g_labels(lyx) = label;
|
||||
|
||||
// Record the change
|
||||
@in1(0) = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
''' + f'''
|
||||
dim3 block(32, 32);
|
||||
const int cX= {cX};
|
||||
const int cY= {cY};''' + '''
|
||||
dim3 grid(ceil(cX/(float)block.x), ceil(cY/(float)block.y));
|
||||
dim3 resolve_block(32, 32);
|
||||
dim3 resolve_grid(ceil(cX/(float)resolve_block.x), ceil(cY/(float)resolve_block.y));
|
||||
|
||||
// Initialise labels
|
||||
init_labels <<< grid, block >>>(@ARGS, cX, cY);
|
||||
|
||||
// Resolve the labels
|
||||
resolve_labels <<< resolve_grid, resolve_block >>>(@ARGS, cX, cY);
|
||||
|
||||
// Changed Flag
|
||||
int32 changed = 1;
|
||||
|
||||
// While labels have changed
|
||||
while(changed) {
|
||||
// Copy changed to device
|
||||
cudaMemsetAsync(in1_p, 0, 4);
|
||||
|
||||
// Label image
|
||||
label_equivalence <<< grid, block >>>(@ARGS, cX, cY);
|
||||
|
||||
// Copy changed flag to host
|
||||
cudaMemcpy(&changed, in1_p, sizeof(int32), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Resolve the labels
|
||||
resolve_labels <<< resolve_grid, resolve_block>>>(@ARGS, cX, cY);
|
||||
}
|
||||
''')
|
||||
result = result.reshape((cY, cX)) * data_2d_copy
|
||||
value = jt.unique(result)
|
||||
value = value[value != 0]
|
||||
|
||||
map_result = jt.zeros((int(value.max().numpy()[0]) + 1), dtype=jt.uint32)
|
||||
map_result[value] = jt.index(value.shape)[0] + 1
|
||||
result = map_result[result]
|
||||
|
||||
return result
|
|
@ -0,0 +1,196 @@
|
|||
import jittor as jt
|
||||
|
||||
|
||||
def ccl_3d(data_3d):
|
||||
'''
|
||||
3D connected component labelling, original code from https://github.com/DanielPlayne/playne-equivalence-algorithm
|
||||
Args:
|
||||
[in]param data_3d: binary three-dimensional vector
|
||||
type data_3d: jittor array
|
||||
|
||||
Returns:
|
||||
[out]result : labeled three-dimensional vector
|
||||
|
||||
Example:
|
||||
>>> import jittor as jt
|
||||
>>> jt.flags.use_cuda = 1
|
||||
>>> data_3d = jt.zeros((10, 11, 12), dtype=jt.uint32)
|
||||
>>> data_3d[2:4, :, :] = 1
|
||||
>>> data_3d[5:7, :, :] = 1
|
||||
>>> result = ccl_3d(data_3d)
|
||||
>>> print(result[:, 0, 0])
|
||||
>>> print(
|
||||
jt.unique(result, return_counts=True, return_inverse=True)[0],
|
||||
jt.unique(result, return_counts=True, return_inverse=True)[2])
|
||||
'''
|
||||
|
||||
data_3d = data_3d.astype(jt.uint32)
|
||||
cX = data_3d.shape[0]
|
||||
cY = data_3d.shape[1]
|
||||
cZ = data_3d.shape[2]
|
||||
changed = jt.ones([1], dtype=jt.uint32)
|
||||
data_3d_copy = data_3d.copy()
|
||||
data_3d = data_3d.reshape(cX * cY * cZ)
|
||||
result = jt.code(data_3d.shape,
|
||||
data_3d.dtype, [data_3d, changed],
|
||||
cuda_header='''
|
||||
@alias(g_image, in0)
|
||||
@alias(g_labels, out)
|
||||
''',
|
||||
cuda_src=r'''
|
||||
__global__ void init_labels(@ARGS_DEF, const int cX, const int cY, const int cZ, const int pX, const int pY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y;
|
||||
const unsigned int iz = (blockIdx.z * blockDim.z) + threadIdx.z;
|
||||
|
||||
if((ix < cX) && (iy < cY) && (iz < cZ)) {
|
||||
const unsigned char pzyx = @g_image(iz*pY + iy*pX + ix);
|
||||
|
||||
// Neighbour Connections
|
||||
const bool nzm1yx = (iz > 0) ? (pzyx == @g_image((iz-1)*pY + iy *pX + ix )) : false;
|
||||
const bool nzym1x = (iy > 0) ? (pzyx == @g_image( iz *pY + (iy-1)*pX + ix )) : false;
|
||||
const bool nzyxm1 = (ix > 0) ? (pzyx == @g_image( iz *pY + iy *pX + ix-1)) : false;
|
||||
|
||||
// Label
|
||||
unsigned int label;
|
||||
|
||||
// Initialise Label
|
||||
label = (nzyxm1) ? ( iz*pY + iy*pX + ix-1) : (iz*pY + iy*pX + ix);
|
||||
label = (nzym1x) ? ( iz*pY + (iy-1)*pX + ix) : label;
|
||||
label = (nzm1yx) ? ((iz-1)*pY + iy*pX + ix) : label;
|
||||
// Write to Global Memory
|
||||
@g_labels(iz*pY + iy*pX + ix) = label;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __inline__ unsigned int find_root(@ARGS_DEF, unsigned int label) {
|
||||
// Resolve Label
|
||||
unsigned int next = @g_labels(label);
|
||||
|
||||
// Follow chain
|
||||
while(label != next) {
|
||||
// Move to next
|
||||
label = next;
|
||||
next = @g_labels(label);
|
||||
}
|
||||
|
||||
// Return label
|
||||
return label;
|
||||
}
|
||||
|
||||
__global__ void resolve_labels(@ARGS_DEF, const int cX, const int cY, const int cZ, const int pX, const int pY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int id = ((blockIdx.z * blockDim.z) + threadIdx.z) * pY +
|
||||
((blockIdx.y * blockDim.y) + threadIdx.y) * pX +
|
||||
((blockIdx.x * blockDim.x) + threadIdx.x);
|
||||
|
||||
// Check Thread Range
|
||||
if(id < cX*cY*cZ) {
|
||||
// Resolve Label
|
||||
@g_labels(id) = find_root(@ARGS, @g_labels(id));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void label_equivalence(@ARGS_DEF, const int cX, const int cY, const int cZ, const int pX, const int pY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y;
|
||||
const unsigned int iz = (blockIdx.z * blockDim.z) + threadIdx.z;
|
||||
|
||||
// Check Thread Range
|
||||
if((ix < cX) && (iy < cY) && (iz < cZ)) {
|
||||
// Get image and label values
|
||||
const unsigned char pzyx = @g_image(iz*pY + iy*pX + ix);
|
||||
|
||||
// Neighbouring indexes
|
||||
const unsigned int xm1 = ix-1;
|
||||
const unsigned int xp1 = ix+1;
|
||||
const unsigned int ym1 = iy-1;
|
||||
const unsigned int yp1 = iy+1;
|
||||
const unsigned int zm1 = iz-1;
|
||||
const unsigned int zp1 = iz+1;
|
||||
|
||||
// Get neighbour labels
|
||||
const unsigned int lzm1yx = (iz > 0) ? @g_labels(zm1*pY + iy*pX + ix) : 0;
|
||||
const unsigned int lzym1x = (iy > 0) ? @g_labels( iz*pY + ym1*pX + ix) : 0;
|
||||
const unsigned int lzyxm1 = (ix > 0) ? @g_labels( iz*pY + iy*pX + xm1) : 0;
|
||||
const unsigned int lzyx = @g_labels( iz*pY + iy*pX + ix);
|
||||
const unsigned int lzyxp1 = (ix < cX-1) ? @g_labels( iz*pY + iy*pX + xp1) : 0;
|
||||
const unsigned int lzyp1x = (iy < cY-1) ? @g_labels( iz*pY + yp1*pX + ix) : 0;
|
||||
const unsigned int lzp1yx = (iz < cZ-1) ? @g_labels(zp1*pY + iy*pX + ix) : 0;
|
||||
|
||||
const bool nzm1yx = (iz > 0) ? (pzyx == @g_image(zm1*pY + iy*pX + ix)) : false;
|
||||
const bool nzym1x = (iy > 0) ? (pzyx == @g_image( iz*pY + ym1*pX + ix)) : false;
|
||||
const bool nzyxm1 = (ix > 0) ? (pzyx == @g_image( iz*pY + iy*pX + xm1)) : false;
|
||||
const bool nzyxp1 = (ix < cX-1) ? (pzyx == @g_image( iz*pY + iy*pX + xp1)) : false;
|
||||
const bool nzyp1x = (iy < cY-1) ? (pzyx == @g_image( iz*pY + yp1*pX + ix)) : false;
|
||||
const bool nzp1yx = (iz < cZ-1) ? (pzyx == @g_image(zp1*pY + iy*pX + ix)) : false;
|
||||
|
||||
// Lowest label
|
||||
unsigned int label = lzyx;
|
||||
|
||||
// Find lowest neighbouring label
|
||||
label = ((nzm1yx) && (lzm1yx < label)) ? lzm1yx : label;
|
||||
label = ((nzym1x) && (lzym1x < label)) ? lzym1x : label;
|
||||
label = ((nzyxm1) && (lzyxm1 < label)) ? lzyxm1 : label;
|
||||
label = ((nzyxp1) && (lzyxp1 < label)) ? lzyxp1 : label;
|
||||
label = ((nzyp1x) && (lzyp1x < label)) ? lzyp1x : label;
|
||||
label = ((nzp1yx) && (lzp1yx < label)) ? lzp1yx : label;
|
||||
|
||||
// If labels are different, resolve them
|
||||
if(label < lzyx) {
|
||||
// Update label
|
||||
// Nonatomic write may overwrite another label but on average seems to give faster results
|
||||
@g_labels(lzyx) = label;
|
||||
|
||||
// Record the change
|
||||
@in1(0) = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
''' + f'''
|
||||
dim3 block(32, 4, 4);
|
||||
const int cX= {cX};
|
||||
const int cY= {cY};
|
||||
const int cZ= {cZ};
|
||||
const int pX= cX;
|
||||
const int pY= cX*cY;''' + '''
|
||||
dim3 grid(ceil(cX/(float)block.x), ceil(cY/(float)block.y), ceil(cZ/(float)block.z));
|
||||
|
||||
// Initialise labels
|
||||
init_labels <<< grid, block >>>(@ARGS, cX, cY, cZ, pX, pY);
|
||||
|
||||
// Resolve the labels
|
||||
resolve_labels <<< grid, block >>>(@ARGS, cX, cY, cZ, pX, pY);
|
||||
|
||||
// Changed Flag
|
||||
int32 changed = 1;
|
||||
|
||||
// While labels have changed
|
||||
while(changed) {
|
||||
// Copy changed to device
|
||||
cudaMemsetAsync(in1_p, 0, 4);
|
||||
|
||||
// Label image
|
||||
label_equivalence <<< grid, block >>>(@ARGS, cX, cY, cZ, pX, pY);
|
||||
|
||||
// Copy changed flag to host
|
||||
cudaMemcpy(&changed, in1_p, sizeof(int32), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Resolve the labels
|
||||
resolve_labels <<< grid, block>>>(@ARGS, cX, cY, cZ, pX, pY);
|
||||
}
|
||||
''')
|
||||
result = result.reshape((cX, cY, cZ)) * data_3d_copy
|
||||
value = jt.unique(result)
|
||||
value = value[value != 0]
|
||||
|
||||
map_result = jt.zeros((int(value.max().numpy()[0]) + 1), dtype=jt.uint32)
|
||||
map_result[value] = jt.index(value.shape)[0] + 1
|
||||
result = map_result[result]
|
||||
|
||||
return result
|
|
@ -0,0 +1,195 @@
|
|||
import jittor as jt
|
||||
|
||||
|
||||
def ccl_link(score_map, link_map, result_comp_area_thresh=6):
|
||||
"""
|
||||
Find components in score map and link them with link map, original code from https://github.com/DanielPlayne/playne-equivalence-algorithm.
|
||||
Args:
|
||||
[in]param score_map: binary two-dimensional vector
|
||||
type score_map: jittor array
|
||||
[in]param link_map: two-dimensional vector with 8 channels
|
||||
type link_map: jittor array
|
||||
[in]param result_comp_area_thresh: threshold of component area
|
||||
type result_comp_area_thresh: int
|
||||
Returns:
|
||||
[out]result: labeled two-dimensional vector
|
||||
Example:
|
||||
>>> import jittor as jt
|
||||
>>> jt.flags.use_cuda = 1
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> score_map = jt.Var(np.load("score_map.npy"))
|
||||
>>> link_map = jt.Var(np.load("link_map.npy"))
|
||||
>>> score_map = score_map >= 0.5
|
||||
>>> link_map = link_map >= 0.8
|
||||
>>> for i in range(8):
|
||||
>>> link_map[:, :, i] = link_map[:, :, i] & score_map
|
||||
|
||||
>>> result = ccl_link(score_map, link_map)
|
||||
>>> cv2.imwrite('pixellink.png', result.numpy().astype(np.uint8) * 50)
|
||||
"""
|
||||
score_map = score_map.astype(jt.uint32)
|
||||
link_map = link_map.astype(jt.uint32)
|
||||
cY = score_map.shape[0]
|
||||
cX = score_map.shape[1]
|
||||
changed = jt.ones([1], dtype=jt.uint32)
|
||||
score_map = score_map.reshape(cX * cY)
|
||||
result = jt.code(score_map.shape,
|
||||
score_map.dtype, [score_map, link_map, changed],
|
||||
cuda_header='''
|
||||
@alias(score_map, in0)
|
||||
@alias(link_map, in1)
|
||||
@alias(g_labels, out)
|
||||
''',
|
||||
cuda_src=r'''
|
||||
__global__ void init_labels(@ARGS_DEF, const int cX, const int cY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y;
|
||||
@g_labels(iy*cX + ix) = iy*cX + ix;
|
||||
}
|
||||
|
||||
__device__ __inline__ unsigned int find_root(@ARGS_DEF, unsigned int label) {
|
||||
// Resolve Label
|
||||
unsigned int next = @g_labels(label);
|
||||
|
||||
// Follow chain
|
||||
while(label != next) {
|
||||
// Move to next
|
||||
label = next;
|
||||
next = @g_labels(label);
|
||||
}
|
||||
|
||||
// Return label
|
||||
return label;
|
||||
}
|
||||
|
||||
__global__ void resolve_labels(@ARGS_DEF, const int cX, const int cY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int id = ((blockIdx.y * blockDim.y) + threadIdx.y) * cX +
|
||||
((blockIdx.x * blockDim.x) + threadIdx.x);
|
||||
|
||||
// Check Thread Range
|
||||
if(id < cX*cY) {
|
||||
// Resolve Label
|
||||
@g_labels(id) = find_root(@ARGS, @g_labels(id));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void label_equivalence(@ARGS_DEF, const int cX, const int cY) {
|
||||
@PRECALC
|
||||
// Calculate index
|
||||
const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y;
|
||||
|
||||
// Check Thread Range
|
||||
if((ix < cX) && (iy < cY)) {
|
||||
// Get image and label values
|
||||
const unsigned char cyx = @score_map( iy*cX + ix);
|
||||
|
||||
// Get neighbour labels
|
||||
const unsigned int lym1x = (iy > 0) ? @g_labels((iy-1)*cX + ix) : 0;
|
||||
const unsigned int lyxm1 = (ix > 0) ? @g_labels(iy *cX + ix-1) : 0;
|
||||
const unsigned int lyx = @g_labels(iy *cX + ix);
|
||||
const unsigned int lyxp1 = (ix < cX-1) ? @g_labels(iy *cX + ix+1) : 0;
|
||||
const unsigned int lyp1x = (iy < cY-1) ? @g_labels((iy+1)*cX + ix) : 0;
|
||||
|
||||
const unsigned int lym1xm1 = (iy > 0 && ix > 0 ) ? @g_labels((iy-1)*cX + ix-1) : 0;
|
||||
const unsigned int lym1xp1 = (iy > 0 && ix < cX-1) ? @g_labels((iy-1)*cX + ix+1) : 0;
|
||||
const unsigned int lyp1xm1 = (iy < cY-1 && ix > 0 ) ? @g_labels((iy+1)*cX + ix-1) : 0;
|
||||
const unsigned int lyp1xp1 = (iy < cY-1 && ix < cX-1) ? @g_labels((iy+1)*cX + ix+1) : 0;
|
||||
bool nym1x, nyxm1, nyxp1, nyp1x, nym1xm1, nym1xp1, nyp1xm1, nyp1xp1;
|
||||
if(cyx) {
|
||||
nym1x = (iy > 0) ? ((cyx == (@score_map((iy-1)*cX + ix))) && (@link_map(iy, ix, 6) || @link_map(iy-1, ix, 7))) : false; // up
|
||||
nyxm1 = (ix > 0) ? ((cyx == (@score_map(iy *cX + ix-1))) && (@link_map(iy, ix, 0) || @link_map(iy-1, ix-1, 3))) : false; // left
|
||||
nyxp1 = (ix < cX-1) ? ((cyx == (@score_map(iy *cX + ix+1))) && (@link_map(iy, ix, 3) || @link_map(iy, ix+1, 0))) : false; // right
|
||||
nyp1x = (iy > cY-1) ? ((cyx == (@score_map((iy+1)*cX + ix))) && (@link_map(iy, ix, 7) || @link_map(iy+1, ix, 6))) : false; // down
|
||||
|
||||
nym1xm1 = (iy > 0 && ix > 0 ) ? ((cyx == (@score_map((iy-1)*cX + ix-1))) && (@link_map(iy, ix, 2) || @link_map(iy-1, ix-1, 4))) : false; // up-left
|
||||
nym1xp1 = (iy > 0 && ix < cX-1) ? ((cyx == (@score_map((iy-1)*cX + ix+1))) && (@link_map(iy, ix, 5) || @link_map(iy-1, ix+1, 1))) : false; // up-right
|
||||
nyp1xm1 = (iy < cY-1 && ix > 0 ) ? ((cyx == (@score_map((iy+1)*cX + ix-1))) && (@link_map(iy, ix, 1) || @link_map(iy+1, ix-1, 5))) : false; // down-left
|
||||
nyp1xp1 = (iy < cY-1 && ix < cX-1) ? ((cyx == (@score_map((iy+1)*cX + ix+1))) && (@link_map(iy, ix, 4) || @link_map(iy+1, ix+1, 2))) : false; // down-right
|
||||
}
|
||||
else {
|
||||
nym1x = (iy > 0) ? (cyx == (@score_map((iy-1)*cX + ix))) : false; // up
|
||||
nyxm1 = (ix > 0) ? (cyx == (@score_map(iy *cX + ix-1))) : false; // left
|
||||
nyxp1 = (ix < cX-1) ? (cyx == (@score_map(iy *cX + ix+1))) : false; // right
|
||||
nyp1x = (iy > cY-1) ? (cyx == (@score_map((iy+1)*cX + ix))) : false; // down
|
||||
|
||||
nym1xm1 = (iy > 0 && ix > 0 ) ? (cyx == (@score_map((iy-1)*cX + ix-1))) : false; // up-left
|
||||
nym1xp1 = (iy > 0 && ix < cX-1) ? (cyx == (@score_map((iy-1)*cX + ix+1))) : false; // up-right
|
||||
nyp1xm1 = (iy < cY-1 && ix > 0 ) ? (cyx == (@score_map((iy+1)*cX + ix-1))) : false; // down-left
|
||||
nyp1xp1 = (iy < cY-1 && ix < cX-1) ? (cyx == (@score_map((iy+1)*cX + ix+1))) : false; // down-right
|
||||
}
|
||||
|
||||
// Lowest label
|
||||
unsigned int label = lyx;
|
||||
|
||||
// Find lowest neighbouring label
|
||||
label = ((nym1x) && (lym1x < label)) ? lym1x : label;
|
||||
label = ((nyxm1) && (lyxm1 < label)) ? lyxm1 : label;
|
||||
label = ((nyxp1) && (lyxp1 < label)) ? lyxp1 : label;
|
||||
label = ((nyp1x) && (lyp1x < label)) ? lyp1x : label;
|
||||
|
||||
label = ((nym1xm1) && (lym1xm1 < label)) ? lym1xm1 : label;
|
||||
label = ((nym1xp1) && (lym1xp1 < label)) ? lym1xp1 : label;
|
||||
label = ((nyp1xm1) && (lyp1xm1 < label)) ? lyp1xm1 : label;
|
||||
label = ((nyp1xp1) && (lyp1xp1 < label)) ? lyp1xp1 : label;
|
||||
|
||||
// If labels are different, resolve them
|
||||
if(label < lyx) {
|
||||
// Update label
|
||||
// Nonatomic write may overwrite another label but on average seems to give faster results
|
||||
@g_labels(lyx) = label;
|
||||
|
||||
// Record the change
|
||||
@in2(0) = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
''' + f'''
|
||||
dim3 block(32, 32);
|
||||
const int cX= {cX};
|
||||
const int cY= {cY};''' + '''
|
||||
dim3 grid(ceil(cX/(float)block.x), ceil(cY/(float)block.y));
|
||||
dim3 resolve_block(32, 32);
|
||||
dim3 resolve_grid(ceil(cX/(float)resolve_block.x), ceil(cY/(float)resolve_block.y));
|
||||
|
||||
// Initialise labels
|
||||
init_labels <<< grid, block >>>(@ARGS, cX, cY);
|
||||
|
||||
// Resolve the labels
|
||||
resolve_labels <<< resolve_grid, resolve_block >>>(@ARGS, cX, cY);
|
||||
|
||||
// Changed Flag
|
||||
int32 changed = 1;
|
||||
|
||||
// While labels have changed
|
||||
while(changed) {
|
||||
// Copy changed to device
|
||||
cudaMemsetAsync(in2_p, 0, 4);
|
||||
|
||||
// Label image
|
||||
label_equivalence <<< grid, block >>>(@ARGS, cX, cY);
|
||||
|
||||
// Copy changed flag to host
|
||||
cudaMemcpy(&changed, in2_p, sizeof(int32), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Resolve the labels
|
||||
resolve_labels <<< resolve_grid, resolve_block >>>(@ARGS, cX, cY);
|
||||
}
|
||||
''')
|
||||
|
||||
result = result.reshape((cY, cX))
|
||||
|
||||
value, _, cnt = jt.unique(result, return_inverse=True, return_counts=True)
|
||||
value = (cnt > result_comp_area_thresh) * value
|
||||
value = value[value != 0]
|
||||
|
||||
map_result = jt.zeros((int(value.max().numpy()[0]) + 1), dtype=jt.uint32)
|
||||
map_result[value] = jt.index(value.shape)[0] + 1
|
||||
result = map_result[result]
|
||||
|
||||
return result
|
Loading…
Reference in New Issue