mirror of https://github.com/Jittor/Jittor
fix mkl default args
This commit is contained in:
parent
32651c5c7d
commit
fc4f54ae0b
|
@ -16,7 +16,7 @@ struct MklConvOp : Op {
|
|||
int stride, padding, dilation, groups;
|
||||
string xformat, wformat, yformat;
|
||||
/* MklConvOp: xformat abcd represents nchw */
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "mkl_conv"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -80,7 +80,7 @@ class TestMklConvOp(unittest.TestCase):
|
|||
def check(xshape, wshape, stride, pad):
|
||||
a = np.random.rand(*xshape).astype(np.float32)
|
||||
b = np.random.rand(*wshape).astype(np.float32)
|
||||
c = jt.mkl_ops.mkl_conv(a,b,stride,pad,1,"acdb","hwio").data
|
||||
c = jt.mkl_ops.mkl_conv(a,b,stride,pad,1,xformat="acdb",wformat="hwio").data
|
||||
|
||||
a_jt = jt.array(a)
|
||||
b_jt = jt.array(b)
|
||||
|
@ -159,8 +159,8 @@ class TestMklConvOp(unittest.TestCase):
|
|||
a = np.random.rand(n,H,W,c).astype(np.float32)
|
||||
b = np.random.rand(h,w,i,o).astype(np.float32)
|
||||
da = np.random.rand(n,H,W,o).astype(np.float32)
|
||||
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,"acdb","hwio","acdb").data
|
||||
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1,"acdb","hwio","acdb").data
|
||||
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
|
||||
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
|
||||
a_jt = jt.array(a)
|
||||
b_jt = jt.array(b)
|
||||
|
||||
|
|
Loading…
Reference in New Issue