This commit is contained in:
Dun Liang 2020-05-28 23:20:45 +08:00
parent 370a3cc8ef
commit d798456fee
8 changed files with 67 additions and 9 deletions

View File

@ -11,7 +11,6 @@ import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from .dataset import Dataset, dataset_root
class VOC(Dataset):

View File

@ -34,7 +34,14 @@ class Pool(Module):
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
if self.op in ['maximum', 'minimum', 'mean'] and not self.count_include_pad:
if self.op in ['maximum', 'minimum', 'mean']:
if self.op == 'mean':
if self.count_include_pad:
count = f"int count = {self.kernel_size*self.kernel_size};"
else:
count = "int count = (k2_ - k2) * (k3_ - k3);"
else:
count = ""
forward_body = f'''{{
int k3 = i3*{self.stride}-{self.padding};
int k2 = i2*{self.stride}-{self.padding};
@ -43,7 +50,7 @@ class Pool(Module):
k3 = max(0, k3);
k2 = max(0, k2);
@out(i0, i1, i2, i3) = init_{self.op}(out_type);
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
{count}
for (int p = k2; p < k2_; ++p)
for (int q = k3; q < k3_; ++q)
@out(i0, i1, i2, i3) = {self.op}(out_type, @out(i0, i1, i2, i3), @in0(i0, i1, p, q));
@ -55,7 +62,7 @@ class Pool(Module):
int k2_ = min(k2 + {self.kernel_size}, in0_shape2);
k3 = max(0, k3);
k2 = max(0, k2);
{"int count = (k2_ - k2) * (k3_ - k3);" if self.op == "mean" else ""}
{count}
int bo=1;
for (int p = k2; p < k2_ && bo; ++p)
for (int q = k3; q < k3_ && bo; ++q) {{
@ -139,6 +146,7 @@ class Pool(Module):
'''])
return out
else:
# TODO: backward
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
"i0", # Nid
"i1", # Cid

View File

@ -24,6 +24,36 @@ try:
except:
skip_this_test = True
class OldPool(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"):
assert dilation == None
assert return_indices == None
self.kernel_size = kernel_size
self.op = op
self.stride = stride if stride else kernel_size
self.padding = padding
self.ceil_mode = ceil_mode
self.count_include_pad = count_include_pad and padding != 0
def execute(self, x):
N,C,H,W = x.shape
if self.ceil_mode == False:
h = (H+self.padding*2-self.kernel_size)//self.stride+1
w = (W+self.padding*2-self.kernel_size)//self.stride+1
else:
h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1
# TODO: backward
xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [
"i0", # Nid
"i1", # Cid
f"i2*{self.stride}-{self.padding}+i4", # Hid
f"i3*{self.stride}-{self.padding}+i5", # Wid
])
return xx.reduce(self.op, [4,5])
def check(jt_model, torch_model, shape, near_data):
if (near_data):
assert shape[0] * shape[1] * shape[2] * shape[3] % 8 == 0
@ -57,6 +87,20 @@ class TestArgPoolOp(unittest.TestCase):
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
@unittest.skipIf(True, "TODO: cannot pass this test, fix me")
@unittest.skipIf(not jt.compiler.has_cuda, "No cuda found")
@jt.flag_scope(use_cuda=1)
def test_cuda_old_pool(self):
from torch.nn import AvgPool2d
jt_model = OldPool(3, 1, 1, op="mean")
torch_model = AvgPool2d(3, 1, 1)
shape = [64, 64, 300, 300]
check(jt_model, torch_model, shape, False)
shape = [32, 128, 157, 300]
check(jt_model, torch_model, shape, False)
for i in range(10):
check(jt_model, torch_model, [1,1,300,300], True)
def test_cpu_(self):
# x = jt.random([32, 128, 157, 300])

View File

@ -23,13 +23,14 @@ class TestMem(unittest.TestCase):
one_g = np.ones((1024*1024*1024//4,), "float32")
meminfo = jt.get_mem_info()
n = int(meminfo.total_cuda_ram // (1024**3) * 1.5)
n = int(meminfo.total_cuda_ram // (1024**3) * 0.6)
for i in range(n):
a = jt.array(one_g)
b = a + 1
b.sync()
backups.append((a,b))
jt.sync_all(True)
backups = []

View File

@ -158,7 +158,7 @@ class TestParallelPass3(unittest.TestCase):
src = f.read()
for i in range(tdim):
assert f"tnum{i}" in src
assert f"tnum{tdim}" not in src
assert f"tnum{tdim}" not in src, f"tnum{tdim}"
src_has_atomic = "atomic_add" in src or "atomicAdd" in src
assert has_atomic == src_has_atomic
assert np.allclose(a.data.sum(rdim), b), (b.sum(), a.data.sum())
@ -176,7 +176,11 @@ class TestParallelPass3(unittest.TestCase):
check(3, 1, 1, [0,1], 1)
check(3, 1, 1, [0,1], 0, [0,0,2])
check(3, 2, 2, [2], 0)
check(3, 2, 1, [1], 0)
if jt.flags.use_cuda:
# loop is not merged so parallel depth 2
check(3, 2, 2, [1], 1)
else:
check(3, 2, 1, [1], 0)
check(3, 2, 2, [1], 1, merge=0)
check(4, 2, 2, [2,3], 0)
check(4, 2, 2, [0,3], 1)

View File

@ -3,12 +3,13 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
#ifdef JIT_cuda
#define pow(T,a,b) ::powf(a,b)
#define pow(T,a,b) ::pow(a,b)
#define maximum(T,a,b) ::max(T(a), T(b))
#define minimum(T,a,b) ::min(T(a), T(b))
#else // JIT_cpu

View File

@ -3,6 +3,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {

View File

@ -291,7 +291,7 @@ void ParallelPass::run() {
// omp func call
// we set num_threads in code
new_func_call->push_back(
"#pragma omp parallel num_threads("+S(thread_num)+")",
"#pragma omp parallel num_threads(thread_num)",
&new_func_call->before
);
} else {