polish auto_parallel and upsample

This commit is contained in:
Dun Liang 2021-03-11 16:05:34 +08:00
parent 8978e7dcd1
commit 36e7acfc1f
4 changed files with 7 additions and 4 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.2.43' __version__ = '1.2.2.44'
from . import lock from . import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -271,6 +271,7 @@ def setup_cutt():
def install_nccl(root_folder): def install_nccl(root_folder):
url = "https://github.com/NVIDIA/nccl/archive/v2.8.4-1.tar.gz" url = "https://github.com/NVIDIA/nccl/archive/v2.8.4-1.tar.gz"
url = "https://codeload.github.com/NVIDIA/nccl/tar.gz/v2.8.4-1"
filename = "nccl.tgz" filename = "nccl.tgz"
fullname = os.path.join(root_folder, filename) fullname = os.path.join(root_folder, filename)

View File

@ -897,7 +897,7 @@ def auto_parallel(n, src, **kw):
tid_def += f"\nauto tnum{i} = 1<<tn{i};" tid_def += f"\nauto tnum{i} = 1<<tn{i};"
tid_def += f"\ntid = tid>>tn{i};" tid_def += f"\ntid = tid>>tn{i};"
for i in range(n): for i in range(n):
tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tn{i})" tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tnum{i})"
call_args.append(pnargs2[i]) call_args.append(pnargs2[i])
call_args.append(f"i{i}") call_args.append(f"i{i}")
call_args += oargs2 call_args += oargs2

View File

@ -1053,8 +1053,10 @@ def upsample(img, size, mode="nearest", align_corners=False):
x = (hid + 0.5) * (h / H) - 0.5 x = (hid + 0.5) * (h / H) - 0.5
y = (wid + 0.5) * (w / W) - 0.5 y = (wid + 0.5) * (w / W) - 0.5
else: else:
x = hid * (h / H) x = hid * (h / H) + (h / H * 0.5 - 0.5)
y = wid * (w / W) if H > h: x = x.clamp(0, h - 1)
y = wid * (w / W) + (w / W * 0.5 - 0.5)
if W > w: y = y.clamp(0, w - 1)
return _interpolate(img, x, y, (nid, cid), mode) return _interpolate(img, x, y, (nid, cid), mode)