Update misc.py

polish cub_cumsum to support ‘range’  item assignment'
This commit is contained in:
Xiang-Li Li 2022-02-22 22:34:14 +08:00 committed by GitHub
parent 8c0f66c638
commit 4383886feb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -683,9 +683,9 @@ def cub_cumsum(x, dim=None):
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
shape = x.shape
shape = list(x.shape)
if (dim != -1 and dim != len(shape) - 1):
order = range(len(shape))
order = list(range(len(shape)))
order[dim], order[-1] = order[-1], order[dim]
shape[dim], shape[-1] = shape[-1], shape[dim]
x = x.permute(order)
@ -1656,4 +1656,4 @@ class CTCLoss(jt.Module):
self.zero_infinity = zero_infinity
def execute(self, log_probs, targets, input_lengths, target_lengths):
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)