add ctcloss

This commit is contained in:
Dun Liang 2021-09-08 17:45:17 +08:00
parent 057dd95658
commit f807a28e6b
4 changed files with 431 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.99'
__version__ = '1.2.3.100'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -1311,3 +1311,348 @@ jt.Var.roll = roll
def safe_log(x):
return jt.safe_clip(x, 1e-30, 1e30).log()
jt.Var.safe_log = safe_log
class _CTCLossFunction(jt.Function):
def execute(self, log_probs, targets, input_lengths, target_lengths, blank=0, zero_infinity=False):
self.blank = blank
T, N, C = log_probs.shape
_N, S = targets.shape
assert _N == N
log_alpha = jt.full([T,N,S*2+1], -1e30)
result = jt.empty((N,))
jt.code([log_probs, targets, input_lengths, target_lengths], [log_alpha, result], cpu_src=f"""
constexpr int blank = {blank};
for (int i=0; i<in0_shape1; i++) {{
int input_len = @in2(i);
int target_len = @in3(i);
@out0(0,i,0) = @in0(0,i,blank);
if (target_len)
@out0(0,i,1) = @in0(0,i,@in1(i,0));
for (int j=1; j<input_len; j++)
for (int k=0; k<target_len*2+1; k++) {{
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @out0(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @out0(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @out0(j-1,i,k-2);
out_type m = std::max(l1, std::max(l2, l3));
@out0(j,i,k) = std::log(
std::exp(l1-m) +
std::exp(l2-m) +
std::exp(l3-m)
) + m + @in0(j,i,target);
}}
if (input_len==0)
@out1(i) = @out0(0,i,0);
else {{
out_type l1 = @out0(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @out0(input_len-1, i, target_len*2-1);
out_type m = std::max(l1, l2);
out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
@out1(i) = -log_likelihood;
}}
}}
""", cuda_src=f"""
__global__ void kernel(@ARGS_DEF) {{
@PRECALC;
constexpr int blank = {blank};
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
int input_len = @in2(i);
int target_len = @in3(i);
@out0(0,i,0) = @in0(0,i,blank);
if (target_len)
@out0(0,i,1) = @in0(0,i,@in1(i,0));
for (int j=1; j<input_len; j++)
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
__syncthreads();
if (k>=target_len*2+1)
continue;
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @out0(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @out0(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @out0(j-1,i,k-2);
out_type m = ::max(l1, ::max(l2, l3));
@out0(j,i,k) = ::log(
::exp(l1-m) +
::exp(l2-m) +
::exp(l3-m)
) + m + @in0(j,i,target);
}}
__syncthreads();
if (input_len==0)
@out1(i) = @out0(0,i,0);
else {{
out_type l1 = @out0(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @out0(input_len-1, i, target_len*2-1);
out_type m = ::max(l1, l2);
out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
@out1(i) = -log_likelihood;
}}
}}
}}
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
""")
self.saved_var = [log_probs, targets, input_lengths, target_lengths, log_alpha, result]
return result
def grad(self, dout):
blank = self.blank
inputs = self.saved_var + [dout]
dlog_probs = jt.zeros_like(inputs[0])
dlog_alpha = jt.zeros_like(inputs[4])
jt.code(inputs, [dlog_probs, dlog_alpha], cpu_src=f"""
constexpr int blank = {blank};
for (int i=0; i<in0_shape1; i++) {{
int input_len = @in2(i);
int target_len = @in3(i);
if (input_len==0)
// write out1 --> read in6
// out1(i) = out0(0,i,0);
@out1(0,i,0) = @in6(i);
else {{
out_type l1 = @in4(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @in4(input_len-1, i, target_len*2-1);
out_type m = std::max(l1, l2);
// out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
// out1(i) = -log_likelihood;
out_type l1_exp = std::exp(l1-m);
out_type l2_exp = std::exp(l2-m);
out_type sumexp = l1_exp + l2_exp;
out_type dlog_likelihood = -@in6(i);
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
@out1(input_len-1, i, target_len*2) = dl1;
if (target_len)
@out1(input_len-1, i, target_len*2-1) = dl2;
}}
for (int j=input_len-1; j>0; j--)
for (int k=0; k<target_len*2+1; k++) {{
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @in4(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @in4(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @in4(j-1,i,k-2);
out_type m = std::max(l1, std::max(l2, l3));
out_type l1_exp = std::exp(l1-m);
out_type l2_exp = std::exp(l2-m);
out_type l3_exp = std::exp(l3-m);
out_type sumexp = l1_exp + l2_exp + l3_exp;
out_type dalpha = @out1(j,i,k);
@out0(j,i,target) += dalpha;
@out1(j-1,i,k) += dalpha * l1_exp / sumexp;
if (k>0)
@out1(j-1,i,k-1) += dalpha * l2_exp / sumexp;
if (k>1 && target_2 != target)
@out1(j-1,i,k-2) += dalpha * l3_exp / sumexp;
}}
// read in0 -> white out0
// write out0 ->read out1
// out0(0,i,0) = in0(0,i,blank);
@out0(0,i,blank) += @out1(0,i,0);
if (target_len)
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
}}
""", cuda_src=f"""
__global__ void kernel(@ARGS_DEF) {{
@PRECALC;
constexpr int blank = {blank};
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
int input_len = @in2(i);
int target_len = @in3(i);
if (input_len==0)
// write out1 --> read in6
// out1(i) = out0(0,i,0);
@out1(0,i,0) = @in6(i);
else {{
out_type l1 = @in4(input_len-1, i, target_len*2);
out_type l2 = -1e30;
if (target_len)
l2 = @in4(input_len-1, i, target_len*2-1);
out_type m = ::max(l1, l2);
// out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
// out1(i) = -log_likelihood;
out_type l1_exp = ::exp(l1-m);
out_type l2_exp = ::exp(l2-m);
out_type sumexp = l1_exp + l2_exp;
out_type dlog_likelihood = -@in6(i);
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
@out1(input_len-1, i, target_len*2) = dl1;
if (target_len)
@out1(input_len-1, i, target_len*2-1) = dl2;
}}
for (int j=input_len-1; j>0; j--)
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
__syncthreads();
if (k>=target_len*2+1)
continue;
int target = k%2 ? @in1(i,k/2) : blank;
int target_2 = target;
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
out_type l1 = @in4(j-1,i,k);
out_type l2 = -1e30;
if (k>0) l2 = @in4(j-1,i,k-1);
out_type l3 = -1e30;
if (k>1 && target_2 != target)
l3 = @in4(j-1,i,k-2);
out_type m = ::max(l1, ::max(l2, l3));
out_type l1_exp = ::exp(l1-m);
out_type l2_exp = ::exp(l2-m);
out_type l3_exp = ::exp(l3-m);
out_type sumexp = l1_exp + l2_exp + l3_exp;
out_type dalpha = @out1(j,i,k);
atomicAdd(&@out0(j,i,target), dalpha);
atomicAdd(&@out1(j-1,i,k), dalpha * l1_exp / sumexp);
if (k>0)
atomicAdd(&@out1(j-1,i,k-1), dalpha * l2_exp / sumexp);
if (k>1 && target_2 != target)
atomicAdd(&@out1(j-1,i,k-2), dalpha * l3_exp / sumexp);
}}
// read in0 -> white out0
// write out0 ->read out1
// out0(0,i,0) = in0(0,i,blank);
__syncthreads();
if (threadIdx.x==0) {{
@out0(0,i,blank) += @out1(0,i,0);
if (target_len)
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
}}
}}
}}
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
""")
return (dlog_probs,)
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
'''The Connectionist Temporal Classification loss.
Reference:
A. Graves et al.: Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
https://www.cs.toronto.edu/~graves/icml_2006.pdf
Input:
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
target_lengths: shape is N, which represents the length of target, element should between [0,S].
blank (int, default 0): blank label index
reduction (string): reduce batch loss,
if reduction is none, it will return (N,) array,
if reduction is mean or sum, it will return one scalar
zero_infinity (bool, default False):
zero_infinity for grad
Example:
import jittor as jt
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
S_min = 10 # Minimum target length, for demonstration purposes
input = jt.randn(T, N, C).log_softmax(2)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
loss = jt.ctc_loss(input, target, input_lengths, target_lengths)
dinput = jt.grad(loss, input)
'''
result = _CTCLossFunction.apply(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity)
if reduction=="mean":
return result.mean()
elif reduction=="sum":
return result.sum()
assert reduction=="none"
return result
class CTCLoss(jt.Module):
'''The Connectionist Temporal Classification loss.
Reference:
A. Graves et al.: Connectionist Temporal Classification:
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
https://www.cs.toronto.edu/~graves/icml_2006.pdf
Args:
blank (int, default 0): blank label index
reduction (string): reduce batch loss,
if reduction is none, it will return (N,) array,
if reduction is mean or sum, it will return one scalar
zero_infinity (bool, default False):
zero_infinity for grad
Input:
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
target_lengths: shape is N, which represents the length of target, element should between [0,S].
Example:
import jittor as jt
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
S_min = 10 # Minimum target length, for demonstration purposes
input = jt.randn(T, N, C).log_softmax(2)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
ctc_loss = jt.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
dinput = jt.grad(loss, input)
'''
def __init__(self, blank=0, reduction='mean', zero_infinity=False):
self.blank = blank
self.reduction = reduction
self.zero_infinity = zero_infinity
def execute(self, log_probs, targets, input_lengths, target_lengths):
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)

