add cuda limit.h

This commit is contained in:
Dun Liang 2020-05-27 17:28:34 +08:00
parent 1e867a5b7b
commit 93ca5e9525
2 changed files with 9 additions and 1 deletions

View File

@ -67,7 +67,10 @@ class Pool(Module):
}}
}}'''
out = jt.code([N,C,h,w], x.dtype, [x],
cuda_header='#include <ops/binary_op_defs.h>',
cuda_header="""
#include <ops/binary_op_defs.h>
#include <misc/cuda_limits.h>
""",
cuda_src=f'''
__global__ static void kernel1(@ARGS_DEF) {{
@PRECALC

View File

@ -75,6 +75,11 @@ class TestArgPoolOp(unittest.TestCase):
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda_avg_pool(self):
self.test_cpu_avg_pool()
def test_cpu_avg_pool(self):
from torch.nn import AvgPool2d
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)