mirror of https://github.com/Jittor/Jittor
fix # Kh, Kw, Kc
This commit is contained in:
parent
41020767af
commit
b9a2a4848b
|
@ -102,7 +102,7 @@ def conv(x, w):
|
|||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, c
|
||||
return y
|
||||
|
||||
# Let's disable tuner. This will cause jittor not to use mkl for convolution
|
||||
|
@ -150,7 +150,7 @@ xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [
|
|||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, C
|
||||
```
|
||||
|
||||
**After expansion:**
|
||||
|
|
|
@ -21,7 +21,7 @@ def conv(x, w):
|
|||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, C
|
||||
return y, yy
|
||||
|
||||
def conv_naive(x, w):
|
||||
|
@ -52,7 +52,7 @@ def conv_transpose(x, w):
|
|||
], 0, ['(i1-i3)%2', '(i2-i4)%2'])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, C
|
||||
return y, yy
|
||||
|
||||
def conv_transpose_naive(x, w):
|
||||
|
|
|
@ -86,7 +86,7 @@ struct ReindexOp : Op {
|
|||
])
|
||||
ww = w.broadcast_var(xx)
|
||||
yy = xx*ww
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, Kc
|
||||
y = yy.sum([3,4,5]) # Kh, Kw, C
|
||||
return y, yy
|
||||
```
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue