mirror of https://github.com/Jittor/Jittor
Merge branch 'fix_meta_op_example' of https://github.com/xmyqsh/jittor
This commit is contained in:
commit
fb4188fc83
|
@ -159,7 +159,7 @@ y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
|||
|
||||
```
|
||||
py
|
||||
shape = [N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc]
|
||||
shape = [N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc]
|
||||
# expansion of x.reindex
|
||||
xx = np.zeros(shape, x.dtype)
|
||||
for i0 in range(shape[0]):
|
||||
|
@ -204,7 +204,7 @@ for i0 in range(shape[0]):
|
|||
for i4 in range(shape[4]):
|
||||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5] * w[i3,i4,i5,i6]
|
||||
y[i0,i1,i2,i6] += yy[i0,i1,i2,i3,i4,i5,i6]
|
||||
```
|
||||
|
||||
**After loop fusion:**
|
||||
|
@ -223,7 +223,7 @@ for i0 in range(shape[0]):
|
|||
for i5 in range(shape[5]):
|
||||
for i6 in range(shape[6]):
|
||||
if not is_overflow(i0,i1,i2,i3,i4,i5,i6):
|
||||
y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5]
|
||||
y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5] * w[i3,i4,i5,i6]
|
||||
```
|
||||
|
||||
This is the trick of meta-operator, It can fused multiple operator into a complicated operation, including many variation of convolution (e.g. group conv, seperate conv,...).
|
||||
|
@ -253,4 +253,4 @@ Even faster than the previous implementation! From the output we can look at the
|
|||
|
||||
在这个教程中,Jittor简单演示了元算子的使用,并不是正真的性能测试,所以使用了比较小的数据规模进行测试,如果需要性能测试,请打开`jt.flags.enable_tuner = 1`,会启动使用专门的硬件库加速。
|
||||
|
||||
In this tutorial, Jittor simply demonstrated the use of meta-operators, which is not a performance test. If you need a performance test, `jt.flags.enable_tuner = 1` will try to use the dedicated hardware library.
|
||||
In this tutorial, Jittor simply demonstrated the use of meta-operators, which is not a performance test. If you need a performance test, `jt.flags.enable_tuner = 1` will try to use the dedicated hardware library.
|
||||
|
|
|
@ -78,10 +78,10 @@ struct ReindexOp : Op {
|
|||
N,H,W,C = x.shape
|
||||
Kh, Kw, _C, Kc = w.shape
|
||||
assert C==_C
|
||||
xx = x.reindex([N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc], [
|
||||
xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
|
||||
'i0', # Nid
|
||||
'i1-i3', # Hid+Khid
|
||||
'i2-i4', # Wid+KWid
|
||||
'i1+i3', # Hid+Khid
|
||||
'i2+i4', # Wid+KWid
|
||||
'i5', # Cid
|
||||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
|
@ -104,4 +104,4 @@ struct ReindexOp : Op {
|
|||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
||||
} // jittor
|
||||
|
|
Loading…
Reference in New Issue