View File

@ -141,7 +141,7 @@ Example::
an = a.shape[ai] if ai>=0 else 1
bn = b.shape[bi] if bi>=0 else 1
if an!=1 and bn!=1:
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
cn = max(an, bn)
shape.append(cn)
shape.extend([n, m, k])
@ -341,16 +341,20 @@ def softmax(x, dim = None):
x = (x-x.max(dim, keepdims=True)).exp()
ret = x / x.sum(dim, keepdims=True)
return ret
jt.Var.softmax = softmax
def log_softmax(x,dim=None):
x = softmax(x,dim=dim)
return jt.log(x)
jt.Var.log_softmax = log_softmax
def log_sigmoid(x):
return jt.log(jt.sigmoid(x))
jt.Var.log_sigmoid = log_sigmoid
def logsumexp(x, dim, keepdim=False):
return x.exp().sum(dim, keepdim).log()
jt.Var.logsumexp = logsumexp
class Identity(Module):
def __init__(self, *args, **kwargs):

View File

@ -159,6 +159,86 @@ class TestPad(unittest.TestCase):
out.detach().numpy(), output.data,
atol=1e-4)
def test_ctc_loss(self):
def check(T,C,N,S,S_min):
jt.set_global_seed(0)
# Initialize random batch of input vectors, for *size = (T,N,C)
input = jt.randn(T, N, C).log_softmax(2)
# input = -jt.ones((T, N, C))
# input[0,0,1] += 0.01
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
_input_jt = input
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
# ctc_loss = nn.CTCLoss()
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
_loss_jt = loss
loss_jt = loss.numpy()
input = torch.Tensor(input.numpy()).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.LongTensor(target_lengths.numpy())
input_lengths = torch.LongTensor(input_lengths.numpy())
target = torch.LongTensor(target.numpy())
loss = tnn.CTCLoss(reduction='none')(input, target, input_lengths, target_lengths)
np.testing.assert_allclose(loss.detach().numpy(), loss_jt, rtol=1e-5, atol=1e-5)
dinput_jt = jt.grad(_loss_jt, _input_jt)
dinput_jt.sync()
loss.sum().backward()
# print(input.grad)
# print(dinput_jt)
# print(loss)
def check_gpu_with_cpu(T,C,N,S,S_min):
jt.set_global_seed(1)
# Initialize random batch of input vectors, for *size = (T,N,C)
input = jt.randn(T, N, C).log_softmax(2)
# input = -jt.ones((T, N, C))
# input[0,0,1] += 0.01
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
_input_jt = input
input_lengths = jt.full((N,), T, dtype=jt.int)
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
# ctc_loss = nn.CTCLoss()
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
_loss_jt = loss
loss_jt = loss.numpy()
dinput_jt = jt.grad(_loss_jt, _input_jt)
dinput_jt.sync()
with jt.flag_scope(use_cuda=1):
input = input.copy()
target = target.copy()
input_lengths = input_lengths.copy()
target_lengths = target_lengths.copy()
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
grad = jt.grad(loss, input)
np.testing.assert_allclose(_loss_jt.numpy(), loss.numpy(), atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(dinput_jt.numpy(), grad.numpy(), atol=1e-5, rtol=1e-5)
check(2,2,1,1,1)
check(50,20,16,30,10)
if jt.has_cuda:
with jt.flag_scope(use_cuda=1):
check(2,2,1,1,1)
check(50,20,16,30,10)
check_gpu_with_cpu(50,20,16,30,10)
class TestOther(unittest.TestCase):
def test_save(self):
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]