mirror of https://github.com/Jittor/Jittor
fix mkl conv
This commit is contained in:
parent
901034659e
commit
557a4bf9d2
|
@ -44,7 +44,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
|
|
|
@ -16,7 +16,7 @@ struct MklConvOp : Op {
|
|||
int stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
/* MklConvOp: xformat abcd represents nchw */
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "mkl_conv"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -16,8 +16,9 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops, \
|
||||
cudnn, curand, cublas
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops
|
||||
if has_cuda:
|
||||
from .compile_extern import cudnn, curand, cublas
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
|
Loading…
Reference in New Issue