fix mkl default args

This commit is contained in:
Dun Liang 2020-05-13 15:40:20 +08:00
parent 32651c5c7d
commit fc4f54ae0b
2 changed files with 4 additions and 4 deletions

View File

@ -16,7 +16,7 @@ struct MklConvOp : Op {
int stride, padding, dilation, groups;
string xformat, wformat, yformat;
/* MklConvOp: xformat abcd represents nchw */
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "mkl_conv"; }
void infer_shape() override;

View File

@ -80,7 +80,7 @@ class TestMklConvOp(unittest.TestCase):
def check(xshape, wshape, stride, pad):
a = np.random.rand(*xshape).astype(np.float32)
b = np.random.rand(*wshape).astype(np.float32)
c = jt.mkl_ops.mkl_conv(a,b,stride,pad,1,"acdb","hwio").data
c = jt.mkl_ops.mkl_conv(a,b,stride,pad,1,xformat="acdb",wformat="hwio").data
a_jt = jt.array(a)
b_jt = jt.array(b)
@ -159,8 +159,8 @@ class TestMklConvOp(unittest.TestCase):
a = np.random.rand(n,H,W,c).astype(np.float32)
b = np.random.rand(h,w,i,o).astype(np.float32)
da = np.random.rand(n,H,W,o).astype(np.float32)
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,"acdb","hwio","acdb").data
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1,"acdb","hwio","acdb").data
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
a_jt = jt.array(a)
b_jt = jt.array(b)