Merge branch 'master' into doc

This commit is contained in:
lzhengning 2021-02-20 17:51:14 +08:00
commit da9a4a0232
25 changed files with 245 additions and 99 deletions

View File

@ -43,8 +43,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kh(kh), kw(kw), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -134,9 +134,9 @@ void CudnnConvBackwardWOp::jit_run() {
filterFormat_@WFORMAT, 4, dimW
));
int padA[] = {padding, padding};
int convstrideA[] = {stride, stride};
int dilationA[] = {dilation, dilation};
int padA[] = {paddingh, paddingw};
int convstrideA[] = {strideh, stridew};
int dilationA[] = {dilationh, dilationw};
// difference between
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
// is the kernel rc order
@ -187,7 +187,7 @@ void CudnnConvBackwardWOp::jit_run() {
jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
jk << padding << "," <<stride << "," << dilation << "," << groups << ".";
jk << paddingh << paddingw << "," <<strideh <<stridew << "," << dilationh << dilationw << "," << groups << ".";
auto iter = bwdw_algo_cache.find(jk.to_string());
if (iter!=bwdw_algo_cache.end()) algo = iter->second;

View File

@ -14,10 +14,10 @@ namespace jittor {
struct CudnnConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kh, kw, stride, padding, dilation, groups;
int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
string xformat, wformat, yformat;
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_w"; }
void infer_shape() override;

View File

@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups),
CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -135,9 +135,9 @@ void CudnnConvBackwardXOp::jit_run() {
filterFormat_@WFORMAT, 4, dimW
));
int padA[] = {padding, padding};
int convstrideA[] = {stride, stride};
int dilationA[] = {dilation, dilation};
int padA[] = {paddingh, paddingw};
int convstrideA[] = {strideh, stridew};
int dilationA[] = {dilationh, dilationw};
// difference between
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
// is the kernel rc order
@ -188,7 +188,7 @@ void CudnnConvBackwardXOp::jit_run() {
jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
jk << padding << "," <<stride << "," << dilation << "," << groups << ".";
jk << paddingh << paddingw << "," <<strideh <<stridew << "," << dilationh << dilationw << "," << groups << ".";
auto iter = bwdx_algo_cache.find(jk.to_string());
if (iter!=bwdx_algo_cache.end()) algo = iter->second;

View File

@ -14,10 +14,10 @@ namespace jittor {
struct CudnnConvBackwardXOp : Op {
Var* w, * dy, * dx;
int xh, xw, stride, padding, dilation, groups;
int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
string xformat, wformat, yformat;
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "cudnn_conv_backward_x"; }
void infer_shape() override;

View File

@ -42,8 +42,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups),
CudnnConvOp::CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
@ -60,8 +60,8 @@ void CudnnConvOp::infer_shape() {
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
ASSERTop(wci * groups,==,xc);
yn = xn, yc = wco;
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1;
yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1;
set_shape(y, "abcd", yformat, yn, yc, yh, yw);
}
@ -135,9 +135,9 @@ void CudnnConvOp::jit_run() {
filterFormat_@WFORMAT, 4, dimW
));
int padA[] = {padding, padding};
int convstrideA[] = {stride, stride};
int dilationA[] = {dilation, dilation};
int padA[] = {paddingh, paddingw};
int convstrideA[] = {strideh, stridew};
int dilationA[] = {dilationh, dilationw};
// difference between
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
// is the kernel rc order
@ -190,7 +190,7 @@ void CudnnConvOp::jit_run() {
jk.clear();
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ",";
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ",";
jk << padding << "," <<stride << "," << dilation << "," << groups << ".";
jk << paddingh << paddingw << "," <<strideh <<stridew << "," << dilationh << dilationw << "," << groups << ".";
auto iter = fwd_algo_cache.find(jk.to_string());
if (iter!=fwd_algo_cache.end()) algo = iter->second;

View File

@ -11,10 +11,10 @@ namespace jittor {
struct CudnnConvOp : Op {
Var* x, * w, * y;
int stride, padding, dilation, groups;
int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
string xformat, wformat, yformat;
/* CudnnConvOp: xformat abcd represents nchw */
CudnnConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "cudnn_conv"; }
void infer_shape() override;

View File

@ -46,8 +46,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kh(kh), kw(kw), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
}
@ -119,10 +119,10 @@ void MklConvBackwardWOp::jit_run() {
memory::dims conv_weights_tz = groups>1
? memory::dims{groups, ch_out/groups, ch_in/groups, kh, kw}
: memory::dims{ch_out, ch_in, kh, kw};
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kh*dilation+dilation-1)/stride+1, (width+padding*2-kw*dilation+dilation-1)/stride+1};
memory::dims conv_strides = {stride, stride};
memory::dims conv_padding = {padding, padding};
memory::dims conv_dilation = {dilation-1, dilation-1};
memory::dims conv_dst_tz = {batch, ch_out, (height+paddingh*2-kh*dilationh+dilationh-1)/strideh+1, (width+paddingw*2-kw*dilationw+dilationw-1)/stridew+1};
memory::dims conv_strides = {strideh, stridew};
memory::dims conv_padding = {paddingh, paddingw};
memory::dims conv_dilation = {dilationh-1, dilationw-1};
if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw);

