mirror of https://github.com/Jittor/Jittor
fix histc.
This commit is contained in:
parent
d0f2d50607
commit
b276a45813
|
@ -2229,10 +2229,9 @@ def histc(input, bins, min=0., max=0.):
|
|||
if bins <= 0:
|
||||
raise RuntimeError(f"bins must be > 0, but got {bins}")
|
||||
bin_length = (max - min) / bins
|
||||
histc = jt.floor((input[jt.logical_and(input >= min, input <= max)] - min) / bin_length).int().reshape(-1)
|
||||
histc = jt.floor((input[jt.logical_and(input >= min, input < max)] - min) / bin_length).int().reshape(-1)
|
||||
hist = jt.ones_like(histc).float().reindex_reduce("add", [bins,], ["@e0(i0)"], extras=[histc])
|
||||
if hist.sum() != histc.shape[0]:
|
||||
hist[-1] += 1
|
||||
hist[-1] += input[input == max].shape[0]
|
||||
return hist
|
||||
|
||||
def peek_s(x):
|
||||
|
|
Loading…
Reference in New Issue