Transitioning from Intel MKL-DNN to oneDNN

This commit is contained in:
lzhengning 2022-03-22 15:59:42 +08:00 committed by Zheng-Ning Liu
parent 7d9cff9e24
commit de2fceab22
6 changed files with 112 additions and 112 deletions

View File

@ -141,12 +141,12 @@ def setup_mkl():
elif platform.system() == 'Darwin':
mkl_lib_paths = [
"/usr/local/lib/libmkldnn.dylib", # x86_64
"/opt/homebrew/lib/libmkldnn.dylib", # arm64
"/usr/local/lib/libdnnl.dylib", # x86_64
"/opt/homebrew/lib/libdnnl.dylib", # arm64
]
if not any([os.path.exists(lib) for lib in mkl_lib_paths]):
raise RuntimeError("Not found onednn, please install it by the command 'brew install onednn'")
extra_flags = f" -lmkldnn "
extra_flags = f" -ldnnl "
mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops")
mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)]

View File

@ -47,9 +47,9 @@
#include <unordered_map>
#include <vector>
#include <mkldnn.hpp>
#include <dnnl.hpp>
using namespace mkldnn;
using namespace dnnl;
using namespace std;
@ -159,8 +159,8 @@ void simple_net(int times = 100) {
if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng);
net.push_back(reorder(user_src_memory, conv1_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, user_src_memory },
{ MKLDNN_ARG_TO, conv1_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, user_src_memory },
{ DNNL_ARG_TO, conv1_src_memory } });
}
auto conv1_weights_memory = user_weights_memory;
@ -181,10 +181,10 @@ void simple_net(int times = 100) {
/// @snippet cpu_cnn_inference_f32.cpp Create memory for output
//[Create convolution primitive]
net.push_back(convolution_forward(conv1_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_src_memory },
{ MKLDNN_ARG_WEIGHTS, conv1_weights_memory },
{ MKLDNN_ARG_BIAS, conv1_user_bias_memory },
{ MKLDNN_ARG_DST, conv1_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv1_src_memory },
{ DNNL_ARG_WEIGHTS, conv1_weights_memory },
{ DNNL_ARG_BIAS, conv1_user_bias_memory },
{ DNNL_ARG_DST, conv1_dst_memory } });
//[Create convolution primitive]
// AlexNet: relu1
@ -204,8 +204,8 @@ void simple_net(int times = 100) {
auto relu1_prim_desc = eltwise_forward::primitive_desc(relu1_desc, eng);
net.push_back(eltwise_forward(relu1_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_dst_memory },
{ MKLDNN_ARG_DST, conv1_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv1_dst_memory },
{ DNNL_ARG_DST, conv1_dst_memory } });
//[Create relu primitive]
// AlexNet: lrn1
@ -226,8 +226,8 @@ void simple_net(int times = 100) {
auto lrn1_dst_memory = memory(lrn1_prim_desc.dst_desc(), eng);
net.push_back(lrn_forward(lrn1_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_dst_memory },
{ MKLDNN_ARG_DST, lrn1_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv1_dst_memory },
{ DNNL_ARG_DST, lrn1_dst_memory } });
// AlexNet: pool1
// {batch, 96, 55, 55} -> {batch, 96, 27, 27}
@ -255,8 +255,8 @@ void simple_net(int times = 100) {
auto pool1_dst_memory = memory(pool1_pd.dst_desc(), eng);
net.push_back(pooling_forward(pool1_pd));
net_args.push_back({ { MKLDNN_ARG_SRC, lrn1_dst_memory },
{ MKLDNN_ARG_DST, pool1_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, lrn1_dst_memory },
{ DNNL_ARG_DST, pool1_dst_memory } });
//[Create pooling primitive]
// AlexNet: conv2
@ -296,8 +296,8 @@ void simple_net(int times = 100) {
if (conv2_prim_desc.src_desc() != conv2_src_memory.get_desc()) {
conv2_src_memory = memory(conv2_prim_desc.src_desc(), eng);
net.push_back(reorder(pool1_dst_memory, conv2_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, pool1_dst_memory },
{ MKLDNN_ARG_TO, conv2_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, pool1_dst_memory },
{ DNNL_ARG_TO, conv2_src_memory } });
}
auto conv2_weights_memory = conv2_user_weights_memory;
@ -312,10 +312,10 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(convolution_forward(conv2_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv2_src_memory },
{ MKLDNN_ARG_WEIGHTS, conv2_weights_memory },
{ MKLDNN_ARG_BIAS, conv2_user_bias_memory },
{ MKLDNN_ARG_DST, conv2_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv2_src_memory },
{ DNNL_ARG_WEIGHTS, conv2_weights_memory },
{ DNNL_ARG_BIAS, conv2_user_bias_memory },
{ DNNL_ARG_DST, conv2_dst_memory } });
// AlexNet: relu2
// {batch, 256, 27, 27} -> {batch, 256, 27, 27}
@ -328,8 +328,8 @@ void simple_net(int times = 100) {
auto relu2_prim_desc = eltwise_forward::primitive_desc(relu2_desc, eng);
net.push_back(eltwise_forward(relu2_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv2_dst_memory },
{ MKLDNN_ARG_DST, conv2_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv2_dst_memory },
{ DNNL_ARG_DST, conv2_dst_memory } });
// AlexNet: lrn2
// {batch, 256, 27, 27} -> {batch, 256, 27, 27}
@ -349,8 +349,8 @@ void simple_net(int times = 100) {
auto lrn2_dst_memory = memory(lrn2_prim_desc.dst_desc(), eng);
net.push_back(lrn_forward(lrn2_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv2_dst_memory },
{ MKLDNN_ARG_DST, lrn2_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv2_dst_memory },
{ DNNL_ARG_DST, lrn2_dst_memory } });
// AlexNet: pool2
// {batch, 256, 27, 27} -> {batch, 256, 13, 13}
@ -372,8 +372,8 @@ void simple_net(int times = 100) {
// create pooling primitive an add it to net
net.push_back(pooling_forward(pool2_pd));
net_args.push_back({ { MKLDNN_ARG_SRC, lrn2_dst_memory },
{ MKLDNN_ARG_DST, pool2_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, lrn2_dst_memory },
{ DNNL_ARG_DST, pool2_dst_memory } });
// AlexNet: conv3
// {batch, 256, 13, 13} (x) {384, 256, 3, 3}; -> {batch, 384, 13, 13};
@ -412,8 +412,8 @@ void simple_net(int times = 100) {
if (conv3_prim_desc.src_desc() != conv3_src_memory.get_desc()) {
conv3_src_memory = memory(conv3_prim_desc.src_desc(), eng);
net.push_back(reorder(pool2_dst_memory, conv3_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, pool2_dst_memory },
{ MKLDNN_ARG_TO, conv3_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, pool2_dst_memory },
{ DNNL_ARG_TO, conv3_src_memory } });
}
auto conv3_weights_memory = conv3_user_weights_memory;
@ -428,10 +428,10 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(convolution_forward(conv3_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv3_src_memory },
{ MKLDNN_ARG_WEIGHTS, conv3_weights_memory },
{ MKLDNN_ARG_BIAS, conv3_user_bias_memory },
{ MKLDNN_ARG_DST, conv3_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv3_src_memory },
{ DNNL_ARG_WEIGHTS, conv3_weights_memory },
{ DNNL_ARG_BIAS, conv3_user_bias_memory },
{ DNNL_ARG_DST, conv3_dst_memory } });
// AlexNet: relu3
// {batch, 384, 13, 13} -> {batch, 384, 13, 13}
@ -444,8 +444,8 @@ void simple_net(int times = 100) {
auto relu3_prim_desc = eltwise_forward::primitive_desc(relu3_desc, eng);
net.push_back(eltwise_forward(relu3_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv3_dst_memory },
{ MKLDNN_ARG_DST, conv3_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv3_dst_memory },
{ DNNL_ARG_DST, conv3_dst_memory } });
// AlexNet: conv4
// {batch, 384, 13, 13} (x) {2, 192, 192, 3, 3}; ->
@ -485,8 +485,8 @@ void simple_net(int times = 100) {
if (conv4_prim_desc.src_desc() != conv4_src_memory.get_desc()) {
conv4_src_memory = memory(conv4_prim_desc.src_desc(), eng);
net.push_back(reorder(conv3_dst_memory, conv4_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, conv3_dst_memory },
{ MKLDNN_ARG_TO, conv4_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, conv3_dst_memory },
{ DNNL_ARG_TO, conv4_src_memory } });
}
auto conv4_weights_memory = conv4_user_weights_memory;
@ -501,10 +501,10 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(convolution_forward(conv4_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv4_src_memory },
{ MKLDNN_ARG_WEIGHTS, conv4_weights_memory },
{ MKLDNN_ARG_BIAS, conv4_user_bias_memory },
{ MKLDNN_ARG_DST, conv4_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv4_src_memory },
{ DNNL_ARG_WEIGHTS, conv4_weights_memory },
{ DNNL_ARG_BIAS, conv4_user_bias_memory },
{ DNNL_ARG_DST, conv4_dst_memory } });
// AlexNet: relu4
// {batch, 384, 13, 13} -> {batch, 384, 13, 13}
@ -517,8 +517,8 @@ void simple_net(int times = 100) {
auto relu4_prim_desc = eltwise_forward::primitive_desc(relu4_desc, eng);
net.push_back(eltwise_forward(relu4_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv4_dst_memory },
{ MKLDNN_ARG_DST, conv4_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv4_dst_memory },
{ DNNL_ARG_DST, conv4_dst_memory } });
// AlexNet: conv5
// {batch, 384, 13, 13} (x) {2, 128, 192, 3, 3}; -> {batch, 256, 13, 13};
@ -557,8 +557,8 @@ void simple_net(int times = 100) {
if (conv5_prim_desc.src_desc() != conv5_src_memory.get_desc()) {
conv5_src_memory = memory(conv5_prim_desc.src_desc(), eng);
net.push_back(reorder(conv4_dst_memory, conv5_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, conv4_dst_memory },
{ MKLDNN_ARG_TO, conv5_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, conv4_dst_memory },
{ DNNL_ARG_TO, conv5_src_memory } });
}
auto conv5_weights_memory = conv5_user_weights_memory;
@ -573,10 +573,10 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(convolution_forward(conv5_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv5_src_memory },
{ MKLDNN_ARG_WEIGHTS, conv5_weights_memory },
{ MKLDNN_ARG_BIAS, conv5_user_bias_memory },
{ MKLDNN_ARG_DST, conv5_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv5_src_memory },
{ DNNL_ARG_WEIGHTS, conv5_weights_memory },
{ DNNL_ARG_BIAS, conv5_user_bias_memory },
{ DNNL_ARG_DST, conv5_dst_memory } });
// AlexNet: relu5
// {batch, 256, 13, 13} -> {batch, 256, 13, 13}
@ -589,8 +589,8 @@ void simple_net(int times = 100) {
auto relu5_prim_desc = eltwise_forward::primitive_desc(relu5_desc, eng);
net.push_back(eltwise_forward(relu5_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv5_dst_memory },
{ MKLDNN_ARG_DST, conv5_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv5_dst_memory },
{ DNNL_ARG_DST, conv5_dst_memory } });
// AlexNet: pool5
// {batch, 256, 13, 13} -> {batch, 256, 6, 6}
@ -615,8 +615,8 @@ void simple_net(int times = 100) {
// create pooling primitive an add it to net
net.push_back(pooling_forward(pool5_pd));
net_args.push_back({ { MKLDNN_ARG_SRC, conv5_dst_memory },
{ MKLDNN_ARG_DST, pool5_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv5_dst_memory },
{ DNNL_ARG_DST, pool5_dst_memory } });
// fc6 inner product {batch, 256, 6, 6} (x) {4096, 256, 6, 6}-> {batch,
@ -651,8 +651,8 @@ void simple_net(int times = 100) {
if (fc6_prim_desc.src_desc() != fc6_src_memory.get_desc()) {
fc6_src_memory = memory(fc6_prim_desc.src_desc(), eng);
net.push_back(reorder(pool5_dst_memory, fc6_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, pool5_dst_memory },
{ MKLDNN_ARG_TO, fc6_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, pool5_dst_memory },
{ DNNL_ARG_TO, fc6_src_memory } });
}
auto fc6_weights_memory = fc6_user_weights_memory;
@ -666,10 +666,10 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(inner_product_forward(fc6_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, fc6_src_memory },
{ MKLDNN_ARG_WEIGHTS, fc6_weights_memory },
{ MKLDNN_ARG_BIAS, fc6_user_bias_memory },
{ MKLDNN_ARG_DST, fc6_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, fc6_src_memory },
{ DNNL_ARG_WEIGHTS, fc6_weights_memory },
{ DNNL_ARG_BIAS, fc6_user_bias_memory },
{ DNNL_ARG_DST, fc6_dst_memory } });
// fc7 inner product {batch, 4096} (x) {4096, 4096}-> {batch, 4096}
@ -708,10 +708,10 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(inner_product_forward(fc7_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, fc6_dst_memory },
{ MKLDNN_ARG_WEIGHTS, fc7_weights_memory },
{ MKLDNN_ARG_BIAS, fc7_user_bias_memory },
{ MKLDNN_ARG_DST, fc7_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, fc6_dst_memory },
{ DNNL_ARG_WEIGHTS, fc7_weights_memory },
{ DNNL_ARG_BIAS, fc7_user_bias_memory },
{ DNNL_ARG_DST, fc7_dst_memory } });
// fc8 inner product {batch, 4096} (x) {1000, 4096}-> {batch, 1000}
memory::dims fc8_weights_tz = { 1000, 4096 };
@ -750,17 +750,17 @@ void simple_net(int times = 100) {
// create convolution primitive and add it to net
net.push_back(inner_product_forward(fc8_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, fc7_dst_memory },
{ MKLDNN_ARG_WEIGHTS, fc8_weights_memory },
{ MKLDNN_ARG_BIAS, fc8_user_bias_memory },
{ MKLDNN_ARG_DST, fc8_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, fc7_dst_memory },
{ DNNL_ARG_WEIGHTS, fc8_weights_memory },
{ DNNL_ARG_BIAS, fc8_user_bias_memory },
{ DNNL_ARG_DST, fc8_dst_memory } });
// create reorder between internal and user data if it is needed and
// add it to net after pooling
if (fc8_dst_memory != user_dst_memory) {
net.push_back(reorder(fc8_dst_memory, user_dst_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, fc8_dst_memory },
{ MKLDNN_ARG_TO, user_dst_memory } });
net_args.push_back({ { DNNL_ARG_FROM, fc8_dst_memory },
{ DNNL_ARG_TO, user_dst_memory } });
}
/// @page cpu_cnn_inference_f32_cpp

View File

@ -13,9 +13,9 @@
#include "var.h"
#include "mkl_conv_backward_w_op.h"
#include <mkldnn.hpp>
#include <dnnl.hpp>
using namespace mkldnn;
using namespace dnnl;
using namespace std;
namespace jittor {
@ -143,8 +143,8 @@ void MklConvBackwardWOp::jit_run() {
if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) {
conv_src_memory = memory(conv_pd.src_desc(), eng);
net_bwd.push_back(reorder(conv_user_src_memory, conv_src_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_src_memory},
{MKLDNN_ARG_TO, conv_src_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_src_memory},
{DNNL_ARG_TO, conv_src_memory}});
}
auto conv_user_diff_dst_memory
@ -169,8 +169,8 @@ void MklConvBackwardWOp::jit_run() {
if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) {
conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng);
net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_src_memory},
{MKLDNN_ARG_TO, conv_bwd_src_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_src_memory},
{DNNL_ARG_TO, conv_bwd_src_memory}});
}
auto conv_diff_dst_memory = conv_user_diff_dst_memory;
@ -178,13 +178,13 @@ void MklConvBackwardWOp::jit_run() {
!= conv_user_diff_dst_memory.get_desc()) {
conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng);
net_bwd.push_back(reorder(conv_user_diff_dst_memory, conv_diff_dst_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_diff_dst_memory},
{MKLDNN_ARG_TO, conv_diff_dst_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_diff_dst_memory},
{DNNL_ARG_TO, conv_diff_dst_memory}});
}
net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd));
net_bwd_args.push_back({{MKLDNN_ARG_SRC, conv_bwd_src_memory},
{MKLDNN_ARG_DIFF_DST, conv_diff_dst_memory}});
net_bwd_args.push_back({{DNNL_ARG_SRC, conv_bwd_src_memory},
{DNNL_ARG_DIFF_DST, conv_diff_dst_memory}});
auto conv_diff_weights_memory = conv_user_diff_weights_memory;
if (conv_bwd_weights_pd.diff_weights_desc()
@ -192,15 +192,15 @@ void MklConvBackwardWOp::jit_run() {
conv_diff_weights_memory
= memory(conv_bwd_weights_pd.diff_weights_desc(), eng);
net_bwd_args.back().insert(
{MKLDNN_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
{DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
net_bwd.push_back(reorder(
conv_diff_weights_memory, conv_user_diff_weights_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_diff_weights_memory},
{MKLDNN_ARG_TO, conv_user_diff_weights_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_weights_memory},
{DNNL_ARG_TO, conv_user_diff_weights_memory}});
} else {
net_bwd_args.back().insert(
{MKLDNN_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
{DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory});
}
ASSERTop(net_bwd.size(),==,net_bwd_args.size());

View File

@ -13,9 +13,9 @@
#include "var.h"
#include "mkl_conv_backward_x_op.h"
#include <mkldnn.hpp>
#include <dnnl.hpp>
using namespace mkldnn;
using namespace dnnl;
using namespace std;
namespace jittor {
@ -142,8 +142,8 @@ void MklConvBackwardXOp::jit_run() {
conv_weights_memory = memory(conv_pd.weights_desc(), eng);
net_bwd.push_back(
reorder(conv_user_weights_memory, conv_weights_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_weights_memory},
{MKLDNN_ARG_TO, conv_weights_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_weights_memory},
{DNNL_ARG_TO, conv_weights_memory}});
}
auto conv_user_diff_dst_memory
@ -168,21 +168,21 @@ void MklConvBackwardXOp::jit_run() {
!= conv_user_diff_dst_memory.get_desc()) {
conv_diff_dst_memory = memory(conv_bwd_data_pd.diff_dst_desc(), eng);
net_bwd.push_back(reorder(conv_user_diff_dst_memory, conv_diff_dst_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_user_diff_dst_memory},
{MKLDNN_ARG_TO, conv_diff_dst_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_diff_dst_memory},
{DNNL_ARG_TO, conv_diff_dst_memory}});
}
auto conv_bwd_weights_memory = conv_weights_memory;
if (conv_bwd_data_pd.weights_desc() != conv_weights_memory.get_desc()) {
conv_bwd_weights_memory = memory(conv_bwd_data_pd.weights_desc(), eng);
net_bwd.push_back(reorder(conv_weights_memory, conv_bwd_weights_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_weights_memory},
{MKLDNN_ARG_TO, conv_bwd_weights_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_weights_memory},
{DNNL_ARG_TO, conv_bwd_weights_memory}});
}
net_bwd.push_back(convolution_backward_data(conv_bwd_data_pd));
net_bwd_args.push_back({{MKLDNN_ARG_WEIGHTS, conv_bwd_weights_memory},
{MKLDNN_ARG_DIFF_DST, conv_diff_dst_memory}});
net_bwd_args.push_back({{DNNL_ARG_WEIGHTS, conv_bwd_weights_memory},
{DNNL_ARG_DIFF_DST, conv_diff_dst_memory}});
auto conv_diff_src_memory = conv_user_diff_src_memory;
if (conv_bwd_data_pd.diff_src_desc()
@ -190,15 +190,15 @@ void MklConvBackwardXOp::jit_run() {
conv_diff_src_memory
= memory(conv_bwd_data_pd.diff_src_desc(), eng);
net_bwd_args.back().insert(
{MKLDNN_ARG_DIFF_SRC, conv_diff_src_memory});
{DNNL_ARG_DIFF_SRC, conv_diff_src_memory});
net_bwd.push_back(reorder(
conv_diff_src_memory, conv_user_diff_src_memory));
net_bwd_args.push_back({{MKLDNN_ARG_FROM, conv_diff_src_memory},
{MKLDNN_ARG_TO, conv_user_diff_src_memory}});
net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_src_memory},
{DNNL_ARG_TO, conv_user_diff_src_memory}});
} else {
net_bwd_args.back().insert(
{MKLDNN_ARG_DIFF_SRC, conv_diff_src_memory});
{DNNL_ARG_DIFF_SRC, conv_diff_src_memory});
}
ASSERTop(net_bwd.size(),==,net_bwd_args.size());

