mirror of https://github.com/Jittor/Jittor
add cuda limit.h
This commit is contained in:
parent
1e867a5b7b
commit
93ca5e9525
|
@ -67,7 +67,10 @@ class Pool(Module):
|
|||
}}
|
||||
}}'''
|
||||
out = jt.code([N,C,h,w], x.dtype, [x],
|
||||
cuda_header='#include <ops/binary_op_defs.h>',
|
||||
cuda_header="""
|
||||
#include <ops/binary_op_defs.h>
|
||||
#include <misc/cuda_limits.h>
|
||||
""",
|
||||
cuda_src=f'''
|
||||
__global__ static void kernel1(@ARGS_DEF) {{
|
||||
@PRECALC
|
||||
|
|
|
@ -75,6 +75,11 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_avg_pool(self):
|
||||
self.test_cpu_avg_pool()
|
||||
|
||||
def test_cpu_avg_pool(self):
|
||||
from torch.nn import AvgPool2d
|
||||
jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True)
|
||||
|
|
Loading…
Reference in New Issue