View File

@ -14,10 +14,10 @@ namespace jittor {
struct MklConvBackwardWOp : Op {
Var* x, * dy, * dw;
int kh, kw, stride, padding, dilation, groups;
int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
string xformat, wformat, yformat;
MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_w"; }
void infer_shape() override;

View File

@ -46,8 +46,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation), groups(groups),
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
}
@ -97,7 +97,8 @@ void MklConvBackwardXOp::jit_run() {
int height = dx->shape[findc("@XFORMAT",'c')];
int width = dx->shape[findc("@XFORMAT",'d')];
int ch_out = w->shape[findc("@WFORMAT",'o')];
int kernel_size = w->shape[findc("@WFORMAT",'h')];
int kernel_sizeh = w->shape[findc("@WFORMAT",'h')];
int kernel_sizew = w->shape[findc("@WFORMAT",'w')];
auto* __restrict__ conv_weights = w->ptr<Twd>();
auto* __restrict__ net_diff_dst = dy->ptr<Tyd>();
@ -114,12 +115,12 @@ void MklConvBackwardXOp::jit_run() {
memory::dims conv_src_tz = {batch, ch_in, height, width};
memory::dims conv_weights_tz = groups>1
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size}
: memory::dims{ch_out, ch_in, kernel_size, kernel_size};
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
memory::dims conv_strides = {stride, stride};
memory::dims conv_padding = {padding, padding};
memory::dims conv_dilation = {dilation-1, dilation-1};
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_sizeh, kernel_sizew}
: memory::dims{ch_out, ch_in, kernel_sizeh, kernel_sizew};
memory::dims conv_dst_tz = {batch, ch_out, (height+paddingh*2-kernel_sizeh*dilationh+dilationh-1)/strideh+1, (width+paddingw*2-kernel_sizew*dilationw+dilationw-1)/stridew+1};
memory::dims conv_strides = {strideh, stridew};
memory::dims conv_padding = {paddingh, paddingw};
memory::dims conv_dilation = {dilationh-1, dilationw-1};
if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw);

View File

