mirror of https://github.com/Jittor/Jittor
Merge branch 'gmh' of https://github.com/Jittor/jittor
This commit is contained in:
commit
c49ebb1327
|
@ -0,0 +1,111 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
|
||||
// cublas_batched_matmul_op.cc
|
||||
#include "var.h"
|
||||
|
||||
#include "cublas_batched_matmul_op.h"
|
||||
#include "cublas_warper.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul")
|
||||
.get_constructor<VarPtr, Var*, Var*, bool, bool>();
|
||||
|
||||
CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b)
|
||||
: a(a), b(b), trans_a(trans_a), trans_b(trans_b) {
|
||||
// TODO: support int8 * int8
|
||||
ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same";
|
||||
// TODO: support diffrent input type
|
||||
ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same";
|
||||
c = create_output(nullptr, a->dtype());
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
}
|
||||
|
||||
|
||||
VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
// a [b,n,m] b [b,m,k], c[b,n,k]
|
||||
// c = a*b
|
||||
if (v_index == 0) {
|
||||
// da = dc*b^T
|
||||
return make_cublas_batched_matmul(dout, b, trans_a^0, trans_b^1);
|
||||
} else {
|
||||
// db = a^T*dc
|
||||
return make_cublas_batched_matmul(a, dout, trans_a^1, trans_b^0);
|
||||
}
|
||||
}
|
||||
|
||||
void CublasBatchedMatmulOp::infer_shape(){
|
||||
ASSERTop(a->shape.size(),==,3);
|
||||
ASSERTop(b->shape.size(),==,3);
|
||||
|
||||
int batch_size = a->shape[0], n = a->shape[1], m = a->shape[2];
|
||||
int batch_size_ = b->shape[0], m_ = b->shape[1], k = b->shape[2];
|
||||
|
||||
ASSERTop(batch_size,==,batch_size_);
|
||||
if (trans_a) {
|
||||
swap(n, m);
|
||||
}
|
||||
if (trans_b) {
|
||||
swap(m_, k);
|
||||
}
|
||||
ASSERTop(m,==,m_);
|
||||
|
||||
c->set_shape({batch_size, n, k});
|
||||
}
|
||||
|
||||
void CublasBatchedMatmulOp::jit_prepare() {
|
||||
add_jit_define("T", a->dtype());
|
||||
add_jit_define("Trans_a", trans_a ? "T" : "N");
|
||||
add_jit_define("Trans_b", trans_b ? "T" : "N");
|
||||
add_jit_define("op", a->dtype().dsize() == 4 ? "S" : "D");
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
void CublasBatchedMatmulOp::jit_run() {
|
||||
cublasHandle_t& handle_ = cublas_handle;
|
||||
const T alpha = 1.0f;
|
||||
const T beta = 0.0f;
|
||||
|
||||
const auto& as = a->shape;
|
||||
const auto& bs = b->shape;
|
||||
auto batch_size = as[0];
|
||||
auto n = as[1];
|
||||
auto m = as[2];
|
||||
auto k = bs[2];
|
||||
if ('@Trans_a'=='T') {
|
||||
n = as[2];
|
||||
m = as[1];
|
||||
}
|
||||
if ('@Trans_b'=='T') {
|
||||
k = bs[1];
|
||||
}
|
||||
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
|
||||
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
|
||||
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
|
||||
k, n, m, &alpha,
|
||||
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
|
||||
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
|
||||
c->ptr<T>(), k, k * n,
|
||||
batch_size));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
|
||||
// cublas_batched_matmul_op.h
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CublasBatchedMatmulOp : Op {
|
||||
Var* a, * b, * c;
|
||||
bool trans_a, trans_b;
|
||||
CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b);
|
||||
|
||||
const char* name() const override { return "cublas_batched_matmul"; }
|
||||
void infer_shape() override;
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -23,13 +23,13 @@ def matmul_transpose(a, b):
|
|||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1]
|
||||
|
||||
if jt.flags.use_cuda:
|
||||
jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
b = b.broadcast(shape)
|
||||
return (a*b).sum(len(shape)-1)
|
||||
|
||||
|
||||
def bmm(a, b):
|
||||
''' batch matrix multiply,
|
||||
shape of input a is [batch, n, m],
|
||||
|
@ -46,11 +46,11 @@ Example::
|
|||
a = jt.random((batch, n, m))
|
||||
b = jt.random((batch, m, k))
|
||||
c = nn.bmm(a, b)
|
||||
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) >= 2
|
||||
assert a.shape[-1] == b.shape[-2]
|
||||
|
||||
if jt.flags.use_cuda:
|
||||
return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0)
|
||||
shape = list(a.shape) + [b.shape[-1]]
|
||||
a = a.broadcast(shape, [len(shape)-1])
|
||||
b = b.broadcast(shape, [len(shape)-3])
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
class TestBMM(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "No cuda found")
|
||||
def test_bmm_cuda(self):
|
||||
def check(batch, n, m, k):
|
||||
def calc(use_cuda, a, b, mask):
|
||||
jt.flags.use_cuda = use_cuda
|
||||
a = jt.array(a)
|
||||
b = jt.array(b)
|
||||
mask = jt.array(mask)
|
||||
c = nn.bmm(a, b)
|
||||
da, db = jt.grad(c*mask, [a, b])
|
||||
return c.data, da.data, db.data
|
||||
mask = np.random.rand(batch, n, k).astype("float32")
|
||||
a = np.random.rand(batch, n, m).astype("float32")
|
||||
b = np.random.rand(batch, m, k).astype("float32")
|
||||
a1,a2,a3 = calc(0, a, b, mask)
|
||||
b1,b2,b3 = calc(1, a, b, mask)
|
||||
assert np.allclose(a1, b1)
|
||||
assert np.allclose(a2, b2)
|
||||
assert np.allclose(a3, b3)
|
||||
check(10,3,4,5)
|
||||
check(10,8,8,8)
|
||||
check(10,8,1,8)
|
||||
check(10,8,8,1)
|
||||
check(10,1,8,8)
|
||||
check(1,7,8,8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue