mirror of https://github.com/Jittor/Jittor
jit_prepare for array op
This commit is contained in:
parent
6bc4a858b2
commit
8d64d98a35
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.0'
|
||||
__version__ = '1.2.0.1'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
|
|
@ -342,62 +342,45 @@ class BatchNorm(Module):
|
|||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
assert affine == None
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.affine = affine
|
||||
if affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if len(x.shape) == 3:
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0, 2], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0, 2], keepdims=1)
|
||||
dims = [0, 2]
|
||||
else:
|
||||
dims = [0]
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=dims, keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
|
||||
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0, 2])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0, 2])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0, 2])
|
||||
running_var = self.running_var.broadcast(x, [0, 2])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0, 2])
|
||||
b = self.bias.broadcast(x, [0, 2])
|
||||
else:
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
|
||||
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0])
|
||||
b = self.bias.broadcast(x, [0])
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum(dims)-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum(dims)-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, dims)
|
||||
running_var = self.running_var.broadcast(x, dims)
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, dims)
|
||||
b = self.bias.broadcast(x, dims)
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
if __name__ == "__main__":
|
||||
import unittest, os
|
||||
unittest.TestLoader.sortTestMethodsUsing = None
|
||||
|
||||
suffix = "__main__.py"
|
||||
assert __file__.endswith(suffix)
|
||||
|
@ -22,17 +23,19 @@ if __name__ == "__main__":
|
|||
suite = unittest.TestSuite()
|
||||
|
||||
for _, test_file in enumerate(test_files):
|
||||
test_name = test_file.split(".")[0]
|
||||
tests = unittest.defaultTestLoader.loadTestsFromName(
|
||||
"jittor.test."+test_name)
|
||||
|
||||
if not test_file.startswith("test_"):
|
||||
continue
|
||||
if _ < skip_l or _ > skip_r:
|
||||
continue
|
||||
test_name = test_file.split(".")[0]
|
||||
if test_only and test_name not in test_only:
|
||||
continue
|
||||
|
||||
print("Add Test", _, test_name)
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
|
||||
"jittor.test."+test_name))
|
||||
suite.addTest(tests)
|
||||
|
||||
result = unittest.TextTestRunner(verbosity=3).run(suite)
|
||||
if len(result.errors) or len(result.failures):
|
||||
|
|
|
@ -9,10 +9,18 @@ import numpy as np
|
|||
from jittor.nn import affine_grid,grid_sample
|
||||
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestAffineGrid(unittest.TestCase):
|
||||
def test_affine_grid_2d(self):
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
N = 8
|
||||
C = 3
|
||||
H = 256
|
||||
|
@ -37,8 +45,6 @@ class TestAffineGrid(unittest.TestCase):
|
|||
|
||||
|
||||
def test_affine_grid_3d(self):
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
N = 8
|
||||
C = 3
|
||||
D = 64
|
||||
|
|
|
@ -51,6 +51,8 @@ class TestConcatOp(unittest.TestCase):
|
|||
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
|
||||
print('concat success...')
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda = 1)
|
||||
def test_concat_perf(self):
|
||||
def check(dim, size, backward=False):
|
||||
|
@ -106,6 +108,7 @@ class TestConcatOp(unittest.TestCase):
|
|||
|
||||
'''
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda = 1)
|
||||
def test_concat2_perf(self):
|
||||
def check(dim, size, backward=False):
|
||||
|
|
|
@ -20,7 +20,7 @@ class TestContrib(unittest.TestCase):
|
|||
arr2.append(jt.array(a))
|
||||
x = np.concatenate(tuple(arr1), dim)
|
||||
y = jt.contrib.concat(arr2, dim)
|
||||
assert (x==y.data).all()
|
||||
assert (x==y.data).all(), (x, y.data, arr1, arr2)
|
||||
check([2,3,4], 0, 2)
|
||||
check([2,3,4], 1, 3)
|
||||
check([2,3,4], 2, 4)
|
||||
|
|
|
@ -45,8 +45,8 @@ def test(shape, op1, op2):
|
|||
with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs:
|
||||
d__ = d.data
|
||||
logs = find_log_with_re(logs,
|
||||
"Jit (fused )?op key (not )?found: \[opkey0:array\]\[opkey1")
|
||||
assert(len(logs)==1)
|
||||
"Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32")
|
||||
assert(len(logs)==1), logs
|
||||
|
||||
a_ = a.data
|
||||
b_ = b.data
|
||||
|
|
|
@ -114,7 +114,7 @@ class TestMklConvOp(unittest.TestCase):
|
|||
b = np.random.rand(o,i,h,w).astype(np.float32)
|
||||
da = np.random.rand(n,o,H,W).astype(np.float32)
|
||||
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1).data
|
||||
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1).data
|
||||
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1).data
|
||||
a_jt = jt.array(a)
|
||||
b_jt = jt.array(b)
|
||||
|
||||
|
@ -160,7 +160,7 @@ class TestMklConvOp(unittest.TestCase):
|
|||
b = np.random.rand(h,w,i,o).astype(np.float32)
|
||||
da = np.random.rand(n,H,W,o).astype(np.float32)
|
||||
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
|
||||
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
|
||||
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
|
||||
a_jt = jt.array(a)
|
||||
b_jt = jt.array(b)
|
||||
|
||||
|
|
|
@ -26,10 +26,10 @@ class FakeMpiBatchNorm(nn.Module):
|
|||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x, global_x):
|
||||
if self.is_train:
|
||||
|
|
|
@ -49,7 +49,7 @@ class TestSlice(unittest.TestCase):
|
|||
# print(slices)
|
||||
x = jt.random(shape)
|
||||
|
||||
with jt.log_capture_scope(log_vprefix="getitem=1000") as logs:
|
||||
with jt.log_capture_scope(log_vprefix="getitem=999") as logs:
|
||||
a = x.getitem(slices)
|
||||
a.sync()
|
||||
b = x.data[slices]
|
||||
|
|
|
@ -74,6 +74,11 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
|
|||
std::memcpy(allocation.ptr, args.ptr, output->size);
|
||||
}
|
||||
|
||||
void ArrayOp::jit_prepare() {
|
||||
if (output->flags.get(NodeFlags::_force_fuse))
|
||||
add_jit_define("T", output->dtype());
|
||||
}
|
||||
|
||||
void ArrayOp::run() {
|
||||
#ifdef HAS_CUDA
|
||||
if (allocation.allocator == &cuda_dual_allocator) {
|
||||
|
|
|
@ -28,6 +28,7 @@ struct ArrayOp : Op {
|
|||
|
||||
const char* name() const override { return "array"; }
|
||||
void run() override;
|
||||
void jit_prepare() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -358,7 +358,7 @@ void GetitemOp::infer_shape() {
|
|||
this->i_to_o = i_to_o.to_nano_vector();
|
||||
this->o_shape = o_shape.to_nano_vector();
|
||||
|
||||
LOGvvvv << "\ni_to_vs:" << i_to_vs
|
||||
LOGV(999) << "\ni_to_vs:" << i_to_vs
|
||||
<< "\ni_to_o:" << i_to_o
|
||||
<< "\no_shape:" << o_shape;
|
||||
}
|
||||
|
|
|
@ -49,11 +49,11 @@ JIT_TEST(kernel_ir) {
|
|||
})", true
|
||||
);
|
||||
string code = R"(//
|
||||
// scope: main(1),
|
||||
// scope: <cmath>(1), aaa(1), main(1),
|
||||
|
||||
// C macro code:"#include <cmath>"
|
||||
// C macro code:"#include <cmath>" lvalue:"<cmath>"
|
||||
#include <cmath>
|
||||
// C macro code:"#define aaa bbb"
|
||||
// C macro code:"#define aaa bbb" lvalue:"aaa" rvalue:" bbb"
|
||||
#define aaa bbb
|
||||
// C code:"using namespace std;" raw:"1"
|
||||
using namespace std;
|
||||
|
|
Loading…
Reference in New Issue