@ -14,10 +14,10 @@ namespace jittor {
struct MklConvBackwardXOp : Op {
Var* w, * dy, * dx;
int xh, xw, stride, padding, dilation, groups;
int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
string xformat, wformat, yformat;
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_x"; }
void infer_shape() override;

View File

@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), stride(stride), padding(padding), dilation(dilation), groups(groups),
MklConvOp::MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
if (!this->yformat.size())
@ -61,8 +61,8 @@ void MklConvOp::infer_shape() {
get_shape(w, "oihw", wformat, wco, wci, wh, ww);
ASSERTop(wci * groups,==,xc);
yn = xn, yc = wco;
yh = (xh+padding*2-wh*dilation+dilation-1)/stride+1;
yw = (xw+padding*2-ww*dilation+dilation-1)/stride+1;
yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1;
yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1;
set_shape(y, "abcd", yformat, yn, yc, yh, yw);
}
@ -104,7 +104,7 @@ void MklConvOp::jit_run() {
using dt = memory::data_type;
if (tag::@XFORMAT==tag::nhwc && tag::@YFORMAT==tag::nhwc && tag::@WFORMAT==tag::hwio
&& stride==1 && padding==0 && dilation==1 && ws[0]==1 && ws[1]==1
&& strideh==1 && stridew==1 && paddingh==0 && paddingw==0 && dilationh==1 && dilationw==1 && ws[0]==1 && ws[1]==1
&& dt::@Tx==dt::f32 && dt::@Ty==dt::f32 && dt::@Tw==dt::f32) {
auto m = xs[0]*xs[1]*xs[2];
auto n = ws[3];
@ -133,9 +133,9 @@ void MklConvOp::jit_run() {
? memory::dims{groups, wco/groups, wci, wh, ww}
: memory::dims{wco, wci, wh, ww};
memory::dims conv1_dst_tz = {yn, yc, yh, yw};
memory::dims conv1_strides = { stride, stride };
memory::dims conv1_padding = { padding, padding };
memory::dims conv1_dilation = { dilation-1, dilation-1 };
memory::dims conv1_strides = { strideh, stridew };
memory::dims conv1_padding = { paddingh, paddingw };
memory::dims conv1_dilation = { dilationh-1, dilationw-1 };
if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw);

View File

@ -14,10 +14,10 @@ namespace jittor {
struct MklConvOp : Op {
Var* x, * w, * y;
int stride, padding, dilation, groups;
int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups;
string xformat, wformat, yformat;
/* MklConvOp: xformat abcd represents nchw */
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "mkl_conv"; }
void infer_shape() override;

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.30'
__version__ = '1.2.2.32'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -28,6 +28,7 @@ class DepthwiseConv(Function):
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
filter_height, filter_width = Kh, Kw
self.Khw = Kh, Kw
assert oh>0 and ow>0
output = jt.code(
[N, C, oh, ow],
x.dtype,

View File

@ -189,7 +189,7 @@ def Resnet101(pretrained=False, **kwargs):
Example::
model = jittor.models.Resnet101()
x = jittor.random([10,224,224,3])
x = jittor.random([10,3,224,224])
y = model(x) # [10, 1000]
"""

View File

@ -554,6 +554,7 @@ class Conv(Module):
assert C==self.in_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
assert oh>0 and ow>0
xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [
'i0', # Nid
'i2', # Cid
@ -576,6 +577,7 @@ class Conv(Module):
oc = self.out_channels
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
assert oh>0 and ow>0
xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [
'i0', # Nid
f'i1*{CpG}+i3', # Gid

View File

@ -55,7 +55,11 @@ RUN pip3 install torch torchvision
with open("/tmp/perf_dockerfile", 'w') as f:
f.write(dockerfile_src)
assert os.system("sudo nvidia-smi -lgc 1500") == 0
assert os.system(f"sudo docker build --tag jittor/jittor-perf{suffix} -f /tmp/perf_dockerfile .") == 0
# if the docker image is not built
if os.system(f"sudo docker image inspect jittor/jittor-perf{suffix}"):
assert os.system(f"sudo docker build --tag jittor/jittor-perf{suffix} -f /tmp/perf_dockerfile .") == 0
# run once for compile source
jt_fps = test_main("jittor", "resnet50", 1)
@ -180,7 +184,8 @@ def test(name, model_name, bs):
loss.backward()
opt.step()
else:
x.sync()
if name == "jittor":
x.sync()
sync()
for i in time_iter():
iter()

View File

@ -0,0 +1,127 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Guowei Yang <471184555@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from jittor import nn, Module
from jittor.models import inception
import numpy as np
import sys, os
import random
import math
import unittest
from jittor.test.test_reorder_tuner import simple_parser
from jittor.test.test_log import find_log_with_re
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
import time
skip_this_test = False
class MnistNet(Module):
def __init__(self):
self.model = inception.inception_v3()
self.layer = nn.Linear(1000,10)
def execute(self, x):
x = self.model(x)
x = self.layer(x)
return x
@unittest.skipIf(skip_this_test, "skip_this_test")
class TestInception(unittest.TestCase):
@classmethod
def setUpClass(self):
# hyper-parameters
self.batch_size = 32
self.weight_decay = 0.0001
self.momentum = 0.9
self.learning_rate = 0.1
# mnist dataset
self.train_loader = MNIST(train=True, transform=trans.Resize(300)) \
.set_attrs(batch_size=self.batch_size, shuffle=True)
self.train_loader.num_workers = 4
self.train_loader.total_len = self.batch_size * 300
# setup random seed
def setup_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
jt.seed(seed)
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1, use_stat_allocator=1)
def test_inception(self):
self.setup_seed(1)
loss_list=[]
acc_list=[]
mnist_net = MnistNet()
global prev
prev = time.time()
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
for batch_idx, (data, target) in enumerate(self.train_loader):
# train step
with jt.log_capture_scope(
log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=10",
) as logs:
# breakpoint()
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
SGD.step(loss)
def callback(batch_idx, loss, output, target):
# print train info
global prev
pred = np.argmax(output, axis=1)
acc = np.mean(target==pred)
loss_list.append(loss[0])
acc_list.append(acc)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
.format(0, batch_idx, 300,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
# prev = time.time()
jt.fetch(batch_idx, loss, output, target, callback)
log_conv = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
log_matmul = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
if batch_idx > 2:
assert len(log_conv)==283 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
mem_used = jt.flags.stat_allocator_total_alloc_byte \
-jt.flags.stat_allocator_total_free_byte
# assert mem_used < 4e9, mem_used
# TODO: why bigger?
assert mem_used < 15.6e9, mem_used
# example log:
# Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000
# Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000
# Train Epoch: 0 [2/100 (2%)] Loss: 3.473594 Acc: 0.100000
# Train Epoch: 0 [3/100 (3%)] Loss: 3.131615 Acc: 0.200000
# Train Epoch: 0 [4/100 (4%)] Loss: 2.524094 Acc: 0.230000
# Train Epoch: 0 [5/100 (5%)] Loss: 7.780025 Acc: 0.080000
# Train Epoch: 0 [6/100 (6%)] Loss: 3.890721 Acc: 0.160000
# Train Epoch: 0 [7/100 (7%)] Loss: 6.370137 Acc: 0.140000
# Train Epoch: 0 [8/100 (8%)] Loss: 11.390827 Acc: 0.150000
# Train Epoch: 0 [9/100 (9%)] Loss: 21.598564 Acc: 0.080000
# Train Epoch: 0 [10/100 (10%)] Loss: 23.369165 Acc: 0.130000
# Train Epoch: 0 [20/100 (20%)] Loss: 4.804510 Acc: 0.100000
# Train Epoch: 0 [30/100 (30%)] Loss: 3.393924 Acc: 0.110000
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
assert jt.core.number_of_lived_vars() < 50000, jt.core.number_of_lived_vars()
jt.sync_all(True)
assert np.mean(loss_list[-20:])<1
assert np.mean(acc_list[-20:])>0.5
if __name__ == "__main__":
unittest.main()

View File

@ -57,7 +57,7 @@ oihw = [4, 3, 5, 5]
import jittor as jt
x = jt.random(nchw)
w = jt.random(oihw)
jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
jt.dirty_fix_pytorch_runtime_error()
@ -88,7 +88,7 @@ m(torch.rand(*nchw))
import jittor as jt
x = jt.random(nchw)
w = jt.random(oihw)
jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync()
"""

View File

@ -56,7 +56,7 @@ class TestMklConvOp(unittest.TestCase):
def test_forward(self):
a = np.random.rand(1,3,224,224).astype(np.float32)
b = np.random.rand(64,3,7,7).astype(np.float32)
c = jt.mkl_ops.mkl_conv(a,b,2,3).data
c = jt.mkl_ops.mkl_conv(a,b,2,2,3,3).data
a_jt = jt.array(a)
b_jt = jt.array(b)
@ -81,7 +81,7 @@ class TestMklConvOp(unittest.TestCase):
def check(xshape, wshape, stride, pad):
a = np.random.rand(*xshape).astype(np.float32)
b = np.random.rand(*wshape).astype(np.float32)
c = jt.mkl_ops.mkl_conv(a,b,stride,pad,1,xformat="acdb",wformat="hwio").data
c = jt.mkl_ops.mkl_conv(a,b,stride,stride,pad,pad,1,1,xformat="acdb",wformat="hwio").data
a_jt = jt.array(a)
b_jt = jt.array(b)
@ -114,8 +114,8 @@ class TestMklConvOp(unittest.TestCase):
a = np.random.rand(n,c,H,W).astype(np.float32)
b = np.random.rand(o,i,h,w).astype(np.float32)
da = np.random.rand(n,o,H,W).astype(np.float32)
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1).data
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1).data
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1).data
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,1,1,1).data
a_jt = jt.array(a)
b_jt = jt.array(b)
@ -160,8 +160,8 @@ class TestMklConvOp(unittest.TestCase):
a = np.random.rand(n,H,W,c).astype(np.float32)
b = np.random.rand(h,w,i,o).astype(np.float32)
da = np.random.rand(n,H,W,o).astype(np.float32)
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
a_jt = jt.array(a)
b_jt = jt.array(b)

View File

@ -148,6 +148,27 @@ class TestSetitem(unittest.TestCase):
a[1,1] = -2
assert (a[0].numpy() == [-1,2]).all(), a[0].numpy()
assert (a[1].numpy() == [3,-2]).all(), a[1].numpy()
# def test_scatter(self):
# src = jt.arange(1, 11).reshape((2, 5))
# index = jt.array([[0, 1, 2, 0]])
# print(index.shape, src.shape)
# x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src)
# print(x)
# def scatter(x, dim, index, src, reduce='void'):
# shape = index.shape
# indexes = [ jt.index(shape, i) for i in range(dim) ]
# indexes.append(index)
# print(indexes)
# return x.setitem(tuple(indexes), src, reduce)
# def scatter_(x, dim, index, src, reduce='void'):
# return x.assign(x.scatter(dim, index, src, reduce))
# jt.Var.scatter = scatter
# jt.Var.scatter_ = scatter_
if __name__ == "__main__":
unittest.main()

View File

@ -37,8 +37,7 @@ JitKey::JitKey() {
JitKey::~JitKey() {
auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]);
LOGvv << "un-protect page" << (void*)buffer_end_page;
ASSERT(0==
mprotect((void*)buffer_end_page, page_size, PROT_READ|PROT_WRITE|PROT_EXEC));
mprotect((void*)buffer_end_page, page_size, PROT_READ|PROT_WRITE|PROT_EXEC);
protected_page = 0;
}

