mirror of https://github.com/Jittor/Jittor
loss3d: chamfer & emd
This commit is contained in:
parent
ec62338d45
commit
52246efbcc
|
@ -0,0 +1,2 @@
|
|||
from .chamfer import chamfer_loss, ChamferLoss
|
||||
from .emd import earth_mover_distance, EarthMoverDistance
|
|
@ -0,0 +1,153 @@
|
|||
# Author: Zheng-Ning Liu
|
||||
#
|
||||
# This file implements chamfer loss on both CPU and GPU.
|
||||
# The implementation does no use extra NxM matrix to store distances, and thus
|
||||
# supports large point clouds.
|
||||
|
||||
import jittor as jt
|
||||
import jittor.nn as nn
|
||||
|
||||
cpu_src = '''
|
||||
for (int bs = 0; bs < in0_shape0; ++bs)
|
||||
for (int i = 0; i < in0_shape1; ++i) {
|
||||
float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) +
|
||||
(@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) +
|
||||
(@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2));
|
||||
@out(bs, i) = 0;
|
||||
for (int j = 1; j < in1_shape1; ++j) {
|
||||
float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) +
|
||||
(@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) +
|
||||
(@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2));
|
||||
if (dis < min_dis) {
|
||||
min_dis = dis;
|
||||
@out(bs, i) = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
cuda_src = '''
|
||||
__global__ void chamfer_loss_min_idx_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
int bs = blockIdx.x;
|
||||
int n = in0_shape1;
|
||||
int m = in1_shape1;
|
||||
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) +
|
||||
(@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) +
|
||||
(@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2));
|
||||
@out(bs, i) = 0;
|
||||
for (int j = 1; j < m; ++j) {
|
||||
float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) +
|
||||
(@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) +
|
||||
(@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2));
|
||||
if (dis < min_dis) {
|
||||
min_dis = dis;
|
||||
@out(bs, i) = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chamfer_loss_min_idx_kernel<<<in0_shape0, 512>>>(@ARGS);
|
||||
'''
|
||||
|
||||
|
||||
def chamfer_loss(pc1, pc2, reduction='mean', dims='BNC', bidirectional=False):
|
||||
''' return the chamfer loss from pc1 to pc2.
|
||||
|
||||
:param pc1: input point cloud
|
||||
:type pc1: jittor array
|
||||
|
||||
:param pc2: input point cloud
|
||||
:type pc2: jittor array
|
||||
|
||||
:param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'.
|
||||
:type reduction: str, optional
|
||||
|
||||
:param dims: a string that represents each dimension, can be
|
||||
'[BNC]' ([batch, number of points, xyz]), or
|
||||
'[BCN]' ([batch, xyz, number of points]). Default: 'BNC'.
|
||||
:type dims: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jittor as jt
|
||||
>>> from jittor.loss3d import chamfer_loss
|
||||
>>> jt.flags.use_cuda = True
|
||||
>>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> cf = chamfer_loss(pc1, pc2, dims='BNC', bidirectional=True)
|
||||
>>> print('chamfer loss =', cf.item())
|
||||
'''
|
||||
if bidirectional:
|
||||
return chamfer_loss(pc1, pc2, reduction, dims) + chamfer_loss(pc2, pc1, reduction, dims)
|
||||
|
||||
assert dims in ['BNC', 'BCN']
|
||||
if dims == 'BCN':
|
||||
pc1, pc2 = pc1.permute(0, 2, 1), pc2.permute(0, 2, 1)
|
||||
|
||||
batch_size_1, N, _ = pc1.shape
|
||||
batch_size_2, M, _ = pc2.shape
|
||||
assert batch_size_1 == batch_size_2
|
||||
batch_size = batch_size_1
|
||||
|
||||
idx = jt.code([batch_size, N], 'int32', [pc1, pc2],
|
||||
cpu_src=cpu_src,
|
||||
cuda_src=cuda_src)
|
||||
|
||||
nearest_pts = pc2.reindex([batch_size, idx.shape[1], 3], [
|
||||
'i0',
|
||||
'@e0(i0, i1)',
|
||||
'i2'
|
||||
], extras=[idx])
|
||||
|
||||
chamfer_distance = (((pc1 - nearest_pts) ** 2).sum(dim=-1)).sqrt()
|
||||
if reduction is None:
|
||||
return chamfer_distance
|
||||
elif reduction == 'sum':
|
||||
return jt.sum(chamfer_distance)
|
||||
elif reduction == 'mean':
|
||||
return jt.mean(chamfer_distance)
|
||||
|
||||
|
||||
class ChamferLoss(nn.Module):
|
||||
''' A loss layer that computes the chamfer loss from pc1 to pc2.
|
||||
|
||||
:param pc1: input point cloud
|
||||
:type pc1: jittor array
|
||||
|
||||
:param pc2: input point cloud
|
||||
:type pc2: jittor array
|
||||
|
||||
:param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'.
|
||||
:type reduction: str, optional
|
||||
|
||||
:param dims: a string that represents each dimension, can be
|
||||
'[BNC]' ([batch, number of points, xyz]), or
|
||||
'[BCN]' ([batch, xyz, number of points]). Default: 'BNC'.
|
||||
:type dims: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jittor as jt
|
||||
>>> from jittor.loss3d import ChamferLoss
|
||||
>>> jt.flags.use_cuda = True
|
||||
>>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> CF = ChamferLoss(dims='BNC', bidirectional=True)
|
||||
>>> cf = CF(pc1, pc2)
|
||||
>>> print('chamfer loss =', cf.item())
|
||||
'''
|
||||
|
||||
def __init__(self, reduction='mean', dims='BNC', bidirectional=False):
|
||||
''' see function @chamfer_loss
|
||||
'''
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
self.dims = dims
|
||||
self.bidirectional = bidirectional
|
||||
|
||||
def execute(self, pc1, pc2):
|
||||
return chamfer_loss(pc1, pc2, self.reduction, self.dims, self.bidirectional)
|
|
@ -0,0 +1,440 @@
|
|||
# Author: Zheng-Ning Liu
|
||||
#
|
||||
# The gpu implementation is original provided by Haoqiang Fan and Kaichun Mo,
|
||||
# <https://github.com/daerduoCarey/PyTorchEMD>.
|
||||
|
||||
import jittor as jt
|
||||
from jittor import Function
|
||||
|
||||
EMD_gpu_header = '''
|
||||
namespace jittor {
|
||||
__device__ inline out_type dist2(out_type x1, out_type y1, out_type z1,
|
||||
out_type x2, out_type y2, out_type z2) {
|
||||
return (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
approxmatch_gpu_src = '''
|
||||
__global__ void approxmatch_gpu_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(xyz1, in0)
|
||||
@alias(xyz2, in1)
|
||||
@alias(match, out)
|
||||
|
||||
int b = in0_shape0;
|
||||
int n = in0_shape1;
|
||||
int m = in1_shape1;
|
||||
|
||||
out_type *remainL = in2_p + blockIdx.x * (n + m) * 2;
|
||||
out_type *remainR = remainL + n;
|
||||
out_type *ratioL = remainR + m;
|
||||
out_type *ratioR = ratioL + n;
|
||||
|
||||
const int Block = 1024;
|
||||
__shared__ out_type buf[Block * 4];
|
||||
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
||||
for (int j = threadIdx.x; j < n * m; j += blockDim.x)
|
||||
match_p[i * n * m + j] = 0;
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x)
|
||||
remainL[j] = n >= m ? 1 : m / n;
|
||||
for (int j = threadIdx.x; j < m; j += blockDim.x)
|
||||
remainR[j] = n >= m ? n / m : 1;
|
||||
__syncthreads();
|
||||
|
||||
for (int j = 7; j >= -2; j--) {
|
||||
out_type level = j > -2 ? -powf(4.0f, j) : 0;
|
||||
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x) {
|
||||
int k = k0 + threadIdx.x;
|
||||
out_type x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k < n) {
|
||||
x1 = @xyz1(i, k, 0);
|
||||
y1 = @xyz1(i, k, 1);
|
||||
z1 = @xyz1(i, k, 2);
|
||||
}
|
||||
|
||||
out_type suml = 1e-9f;
|
||||
for (int l0 = 0; l0 < m; l0 += Block){
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l = threadIdx.x; l < lend; l += blockDim.x) {
|
||||
buf[l * 4 + 0] = @xyz2(i, l0 + l, 0);
|
||||
buf[l * 4 + 1] = @xyz2(i, l0 + l, 1);
|
||||
buf[l * 4 + 2] = @xyz2(i, l0 + l, 2);
|
||||
buf[l * 4 + 3] = remainR[l0 + l];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int l = 0; l < lend; l++){
|
||||
out_type x2 = buf[l * 4 + 0];
|
||||
out_type y2 = buf[l * 4 + 1];
|
||||
out_type z2 = buf[l * 4 + 2];
|
||||
out_type d = level * dist2(x1, y1, z1, x2, y2, z2);
|
||||
out_type w = __expf(d) * buf[l * 4 + 3];
|
||||
suml += w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k < n)
|
||||
ratioL[k] = remainL[k] / suml;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int l0 = 0; l0 < m; l0 += blockDim.x){
|
||||
int l = l0 + threadIdx.x;
|
||||
out_type x2 = 0, y2 = 0, z2 = 0;
|
||||
if (l < m){
|
||||
x2 = @xyz2(i, l, 0);
|
||||
y2 = @xyz2(i, l, 1);
|
||||
z2 = @xyz2(i, l, 2);
|
||||
}
|
||||
out_type sumr = 0;
|
||||
for (int k0 = 0; k0 < n; k0 += Block){
|
||||
int kend = min(n, k0 + Block) - k0;
|
||||
for (int k = threadIdx.x; k < kend; k += blockDim.x){
|
||||
buf[k * 4 + 0] = @xyz1(i, k0 + k, 0);
|
||||
buf[k * 4 + 1] = @xyz1(i, k0 + k, 1);
|
||||
buf[k * 4 + 2] = @xyz1(i, k0 + k, 2);
|
||||
buf[k * 4 + 3] = ratioL[k0 + k];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < kend; k++){
|
||||
out_type x1 = buf[k * 4 + 0];
|
||||
out_type y1 = buf[k * 4 + 1];
|
||||
out_type z1 = buf[k * 4 + 2];
|
||||
out_type d = level * dist2(x1, y1, z1, x2, y2, z2);
|
||||
out_type w = __expf(d) * buf[k * 4 + 3];
|
||||
sumr += w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (l < m){
|
||||
sumr *= remainR[l];
|
||||
out_type consumption = fminf(remainR[l] / (sumr + 1e-9f), 1.0f);
|
||||
ratioR[l] = consumption * remainR[l];
|
||||
remainR[l] = fmaxf(0.0f, remainR[l] - sumr);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x){
|
||||
int k = k0 + threadIdx.x;
|
||||
out_type x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k < n){
|
||||
x1 = @xyz1(i, k, 0);
|
||||
y1 = @xyz1(i, k, 1);
|
||||
z1 = @xyz1(i, k, 2);
|
||||
}
|
||||
out_type suml = 0;
|
||||
for (int l0 = 0; l0 < m; l0 += Block){
|
||||
int lend = min(m, l0 + Block)-l0;
|
||||
for (int l = threadIdx.x; l < lend; l += blockDim.x){
|
||||
buf[l * 4 + 0] = @xyz2(i, l0 + l, 0);
|
||||
buf[l * 4 + 1] = @xyz2(i, l0 + l, 1);
|
||||
buf[l * 4 + 2] = @xyz2(i, l0 + l, 2);
|
||||
buf[l * 4 + 3] = ratioR[l0 + l];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
out_type rl = ratioL[k];
|
||||
if (k < n){
|
||||
for (int l = 0; l < lend; l++){
|
||||
out_type x2 = buf[l * 4 + 0];
|
||||
out_type y2 = buf[l * 4 + 1];
|
||||
out_type z2 = buf[l * 4 + 2];
|
||||
out_type d = level * dist2(x1, y1, z1, x2, y2, z2);
|
||||
out_type w = __expf(d) * rl * buf[l*4+3];
|
||||
@match(i, l0 + l, k) += w;
|
||||
suml += w;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k < n)
|
||||
remainL[k] = fmaxf(0.0f, remainL[k] - suml);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
approxmatch_gpu_kernel<<<32, 512>>>(@ARGS);
|
||||
'''
|
||||
|
||||
matchcost_gpu_src = '''
|
||||
__global__ void matchcost_gpu_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(xyz1, in0)
|
||||
@alias(xyz2, in1)
|
||||
@alias(match, in2)
|
||||
|
||||
int b = in0_shape0;
|
||||
int n = in0_shape1;
|
||||
int m = in1_shape1;
|
||||
|
||||
const int Block = 1024;
|
||||
__shared__ out_type allsum[512];
|
||||
__shared__ out_type buf[Block * 3];
|
||||
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
||||
out_type subsum = 0;
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x) {
|
||||
int k = k0 + threadIdx.x;
|
||||
out_type x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k < n) {
|
||||
x1 = @xyz1(i, k, 0);
|
||||
y1 = @xyz1(i, k, 1);
|
||||
z1 = @xyz1(i, k, 2);
|
||||
}
|
||||
|
||||
for (int l0 = 0; l0 < m; l0 += Block) {
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l = threadIdx.x; l < lend * 3; l += blockDim.x)
|
||||
buf[l] = xyz2_p[i * m * 3 + l0 * 3 + l];
|
||||
__syncthreads();
|
||||
|
||||
if (k < n) {
|
||||
for (int l = 0; l < lend; l++) {
|
||||
out_type x2 = buf[l * 3 + 0];
|
||||
out_type y2 = buf[l * 3 + 1];
|
||||
out_type z2 = buf[l * 3 + 2];
|
||||
out_type d = dist2(x1, y1, z1, x2, y2, z2);
|
||||
subsum += d * @match(i, l0 + l, k);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
allsum[threadIdx.x] = subsum;
|
||||
for (int j = 1; j < blockDim.x; j <<= 1) {
|
||||
__syncthreads();
|
||||
if ((threadIdx.x & j) == 0 && threadIdx.x + j < blockDim.x) {
|
||||
allsum[threadIdx.x] += allsum[threadIdx.x + j];
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
@out(i) = allsum[0];
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
matchcost_gpu_kernel<<<32, 512>>>(@ARGS);
|
||||
'''
|
||||
|
||||
matchcost_grad1_gpu_src = '''
|
||||
__global__ void matchcost_grad1_gpu_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(grad, in0)
|
||||
@alias(xyz1, in1)
|
||||
@alias(xyz2, in2)
|
||||
@alias(match, in3)
|
||||
|
||||
int b = grad_shape0;
|
||||
int n = xyz1_shape1;
|
||||
int m = xyz2_shape1;
|
||||
|
||||
for (int i = blockIdx.x; i < b ; i += gridDim.x){
|
||||
for (int l = threadIdx.x; l < n; l += blockDim.x){
|
||||
out_type x1 = @xyz1(i, l, 0);
|
||||
out_type y1 = @xyz1(i, l, 1);
|
||||
out_type z1 = @xyz1(i, l, 2);
|
||||
out_type dx = 0, dy = 0, dz = 0;
|
||||
for (int k = 0; k < m; k++){
|
||||
out_type x2 = @xyz2(i, k, 0);
|
||||
out_type y2 = @xyz2(i, k, 1);
|
||||
out_type z2 = @xyz2(i, k, 2);
|
||||
out_type d = @match(i, k, l) * 2;
|
||||
dx += (x1 - x2) * d;
|
||||
dy += (y1 - y2) * d;
|
||||
dz += (z1 - z2) * d;
|
||||
}
|
||||
@out(i, l, 0) = dx * @grad(i);
|
||||
@out(i, l, 1) = dy * @grad(i);
|
||||
@out(i, l, 2) = dz * @grad(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matchcost_grad1_gpu_kernel<<<32, 512>>>(@ARGS);
|
||||
'''
|
||||
|
||||
matchcost_grad2_gpu_src = '''
|
||||
__global__ void matchcost_grad2_gpu_kernel(@ARGS_DEF) {
|
||||
@PRECALC
|
||||
@alias(grad, in0)
|
||||
@alias(xyz1, in1)
|
||||
@alias(xyz2, in2)
|
||||
@alias(match, in3)
|
||||
|
||||
int b = grad_shape0;
|
||||
int n = xyz1_shape1;
|
||||
int m = xyz2_shape1;
|
||||
|
||||
__shared__ out_type sum_grad[256 * 3];
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
||||
int kbeg = m * blockIdx.y / gridDim.y;
|
||||
int kend = m * (blockIdx.y + 1) / gridDim.y;
|
||||
for (int k = kbeg; k < kend; k++) {
|
||||
out_type x2 = @xyz2(i, k, 0);
|
||||
out_type y2 = @xyz2(i, k, 1);
|
||||
out_type z2 = @xyz2(i, k, 2);
|
||||
out_type subsumx = 0, subsumy = 0, subsumz = 0;
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x) {
|
||||
out_type x1 = x2 - @xyz1(i, j, 0);
|
||||
out_type y1 = y2 - @xyz1(i, j, 1);
|
||||
out_type z1 = z2 - @xyz1(i, j, 2);
|
||||
out_type d = @match(i, k, j) * 2;
|
||||
subsumx += x1 * d;
|
||||
subsumy += y1 * d;
|
||||
subsumz += z1 * d;
|
||||
}
|
||||
sum_grad[threadIdx.x * 3 + 0] = subsumx;
|
||||
sum_grad[threadIdx.x * 3 + 1] = subsumy;
|
||||
sum_grad[threadIdx.x * 3 + 2] = subsumz;
|
||||
|
||||
for (int j = 1; j < blockDim.x; j <<= 1) {
|
||||
__syncthreads();
|
||||
int j1 = threadIdx.x;
|
||||
int j2 = threadIdx.x + j;
|
||||
if ((j1 & j) == 0 && j2 < blockDim.x){
|
||||
sum_grad[j1 * 3 + 0] += sum_grad[j2 * 3 + 0];
|
||||
sum_grad[j1 * 3 + 1] += sum_grad[j2 * 3 + 1];
|
||||
sum_grad[j1 * 3 + 2] += sum_grad[j2 * 3 + 2];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0){
|
||||
@out(i, k, 0) = sum_grad[0] * @grad(i);
|
||||
@out(i, k, 1) = sum_grad[1] * @grad(i);
|
||||
@out(i, k, 2) = sum_grad[2] * @grad(i);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matchcost_grad2_gpu_kernel<<<dim3(32, 32), 256>>>(@ARGS);
|
||||
'''
|
||||
|
||||
class EarthMoverDistance(Function):
|
||||
''' A loss layer that computes Earth Mover's distance from pc1 to pc2. Only supports GPU.
|
||||
|
||||
:param pc1: input point cloud
|
||||
:type pc1: jittor array
|
||||
|
||||
:param pc2: input point cloud
|
||||
:type pc2: jittor array
|
||||
|
||||
:param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'.
|
||||
:type reduction: str, optional
|
||||
|
||||
:param dims: a string that represents each dimension, can be
|
||||
'[BNC]' ([batch, number of points, xyz]), or
|
||||
'[BCN]' ([batch, xyz, number of points]). Default: 'BNC'.
|
||||
:type dims: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jittor as jt
|
||||
>>> from jittor.loss3d import EarthMoverDistance
|
||||
>>> jt.flags.use_cuda = True
|
||||
>>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> EMD = EarthMoverDistance(dims='BNC')
|
||||
>>> emd = EMD(pc1, pc2)
|
||||
>>> print('EMD =', emd.item())
|
||||
'''
|
||||
def execute(self, pc1, pc2, reduction='mean', dims='BNC'):
|
||||
assert dims in ['BNC', 'BCN']
|
||||
if dims == 'BCN':
|
||||
pc1, pc2 = pc1.permute(0, 2, 1), pc2.permute(0, 2, 1)
|
||||
|
||||
batch_size_1, N, _ = pc1.shape
|
||||
batch_size_2, M, _ = pc2.shape
|
||||
assert batch_size_1 == batch_size_2
|
||||
batch_size = batch_size_1
|
||||
|
||||
temp = jt.zeros([batch_size, (N + M) * 2], pc1.dtype)
|
||||
match = jt.code(
|
||||
shape=[batch_size, M, N],
|
||||
dtype=pc1.dtype,
|
||||
inputs=[pc1, pc2, temp],
|
||||
cuda_header=EMD_gpu_header,
|
||||
cuda_src=approxmatch_gpu_src,
|
||||
)
|
||||
|
||||
emd = jt.code(
|
||||
shape=[batch_size],
|
||||
dtype=pc1.dtype,
|
||||
inputs=[pc1, pc2, match],
|
||||
cuda_header=EMD_gpu_header,
|
||||
cuda_src=matchcost_gpu_src,
|
||||
)
|
||||
|
||||
self.saved_vars = (pc1, pc2, match, reduction)
|
||||
|
||||
if reduction is None:
|
||||
return emd
|
||||
elif reduction == 'sum':
|
||||
return emd.sum()
|
||||
elif reduction == 'mean':
|
||||
return emd.mean()
|
||||
|
||||
def grad(self, grad):
|
||||
pc1, pc2, match, reduction = self.saved_vars
|
||||
|
||||
if reduction == 'sum':
|
||||
grad = jt.ones([pc1.shape[0]]) * grad
|
||||
elif reduction == 'mean':
|
||||
grad = jt.ones([pc1.shape[0]]) * grad / pc1.shape[0]
|
||||
|
||||
grad_pc1 = jt.code(
|
||||
shape=pc1.shape,
|
||||
dtype=pc1.dtype,
|
||||
inputs=[grad, pc1, pc2, match],
|
||||
cuda_src=matchcost_grad1_gpu_src,
|
||||
)
|
||||
|
||||
grad_pc2 = jt.code(
|
||||
shape=pc2.shape,
|
||||
dtype=pc2.dtype,
|
||||
inputs=[grad, pc1, pc2, match],
|
||||
cuda_src=matchcost_grad2_gpu_src,
|
||||
)
|
||||
|
||||
return grad_pc1, grad_pc2
|
||||
|
||||
|
||||
def earth_mover_distance(pc1, pc2, reduction='mean', dims='BNC'):
|
||||
''' Earth Mover's distance from pc1 to pc2. Only supports GPU.
|
||||
|
||||
:param pc1: input point cloud
|
||||
:type pc1: jittor array
|
||||
|
||||
:param pc2: input point cloud
|
||||
:type pc2: jittor array
|
||||
|
||||
:param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'.
|
||||
:type reduction: str, optional
|
||||
|
||||
:param dims: a string that represents each dimension, can be
|
||||
'[BNC]' ([batch, number of points, xyz]), or
|
||||
'[BCN]' ([batch, xyz, number of points]). Default: 'BNC'.
|
||||
:type dims: str, optional
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jittor as jt
|
||||
>>> from jittor.loss3d import earth_mover_distance
|
||||
>>> jt.flags.use_cuda = True
|
||||
>>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32)
|
||||
>>> emd = earth_mover_distance(pc1, pc2, dims='BNC')
|
||||
>>> print('EMD =', emd.item())
|
||||
'''
|
||||
return EarthMoverDistance.apply(pc1, pc2, reduction, dims)
|
|
@ -0,0 +1,88 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import torch
|
||||
from emd import earth_mover_distance as TEMD
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
import jittor as jt
|
||||
from jittor.loss3d import chamfer_loss
|
||||
from jittor.loss3d import earth_mover_distance
|
||||
|
||||
|
||||
class TestLoss3d(unittest.TestCase):
|
||||
def test_chamfer(self):
|
||||
def test():
|
||||
pc1 = np.random.randn(10, 100, 3).astype(np.float32)
|
||||
pc2 = np.random.randn(10, 100, 3).astype(np.float32)
|
||||
|
||||
Jpc1 = jt.array(pc1)
|
||||
Jpc2 = jt.array(pc2)
|
||||
Jcf = chamfer_loss(Jpc1, Jpc2, dims='BNC')
|
||||
|
||||
ppc1 = np.repeat(pc1[:, :, None, :], 100, axis=2)
|
||||
ppc2 = np.repeat(pc2[:, None, :, :], 100, axis=1)
|
||||
ncf = np.sqrt(((ppc1 - ppc2) ** 2).sum(axis=-1)).min(axis=-1)
|
||||
ncf = ncf.mean()
|
||||
|
||||
self.assertTrue(np.allclose(ncf, Jcf.item()))
|
||||
|
||||
jt.flags.use_cuda = False
|
||||
test()
|
||||
jt.flags.use_cuda = True
|
||||
test()
|
||||
|
||||
def test_chamfer_dims(self):
|
||||
def test():
|
||||
pc1 = np.random.randn(10, 100, 3).astype(np.float32)
|
||||
pc2 = np.random.randn(10, 100, 3).astype(np.float32)
|
||||
|
||||
Jpc1 = jt.array(pc1.transpose([0, 2, 1]))
|
||||
Jpc2 = jt.array(pc2.transpose([0, 2, 1]))
|
||||
Jcf = chamfer_loss(Jpc1, Jpc2, dims='BCN')
|
||||
|
||||
ppc1 = np.repeat(pc1[:, :, None, :], 100, axis=2)
|
||||
ppc2 = np.repeat(pc2[:, None, :, :], 100, axis=1)
|
||||
ncf = np.sqrt(((ppc1 - ppc2) ** 2).sum(axis=-1)).min(axis=-1)
|
||||
ncf = ncf.mean()
|
||||
|
||||
self.assertTrue(np.allclose(ncf, Jcf.item()))
|
||||
|
||||
jt.flags.use_cuda = False
|
||||
test()
|
||||
jt.flags.use_cuda = True
|
||||
test()
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Pyorch_EMD found")
|
||||
def test_emd_torch(self):
|
||||
jt.flags.use_cuda = True
|
||||
|
||||
pc1 = np.random.randn(10, 100, 3).astype(np.float32)
|
||||
pc2 = np.random.randn(10, 50, 3).astype(np.float32)
|
||||
|
||||
Tpc1 = torch.from_numpy(pc1).cuda()
|
||||
Tpc2 = torch.from_numpy(pc2).cuda()
|
||||
Tpc1.requires_grad = True
|
||||
Tpc2.requires_grad = True
|
||||
Temdcost = TEMD(Tpc1, Tpc2, transpose=False)
|
||||
Temd = Temdcost.mean()
|
||||
|
||||
Jpc1 = jt.array(pc1)
|
||||
Jpc2 = jt.array(pc2)
|
||||
Jemd = earth_mover_distance(Jpc1, Jpc2, dims='BNC')
|
||||
|
||||
Temd.backward()
|
||||
Tgrad1 = Tpc1.grad.cpu().numpy()
|
||||
Tgrad2 = Tpc2.grad.cpu().numpy()
|
||||
|
||||
Jgrad1, Jgrad2 = jt.grad(Jemd, [Jpc1, Jpc2])
|
||||
|
||||
self.assertTrue(np.allclose(Temd.item(), Jemd.item()), Temd.item() - Jemd.item())
|
||||
self.assertTrue(np.allclose(Tgrad1, Jgrad1.data, atol=1e-4), np.abs(Tgrad1 - Jgrad1.data).max())
|
||||
self.assertTrue(np.allclose(Tgrad2, Jgrad2.data, atol=1e-4), np.abs(Tgrad2 - Jgrad2.data).max())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue