update bmm

This commit is contained in:
Meng-Hao 2020-07-12 17:29:10 +08:00
parent 602ad64d53
commit 395733f85a
3 changed files with 174 additions and 0 deletions

View File

@ -0,0 +1,128 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Meng-Hao Guo <guomenghao1997@gmail.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
// cublas_batched_matmul_op.cc
#include "var.h"
#include "cublas_batched_matmul_op.h"
#include "cublas_warper.h"
using namespace std;
namespace jittor {
#ifndef JIT
static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul")
.get_constructor<VarPtr, Var*, Var*, bool, bool>();
CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
// TODO: support int8 * int8
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
// TODO: support diffrent input type
ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same";
c = create_output(nullptr, a->dtype());
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
}
VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// a [b,n,m] b [b,m,k], c[b,n,k]
// c = a*b
if (v_index == 0) {
// da = dc*b^T
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1);
} else {
// db = a^T*dc
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0);
}
}
void CublasBatchedMatmulOp::infer_shape(){
// TODO: 改成bmm的infer shape
ASSERTop(a->shape.size(),==,3);
ASSERTop(b->shape.size(),==,3);
int batch_size = a->shape[0], n = a->shape[1], m = a->shape[2];
int batch_size_ = b->shape[0], m_ = b->shape[1], k = b->shape[2];
ASSERTop(batch_size,==,batch_size_);
if (trans_a) {
swap(n, m);
}
if (trans_b) {
swap(m_, k);
}
ASSERTop(m,==,m_);
c->set_shape({batch_size, n, k});
}
void CublasBatchedMatmulOp::jit_prepare() {
add_jit_define("T", a->dtype());
add_jit_define("Trans_a", trans_a ? "T" : "N");
add_jit_define("Trans_b", trans_b ? "T" : "N");
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
}
#else // JIT
#ifdef JIT_cuda
#pragma clang diagnostic ignored "-Wtautological-compare"
void CublasBatchedMatmulOp::jit_run() {
// TODO
cublasHandle_t& handle_ = cublas_handle;
const T alpha = 1.0f;
const T beta = 0.0f;
const auto& as = a->shape;
const auto& bs = b->shape;
auto batch_size = as[0];
auto n = as[1];
auto m = as[2];
auto k = bs[2];
if ('@Trans_a'=='T') {
n = as[2];
m = as[1];
}
if ('@Trans_b'=='T') {
k = bs[1];
}
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
// 修改成bmm接口
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
}
#endif
#endif // JIT
} // jittor
// nn.py
// def bmm(a, b):
// from compile_extern import cublas_ops
// if jt.flags.use_cuda and cublas_ops:
// return cublas_ops.cublas_batched_matmul(a, b, 0, 0)
// assert len(a.shape) >= 2 and len(b.shape) >= 2
// assert a.shape[-1] == b.shape[-2]
// shape = list(a.shape) + [b.shape[-1]]
// a = a.broadcast(shape, [len(shape)-1])
// b = b.broadcast(shape, [len(shape)-3])
// return (a*b).sum(len(shape)-2)

View File

@ -0,0 +1,30 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors:
// Meng-Hao Guo <guomenghao1997@gmail.com>
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
// cublas_batched_matmul_op.h
#pragma once
#include "op.h"
#include "ops/op_register.h"
#include "var.h"
namespace jittor {
struct CublasBatchedMatmulOp : Op {
Var* a, * b, * c;
bool trans_a, trans_b;
CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b);
const char* name() const override { return "cublas_batched_matmul"; }
void infer_shape() override;
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -29,6 +29,22 @@ def matmul_transpose(a, b):
b = b.broadcast(shape)
return (a*b).sum(len(shape)-1)
def bmm(a, b):
'''
'''
if jt.flags.use_cuda:
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
assert len(a.shape) >= 2 and len(b.shape) >= 2
assert a.shape[-1] == b.shape[-2]
shape = list(a.shape) + [b.shape[-1]]
a = a.broadcast(shape, [len(shape)-1])
b = b.broadcast(shape, [len(shape)-3])
return (a*b).sum(len(shape)-2)
def matmul(a, b):
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-2]