View File

@ -331,20 +331,11 @@ void ConvTuner::forwardTune(FusedOp* fop) {
int padding_w = -rw[1]->as_int();
int dilation_w = rw[2]->as_int();
if (dilation_h < 1 || dilation_w < 1) continue;
if (stride_h!=stride_w || padding_h!=padding_w || dilation_h!=dilation_w) {
LOGw << "cannot relay different stride and padding between h and w"
<< stride_h << padding_h << dilation_h << stride_w << padding_w << dilation_w
<< "This may cause low performance. Please send us issue if you need it.";
continue;
}
LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h;
if (xformat == "bacd") {
LOGvvvv << "mkl not support bacd, continue";
continue;
}
int stride = stride_h;
int padding = padding_h;
int dilation = dilation_h;
Var* x = x_id == 0 ? xoi.op->output(0) : xoi.op->input(0);
Var* w = w_id == 0 ? woi.op->output(0) : woi.op->input(0);
Var* y = y_id == 0 ? yoi.op->output(0) : yoi.op->input(0);
@ -371,9 +362,9 @@ void ConvTuner::forwardTune(FusedOp* fop) {
if (!has_op(relay_conv_name))
continue;
auto make_conv = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
LOGvvvv << x << w << stride << padding << dilation << groups << xformat << wformat << yformat;
rvar = make_conv(x, w, stride, padding, dilation, groups, xformat, wformat, yformat);
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, string, string, string>();
LOGvvvv << x << w << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat;
rvar = make_conv(x, w, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat);
} else
if (x_id == 0) {
relay_conv_name = fop->flags.get(NodeFlags::_cpu) ?
@ -383,9 +374,9 @@ void ConvTuner::forwardTune(FusedOp* fop) {
auto height = x->shape[xformat.find("c")];
auto width = x->shape[xformat.find("d")];
auto make_conv_x = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int , int, int, int, int, int, string, string, string>();
LOGvvvv << w << y << height << width << stride << padding << dilation << groups << xformat << wformat << yformat;
rvar = make_conv_x(w, y, height, width, stride, padding, dilation, groups, xformat, wformat, yformat);
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
LOGvvvv << w << y << height << width << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat;
rvar = make_conv_x(w, y, height, width, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat);
} else {
relay_conv_name = fop->flags.get(NodeFlags::_cpu) ?
"mkl_conv_backward_w" : "cudnn_conv_backward_w";
@ -393,10 +384,10 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
auto kh = w->shape[wformat.find("h")];
auto kw = w->shape[wformat.find("w")];
LOGvvvv << x << y << kh << stride << padding << dilation << groups << xformat << wformat << yformat;
LOGvvvv << x << y << kh << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat;
auto make_conv_w = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, string, string, string>();
rvar = make_conv_w(x, y, kh, kw, stride, padding, dilation, groups, xformat, wformat, yformat);
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, string, string, string>();
rvar = make_conv_w(x, y, kh, kw, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat);
}
LOGvvvv << relay_conv_name << "output:" << rvar;

View File

@ -23,7 +23,6 @@ JIT_TEST(sfrl_allocator_time) {
Allocator* allocator = get_allocator();
int max_allc_num = 10000;
size_t id[max_allc_num];
void* addr[max_allc_num];
size_t temp[max_allc_num];
std::vector<TestTask> tasks;
tasks.push_back(TestTask(20000000, 1000, 1000, 400.0));
@ -35,12 +34,12 @@ JIT_TEST(sfrl_allocator_time) {
for (size_t k = 0; k < tasks[i].times1; ++k) {
for (size_t j = 0; j < tasks[i].times2; ++j) {
temp[j] = j;
addr[j] = allocator->alloc(tasks[i].size, id[j]);
allocator->alloc(tasks[i].size, id[j]);
if (j > 0)
std::swap(temp[j], temp[rand() % j]);
}
for (size_t j = 0; j < tasks[i].times2; ++j) {
allocator->free(addr[temp[j]], tasks[i].size, id[temp[j]]);
allocator->free(0, tasks[i].size, id[temp[j]]);
}
}
auto end = std::chrono::duration_cast<std::chrono::microseconds>(
@ -55,7 +54,6 @@ JIT_TEST(sfrl_allocator_share) {
Allocator* allocator = get_allocator();
int max_allc_num = 10000;
size_t id[max_allc_num];
void* addr[max_allc_num];
size_t temp[max_allc_num];
std::vector<TestTask> tasks;
tasks.push_back(TestTask(20000000, 1000, 1000, 400.0));
@ -72,13 +70,12 @@ JIT_TEST(sfrl_allocator_share) {
if (rand() % 10 != 0 && j > 0) {
id[j] = id[rand() % j];
allocator->share_with(tasks[i].size, id[j]);
addr[j] = addr[id[j]];
} else {
addr[j] = allocator->alloc(tasks[i].size, id[j]);
allocator->alloc(tasks[i].size, id[j]);
}
}
for (size_t j = 0; j < tasks[i].times2; ++j) {
allocator->free(addr[temp[j]], tasks[i].size, id[temp[j]]);
allocator->free(0, tasks[i].size, id[temp[j]]);
}
}
auto end = std::chrono::duration_cast<std::chrono::microseconds>(
@ -93,7 +90,6 @@ JIT_TEST(sfrl_allocator_share_without_size_and_ptr) {
Allocator* allocator = get_allocator();
int max_allc_num = 1000;
size_t id[max_allc_num];
void* addr[max_allc_num];
size_t temp[max_allc_num];
std::vector<TestTask> tasks;
tasks.push_back(TestTask(20000000, 100, 100, 400.0));
@ -108,9 +104,8 @@ JIT_TEST(sfrl_allocator_share_without_size_and_ptr) {
if (rand() % 10 != 0 && j > 0) {
id[j] = id[rand() % j];
allocator->share_with(0, id[j]);
addr[j] = addr[id[j]];
} else {
addr[j] = allocator->alloc(tasks[i].size, id[j]);
allocator->alloc(tasks[i].size, id[j]);
}
}
for (size_t j = 0; j < tasks[i].times2; ++j) {

View File

@ -10,6 +10,7 @@
#include <iomanip>
#include <thread>
#include <unordered_map>
#include <unistd.h>
#include "utils/log.h"
#include "utils/mwsr_list.h"
@ -184,12 +185,15 @@ bool exited = false;
size_t thread_local protected_page = 0;
int segfault_happen = 0;
string thread_local thread_name;
static int _pid = getpid();
void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
if (signal == SIGINT) {
LOGe << "Caught SIGINT, exit";
if (_pid == getpid()) {
LOGe << "Caught SIGINT, quick exit";
}
exited = true;
exit(1);
std::quick_exit(1);
}
std::cerr << "Caught segfault at address " << si->si_addr << ", "
<< "thread_name: '" << thread_name << "', flush log..." << std::endl;