mirror of https://github.com/Jittor/Jittor
add ctcloss
This commit is contained in:
parent
057dd95658
commit
f807a28e6b
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.99'
|
||||
__version__ = '1.2.3.100'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1311,3 +1311,348 @@ jt.Var.roll = roll
|
|||
def safe_log(x):
|
||||
return jt.safe_clip(x, 1e-30, 1e30).log()
|
||||
jt.Var.safe_log = safe_log
|
||||
|
||||
class _CTCLossFunction(jt.Function):
|
||||
def execute(self, log_probs, targets, input_lengths, target_lengths, blank=0, zero_infinity=False):
|
||||
self.blank = blank
|
||||
T, N, C = log_probs.shape
|
||||
_N, S = targets.shape
|
||||
assert _N == N
|
||||
log_alpha = jt.full([T,N,S*2+1], -1e30)
|
||||
result = jt.empty((N,))
|
||||
jt.code([log_probs, targets, input_lengths, target_lengths], [log_alpha, result], cpu_src=f"""
|
||||
constexpr int blank = {blank};
|
||||
for (int i=0; i<in0_shape1; i++) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
@out0(0,i,0) = @in0(0,i,blank);
|
||||
if (target_len)
|
||||
@out0(0,i,1) = @in0(0,i,@in1(i,0));
|
||||
for (int j=1; j<input_len; j++)
|
||||
for (int k=0; k<target_len*2+1; k++) {{
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @out0(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @out0(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @out0(j-1,i,k-2);
|
||||
out_type m = std::max(l1, std::max(l2, l3));
|
||||
@out0(j,i,k) = std::log(
|
||||
std::exp(l1-m) +
|
||||
std::exp(l2-m) +
|
||||
std::exp(l3-m)
|
||||
) + m + @in0(j,i,target);
|
||||
}}
|
||||
if (input_len==0)
|
||||
@out1(i) = @out0(0,i,0);
|
||||
else {{
|
||||
out_type l1 = @out0(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @out0(input_len-1, i, target_len*2-1);
|
||||
out_type m = std::max(l1, l2);
|
||||
out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
|
||||
@out1(i) = -log_likelihood;
|
||||
}}
|
||||
}}
|
||||
""", cuda_src=f"""
|
||||
__global__ void kernel(@ARGS_DEF) {{
|
||||
@PRECALC;
|
||||
constexpr int blank = {blank};
|
||||
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
@out0(0,i,0) = @in0(0,i,blank);
|
||||
if (target_len)
|
||||
@out0(0,i,1) = @in0(0,i,@in1(i,0));
|
||||
for (int j=1; j<input_len; j++)
|
||||
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
|
||||
__syncthreads();
|
||||
if (k>=target_len*2+1)
|
||||
continue;
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @out0(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @out0(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @out0(j-1,i,k-2);
|
||||
out_type m = ::max(l1, ::max(l2, l3));
|
||||
@out0(j,i,k) = ::log(
|
||||
::exp(l1-m) +
|
||||
::exp(l2-m) +
|
||||
::exp(l3-m)
|
||||
) + m + @in0(j,i,target);
|
||||
}}
|
||||
__syncthreads();
|
||||
if (input_len==0)
|
||||
@out1(i) = @out0(0,i,0);
|
||||
else {{
|
||||
out_type l1 = @out0(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @out0(input_len-1, i, target_len*2-1);
|
||||
out_type m = ::max(l1, l2);
|
||||
out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
|
||||
@out1(i) = -log_likelihood;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
|
||||
""")
|
||||
self.saved_var = [log_probs, targets, input_lengths, target_lengths, log_alpha, result]
|
||||
return result
|
||||
|
||||
def grad(self, dout):
|
||||
blank = self.blank
|
||||
inputs = self.saved_var + [dout]
|
||||
dlog_probs = jt.zeros_like(inputs[0])
|
||||
dlog_alpha = jt.zeros_like(inputs[4])
|
||||
jt.code(inputs, [dlog_probs, dlog_alpha], cpu_src=f"""
|
||||
constexpr int blank = {blank};
|
||||
for (int i=0; i<in0_shape1; i++) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
if (input_len==0)
|
||||
// write out1 --> read in6
|
||||
// out1(i) = out0(0,i,0);
|
||||
@out1(0,i,0) = @in6(i);
|
||||
else {{
|
||||
out_type l1 = @in4(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @in4(input_len-1, i, target_len*2-1);
|
||||
out_type m = std::max(l1, l2);
|
||||
// out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
|
||||
// out1(i) = -log_likelihood;
|
||||
out_type l1_exp = std::exp(l1-m);
|
||||
out_type l2_exp = std::exp(l2-m);
|
||||
out_type sumexp = l1_exp + l2_exp;
|
||||
|
||||
out_type dlog_likelihood = -@in6(i);
|
||||
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
|
||||
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
|
||||
|
||||
@out1(input_len-1, i, target_len*2) = dl1;
|
||||
if (target_len)
|
||||
@out1(input_len-1, i, target_len*2-1) = dl2;
|
||||
}}
|
||||
for (int j=input_len-1; j>0; j--)
|
||||
for (int k=0; k<target_len*2+1; k++) {{
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @in4(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @in4(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @in4(j-1,i,k-2);
|
||||
out_type m = std::max(l1, std::max(l2, l3));
|
||||
out_type l1_exp = std::exp(l1-m);
|
||||
out_type l2_exp = std::exp(l2-m);
|
||||
out_type l3_exp = std::exp(l3-m);
|
||||
out_type sumexp = l1_exp + l2_exp + l3_exp;
|
||||
out_type dalpha = @out1(j,i,k);
|
||||
|
||||
@out0(j,i,target) += dalpha;
|
||||
|
||||
@out1(j-1,i,k) += dalpha * l1_exp / sumexp;
|
||||
if (k>0)
|
||||
@out1(j-1,i,k-1) += dalpha * l2_exp / sumexp;
|
||||
if (k>1 && target_2 != target)
|
||||
@out1(j-1,i,k-2) += dalpha * l3_exp / sumexp;
|
||||
}}
|
||||
// read in0 -> white out0
|
||||
// write out0 ->read out1
|
||||
// out0(0,i,0) = in0(0,i,blank);
|
||||
@out0(0,i,blank) += @out1(0,i,0);
|
||||
if (target_len)
|
||||
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
|
||||
}}
|
||||
""", cuda_src=f"""
|
||||
__global__ void kernel(@ARGS_DEF) {{
|
||||
@PRECALC;
|
||||
constexpr int blank = {blank};
|
||||
for (int i=blockIdx.x; i<in0_shape1; i+=gridDim.x) {{
|
||||
int input_len = @in2(i);
|
||||
int target_len = @in3(i);
|
||||
if (input_len==0)
|
||||
// write out1 --> read in6
|
||||
// out1(i) = out0(0,i,0);
|
||||
@out1(0,i,0) = @in6(i);
|
||||
else {{
|
||||
out_type l1 = @in4(input_len-1, i, target_len*2);
|
||||
out_type l2 = -1e30;
|
||||
if (target_len)
|
||||
l2 = @in4(input_len-1, i, target_len*2-1);
|
||||
out_type m = ::max(l1, l2);
|
||||
// out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m;
|
||||
// out1(i) = -log_likelihood;
|
||||
out_type l1_exp = ::exp(l1-m);
|
||||
out_type l2_exp = ::exp(l2-m);
|
||||
out_type sumexp = l1_exp + l2_exp;
|
||||
|
||||
out_type dlog_likelihood = -@in6(i);
|
||||
out_type dl1 = dlog_likelihood * l1_exp / sumexp;
|
||||
out_type dl2 = dlog_likelihood * l2_exp / sumexp;
|
||||
|
||||
@out1(input_len-1, i, target_len*2) = dl1;
|
||||
if (target_len)
|
||||
@out1(input_len-1, i, target_len*2-1) = dl2;
|
||||
}}
|
||||
for (int j=input_len-1; j>0; j--)
|
||||
for (int k=threadIdx.x; k-threadIdx.x<target_len*2+1; k+=blockDim.x) {{
|
||||
__syncthreads();
|
||||
if (k>=target_len*2+1)
|
||||
continue;
|
||||
int target = k%2 ? @in1(i,k/2) : blank;
|
||||
int target_2 = target;
|
||||
if (k>1 && k%2) target_2 = @in1(i,k/2-1);
|
||||
out_type l1 = @in4(j-1,i,k);
|
||||
out_type l2 = -1e30;
|
||||
if (k>0) l2 = @in4(j-1,i,k-1);
|
||||
out_type l3 = -1e30;
|
||||
if (k>1 && target_2 != target)
|
||||
l3 = @in4(j-1,i,k-2);
|
||||
out_type m = ::max(l1, ::max(l2, l3));
|
||||
out_type l1_exp = ::exp(l1-m);
|
||||
out_type l2_exp = ::exp(l2-m);
|
||||
out_type l3_exp = ::exp(l3-m);
|
||||
out_type sumexp = l1_exp + l2_exp + l3_exp;
|
||||
out_type dalpha = @out1(j,i,k);
|
||||
|
||||
atomicAdd(&@out0(j,i,target), dalpha);
|
||||
|
||||
atomicAdd(&@out1(j-1,i,k), dalpha * l1_exp / sumexp);
|
||||
if (k>0)
|
||||
atomicAdd(&@out1(j-1,i,k-1), dalpha * l2_exp / sumexp);
|
||||
if (k>1 && target_2 != target)
|
||||
atomicAdd(&@out1(j-1,i,k-2), dalpha * l3_exp / sumexp);
|
||||
}}
|
||||
// read in0 -> white out0
|
||||
// write out0 ->read out1
|
||||
// out0(0,i,0) = in0(0,i,blank);
|
||||
__syncthreads();
|
||||
if (threadIdx.x==0) {{
|
||||
@out0(0,i,blank) += @out1(0,i,0);
|
||||
if (target_len)
|
||||
@out0(0,i,@in1(i,0)) += @out1(0,i,1);
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
kernel<<<std::min(in0_shape1, 1024), std::min(in1_shape1*2+1, 1024)>>>(@ARGS);
|
||||
""")
|
||||
return (dlog_probs,)
|
||||
|
||||
|
||||
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
|
||||
'''The Connectionist Temporal Classification loss.
|
||||
|
||||
|
||||
Reference:
|
||||
A. Graves et al.: Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
|
||||
https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||
|
||||
Input:
|
||||
|
||||
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
|
||||
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
|
||||
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
|
||||
target_lengths: shape is N, which represents the length of target, element should between [0,S].
|
||||
blank (int, default 0): blank label index
|
||||
reduction (string): reduce batch loss,
|
||||
if reduction is none, it will return (N,) array,
|
||||
if reduction is mean or sum, it will return one scalar
|
||||
zero_infinity (bool, default False):
|
||||
zero_infinity for grad
|
||||
|
||||
Example:
|
||||
|
||||
import jittor as jt
|
||||
T = 50 # Input sequence length
|
||||
C = 20 # Number of classes (including blank)
|
||||
N = 16 # Batch size
|
||||
S = 30 # Target sequence length of longest target in batch (padding length)
|
||||
S_min = 10 # Minimum target length, for demonstration purposes
|
||||
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths)
|
||||
|
||||
dinput = jt.grad(loss, input)
|
||||
|
||||
'''
|
||||
result = _CTCLossFunction.apply(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity)
|
||||
if reduction=="mean":
|
||||
return result.mean()
|
||||
elif reduction=="sum":
|
||||
return result.sum()
|
||||
assert reduction=="none"
|
||||
return result
|
||||
|
||||
|
||||
class CTCLoss(jt.Module):
|
||||
'''The Connectionist Temporal Classification loss.
|
||||
|
||||
|
||||
Reference:
|
||||
A. Graves et al.: Connectionist Temporal Classification:
|
||||
Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
|
||||
https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
blank (int, default 0): blank label index
|
||||
reduction (string): reduce batch loss,
|
||||
if reduction is none, it will return (N,) array,
|
||||
if reduction is mean or sum, it will return one scalar
|
||||
zero_infinity (bool, default False):
|
||||
zero_infinity for grad
|
||||
|
||||
Input:
|
||||
|
||||
log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number.
|
||||
targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C).
|
||||
input_lengths: shape is [N], which represents the length of input, element should between [0,T].
|
||||
target_lengths: shape is N, which represents the length of target, element should between [0,S].
|
||||
|
||||
Example:
|
||||
|
||||
import jittor as jt
|
||||
T = 50 # Input sequence length
|
||||
C = 20 # Number of classes (including blank)
|
||||
N = 16 # Batch size
|
||||
S = 30 # Target sequence length of longest target in batch (padding length)
|
||||
S_min = 10 # Minimum target length, for demonstration purposes
|
||||
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
ctc_loss = jt.CTCLoss()
|
||||
loss = ctc_loss(input, target, input_lengths, target_lengths)
|
||||
|
||||
dinput = jt.grad(loss, input)
|
||||
|
||||
'''
|
||||
def __init__(self, blank=0, reduction='mean', zero_infinity=False):
|
||||
self.blank = blank
|
||||
self.reduction = reduction
|
||||
self.zero_infinity = zero_infinity
|
||||
|
||||
def execute(self, log_probs, targets, input_lengths, target_lengths):
|
||||
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)
|
|
@ -141,7 +141,7 @@ Example::
|
|||
an = a.shape[ai] if ai>=0 else 1
|
||||
bn = b.shape[bi] if bi>=0 else 1
|
||||
if an!=1 and bn!=1:
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{a.shape}"
|
||||
assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}"
|
||||
cn = max(an, bn)
|
||||
shape.append(cn)
|
||||
shape.extend([n, m, k])
|
||||
|
@ -341,16 +341,20 @@ def softmax(x, dim = None):
|
|||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
ret = x / x.sum(dim, keepdims=True)
|
||||
return ret
|
||||
jt.Var.softmax = softmax
|
||||
|
||||
def log_softmax(x,dim=None):
|
||||
x = softmax(x,dim=dim)
|
||||
return jt.log(x)
|
||||
jt.Var.log_softmax = log_softmax
|
||||
|
||||
def log_sigmoid(x):
|
||||
return jt.log(jt.sigmoid(x))
|
||||
jt.Var.log_sigmoid = log_sigmoid
|
||||
|
||||
def logsumexp(x, dim, keepdim=False):
|
||||
return x.exp().sum(dim, keepdim).log()
|
||||
jt.Var.logsumexp = logsumexp
|
||||
|
||||
class Identity(Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -159,6 +159,86 @@ class TestPad(unittest.TestCase):
|
|||
out.detach().numpy(), output.data,
|
||||
atol=1e-4)
|
||||
|
||||
def test_ctc_loss(self):
|
||||
def check(T,C,N,S,S_min):
|
||||
jt.set_global_seed(0)
|
||||
|
||||
# Initialize random batch of input vectors, for *size = (T,N,C)
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# input = -jt.ones((T, N, C))
|
||||
# input[0,0,1] += 0.01
|
||||
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
_input_jt = input
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
# ctc_loss = nn.CTCLoss()
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
|
||||
_loss_jt = loss
|
||||
|
||||
loss_jt = loss.numpy()
|
||||
|
||||
input = torch.Tensor(input.numpy()).detach().requires_grad_()
|
||||
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
|
||||
target_lengths = torch.LongTensor(target_lengths.numpy())
|
||||
input_lengths = torch.LongTensor(input_lengths.numpy())
|
||||
target = torch.LongTensor(target.numpy())
|
||||
loss = tnn.CTCLoss(reduction='none')(input, target, input_lengths, target_lengths)
|
||||
np.testing.assert_allclose(loss.detach().numpy(), loss_jt, rtol=1e-5, atol=1e-5)
|
||||
|
||||
dinput_jt = jt.grad(_loss_jt, _input_jt)
|
||||
dinput_jt.sync()
|
||||
|
||||
loss.sum().backward()
|
||||
# print(input.grad)
|
||||
# print(dinput_jt)
|
||||
# print(loss)
|
||||
|
||||
def check_gpu_with_cpu(T,C,N,S,S_min):
|
||||
jt.set_global_seed(1)
|
||||
|
||||
# Initialize random batch of input vectors, for *size = (T,N,C)
|
||||
input = jt.randn(T, N, C).log_softmax(2)
|
||||
# input = -jt.ones((T, N, C))
|
||||
# input[0,0,1] += 0.01
|
||||
|
||||
# Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int)
|
||||
_input_jt = input
|
||||
|
||||
input_lengths = jt.full((N,), T, dtype=jt.int)
|
||||
target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int)
|
||||
# ctc_loss = nn.CTCLoss()
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
|
||||
_loss_jt = loss
|
||||
|
||||
loss_jt = loss.numpy()
|
||||
|
||||
dinput_jt = jt.grad(_loss_jt, _input_jt)
|
||||
dinput_jt.sync()
|
||||
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
input = input.copy()
|
||||
target = target.copy()
|
||||
input_lengths = input_lengths.copy()
|
||||
target_lengths = target_lengths.copy()
|
||||
loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none')
|
||||
grad = jt.grad(loss, input)
|
||||
np.testing.assert_allclose(_loss_jt.numpy(), loss.numpy(), atol=1e-5, rtol=1e-5)
|
||||
np.testing.assert_allclose(dinput_jt.numpy(), grad.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
check(2,2,1,1,1)
|
||||
check(50,20,16,30,10)
|
||||
|
||||
if jt.has_cuda:
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
check(2,2,1,1,1)
|
||||
check(50,20,16,30,10)
|
||||
check_gpu_with_cpu(50,20,16,30,10)
|
||||
|
||||
class TestOther(unittest.TestCase):
|
||||
def test_save(self):
|
||||
pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
|
||||
|
|
Loading…
Reference in New Issue