Merge pull request #586 from fansunqi/dim

fix dim=3 error
This commit is contained in:
514flowey 2024-09-05 20:18:02 +08:00 committed by GitHub
commit 593519203b
1 changed files with 1 additions and 1 deletions

View File

@ -1965,7 +1965,7 @@ def _interpolate(img, x, y, ids, mode):
# TODO: tf_mode to another function
def resize(img, size, mode="nearest", align_corners=False, tf_mode=False):
if img.dim() != 3:
if img.dim() != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
n, c, h, w = img.shape
H, W = size