fix mkl conv

This commit is contained in:
Dun Liang 2020-05-02 10:54:16 +08:00
parent 901034659e
commit 557a4bf9d2
3 changed files with 5 additions and 4 deletions

View File

@ -44,7 +44,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
y = create_output(nullptr, dtype_infer(x->ns, w->ns));

View File

@ -16,7 +16,7 @@ struct MklConvOp : Op {
int stride, padding, dilation;
string xformat, wformat, yformat;
/* MklConvOp: xformat abcd represents nchw */
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, string xformat="abcd", string wformat="oihw", string yformat="");
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "mkl_conv"; }
void infer_shape() override;

View File

@ -16,8 +16,9 @@ with lock.lock_scope():
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops, mpi, mpi_ops, \
cudnn, curand, cublas
from .compile_extern import mkl_ops, mpi, mpi_ops
if has_cuda:
from .compile_extern import cudnn, curand, cublas
import contextlib
import numpy as np