View File

@ -7,12 +7,12 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <mkldnn.hpp>
#include <dnnl.hpp>
#include "var.h"
#include "mkl_conv_op.h"
using namespace mkldnn;
using namespace dnnl;
using namespace std;
namespace jittor {
@ -110,7 +110,7 @@ void MklConvOp::jit_run() {
auto n = ws[3];
auto k = xs[3];
// x: [m,k], w: [k,n], y: [m,n]
ASSERTop(0,==,mkldnn_sgemm('N', 'N', m, n, k,
ASSERTop(0,==,dnnl_sgemm('N', 'N', m, n, k,
1.f, x->ptr<float32>(), k,
w->ptr<float32>(), n,
0.f, y->ptr<float32>(), n));
@ -162,27 +162,27 @@ void MklConvOp::jit_run() {
if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng);
net.push_back(reorder(user_src_memory, conv1_src_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, user_src_memory },
{ MKLDNN_ARG_TO, conv1_src_memory } });
net_args.push_back({ { DNNL_ARG_FROM, user_src_memory },
{ DNNL_ARG_TO, conv1_src_memory } });
}
auto conv1_weights_memory = user_weights_memory;
if (conv1_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv1_weights_memory = memory(conv1_prim_desc.weights_desc(), eng);
net.push_back(reorder(user_weights_memory, conv1_weights_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, user_weights_memory }, { MKLDNN_ARG_TO, conv1_weights_memory } });
net_args.push_back({ { DNNL_ARG_FROM, user_weights_memory }, { DNNL_ARG_TO, conv1_weights_memory } });
}
auto conv1_dst_memory = memory(conv1_prim_desc.dst_desc(), eng);
net.push_back(convolution_forward(conv1_prim_desc));
net_args.push_back({ { MKLDNN_ARG_SRC, conv1_src_memory },
{ MKLDNN_ARG_WEIGHTS, conv1_weights_memory },
{ MKLDNN_ARG_DST, conv1_dst_memory } });
net_args.push_back({ { DNNL_ARG_SRC, conv1_src_memory },
{ DNNL_ARG_WEIGHTS, conv1_weights_memory },
{ DNNL_ARG_DST, conv1_dst_memory } });
if (conv1_dst_memory != user_dst_memory) {
net.push_back(reorder(conv1_dst_memory, user_dst_memory));
net_args.push_back({ { MKLDNN_ARG_FROM, conv1_dst_memory },{ MKLDNN_ARG_TO, user_dst_memory } });
net_args.push_back({ { DNNL_ARG_FROM, conv1_dst_memory },{ DNNL_ARG_TO, user_dst_memory } });
}
ASSERTop(net.size(),==,net_args.size());

View File

@ -7,12 +7,12 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <mkldnn.hpp>
#include <dnnl.hpp>
#include "var.h"
#include "mkl_matmul_op.h"
using namespace mkldnn;
using namespace dnnl;
using namespace std;
namespace jittor {
@ -66,7 +66,7 @@ void MklMatmulOp::jit_run() {
k = bs[0];
}
// a: [n,m], b: [m,k], c: [n,k]
ASSERTop(0,==,mkldnn_sgemm('@Trans_a', '@Trans_b', n, k, m,
ASSERTop(0,==,dnnl_sgemm('@Trans_a', '@Trans_b', n, k, m,
1.f, a->ptr<T>(), '@Trans_a'=='N'? m : n,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
0.f, c->ptr<T>(), k));