mirror of https://github.com/Jittor/Jittor
Update misc.py
polish cub_cumsum to support ‘range’ item assignment'
This commit is contained in:
parent
8c0f66c638
commit
4383886feb
|
@ -683,9 +683,9 @@ def cub_cumsum(x, dim=None):
|
|||
if (dim == None):
|
||||
dim = -1
|
||||
assert(dim >= -1 and dim < len(x.shape))
|
||||
shape = x.shape
|
||||
shape = list(x.shape)
|
||||
if (dim != -1 and dim != len(shape) - 1):
|
||||
order = range(len(shape))
|
||||
order = list(range(len(shape)))
|
||||
order[dim], order[-1] = order[-1], order[dim]
|
||||
shape[dim], shape[-1] = shape[-1], shape[dim]
|
||||
x = x.permute(order)
|
||||
|
@ -1656,4 +1656,4 @@ class CTCLoss(jt.Module):
|
|||
self.zero_infinity = zero_infinity
|
||||
|
||||
def execute(self, log_probs, targets, input_lengths, target_lengths):
|
||||
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)
|
||||
return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity)
|
||||
|
|
Loading…
Reference in New Issue