mirror of https://github.com/Jittor/Jittor
Add interpolate area support.
This commit is contained in:
parent
05ed6c7e34
commit
56255578f9
|
@ -1769,6 +1769,20 @@ def resize(img, size, mode="nearest", align_corners=False, tf_mode=False):
|
|||
elif mode == 'nearest':
|
||||
x = hid * (h / H)
|
||||
y = wid * (w / W)
|
||||
elif mode == "area":
|
||||
stride = (h // H, w // W)
|
||||
assert stride[0] > 0 and stride[1] > 0
|
||||
x, y = jt.meshgrid(jt.arange(0, H, 1), jt.arange(0, W, 1))
|
||||
startH = jt.floor(x*h/H).int32()
|
||||
endH = jt.ceil((x+1)*h/H).int32()
|
||||
maxH = int(jt.max(endH - startH).data)
|
||||
startW = jt.floor(y*w/W).int32()
|
||||
endW = jt.ceil((y+1)*w/W).int32()
|
||||
maxW = int(jt.max(endW - startW).data)
|
||||
pixel_count = (endH - startH) * (endW - startW)
|
||||
adaptive_output = img.reindex([img.shape[0], img.shape[1], H, W, maxH, maxW], ["i0", "i1", "@e0(i2, i3) + i4", "@e2(i2, i3) + i5"], extras=[startH, endH, startW, endW], overflow_conditions=["i4 >= @e1(i2, i3) - @e0(i2, i3)", "i5 >= @e3(i2, i3) - @e2(i2, i3)"], overflow_value=0)
|
||||
adaptive_output = adaptive_output.reduce("sum", [4,5]) / pixel_count[None, None, ...]
|
||||
return adaptive_output
|
||||
else:
|
||||
if (tf_mode):
|
||||
x = hid * (h / H)
|
||||
|
|
Loading…
Reference in New Issue