mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of github.com:Jittor/jittor into win_cuda
This commit is contained in:
commit
a78c3b4f12
|
@ -1,5 +1,6 @@
|
|||
# Jittor: 即时编译深度学习框架
|
||||
|
||||

|
||||
|
||||
[快速开始](#快速开始) | [安装](#安装) | [教程](#教程)
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Jittor: a Just-in-time(JIT) deep learning framework
|
||||
|
||||

|
||||
|
||||
[Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Jittor: a Just-in-time(JIT) deep learning framework
|
||||
# Jittor: 即时编译深度学习框架
|
||||
|
||||

|
||||
|
||||
[Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md)
|
||||
|
||||
[快速开始](#快速开始) | [安装](#安装) | [教程](#教程)
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
|
@ -0,0 +1,154 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
#include "init.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static inline cudnnRNNMode_t rnn_string_to_rnn_mode(string mode) {
|
||||
if (mode == "relu")
|
||||
return CUDNN_RNN_RELU;
|
||||
if (mode == "tanh")
|
||||
return CUDNN_RNN_TANH;
|
||||
if (mode == "lstm")
|
||||
return CUDNN_LSTM;
|
||||
ASSERT(mode == "gru") << "rnn mode must be relu, tanh, lstm, or gru, but got " << mode;
|
||||
return CUDNN_GRU;
|
||||
}
|
||||
|
||||
static inline int rnn_string_to_num_linear_layers(string mode) {
|
||||
if (mode == "relu")
|
||||
return 2;
|
||||
if (mode == "tanh")
|
||||
return 2;
|
||||
if (mode == "lstm")
|
||||
return 8;
|
||||
ASSERT(mode == "gru") << "mode must be relu, tanh, lstm, or gru, but got " << mode;
|
||||
return 6;
|
||||
}
|
||||
|
||||
/** A wrapper for CUDNN dropout descriptor
|
||||
*/
|
||||
struct DropoutDescriptor {
|
||||
cudnnDropoutDescriptor_t desc;
|
||||
size_t stateSize, stateAllocation;
|
||||
float dropout;
|
||||
void *stateSpace;
|
||||
|
||||
DropoutDescriptor(cudnnHandle_t handle, float dropout)
|
||||
: dropout(dropout), stateSpace(nullptr) {
|
||||
checkCudaErrors(cudnnCreateDropoutDescriptor(&desc));
|
||||
if (dropout > 0) {
|
||||
checkCudaErrors(cudnnDropoutGetStatesSize(handle, &stateSize));
|
||||
stateSpace = exe.temp_allocator->alloc(stateSize, stateAllocation);
|
||||
checkCudaErrors(cudnnSetDropoutDescriptor(
|
||||
desc,
|
||||
cudnn_handle,
|
||||
dropout,
|
||||
stateSpace,
|
||||
stateSize,
|
||||
get_seed()
|
||||
));
|
||||
} else {
|
||||
checkCudaErrors(cudnnSetDropoutDescriptor(
|
||||
desc, handle, 0, nullptr, 0, 0
|
||||
));
|
||||
}
|
||||
}
|
||||
~DropoutDescriptor() {
|
||||
checkCudaErrors(cudnnDestroyDropoutDescriptor(desc));
|
||||
if (stateSpace)
|
||||
exe.temp_allocator->free(stateSpace, stateSize, stateAllocation);
|
||||
}
|
||||
};
|
||||
|
||||
/** A wrapper for CUDNN RNN descriptor
|
||||
*/
|
||||
struct RnnDescriptor {
|
||||
cudnnHandle_t handle;
|
||||
cudnnRNNDescriptor_t desc;
|
||||
DropoutDescriptor dropoutDesc;
|
||||
|
||||
RnnDescriptor(cudnnHandle_t handle, string mode, int hidden_size, int num_layers,
|
||||
float dropout, bool bidirectional) : handle(handle), dropoutDesc(handle, dropout) {
|
||||
checkCudaErrors(cudnnCreateRNNDescriptor(&desc));
|
||||
checkCudaErrors(cudnnSetRNNDescriptor_v6(
|
||||
handle,
|
||||
desc,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropoutDesc.desc,
|
||||
CUDNN_LINEAR_INPUT,
|
||||
bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
|
||||
rnn_string_to_rnn_mode(mode),
|
||||
CUDNN_RNN_ALGO_STANDARD,
|
||||
CUDNN_DATA_FLOAT
|
||||
));
|
||||
}
|
||||
|
||||
~RnnDescriptor() {
|
||||
checkCudaErrors(cudnnDestroyRNNDescriptor(desc));
|
||||
}
|
||||
|
||||
size_t weight_space_size(const cudnnTensorDescriptor_t &xDesc) {
|
||||
size_t size;
|
||||
checkCudaErrors(cudnnGetRNNParamsSize(
|
||||
handle, desc, xDesc, &size, CUDNN_DATA_FLOAT
|
||||
));
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t work_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) {
|
||||
size_t size;
|
||||
checkCudaErrors(cudnnGetRNNWorkspaceSize(
|
||||
handle, desc, seq_length, xDesc, &size
|
||||
));
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t reserve_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) {
|
||||
size_t size;
|
||||
checkCudaErrors(cudnnGetRNNTrainingReserveSize(
|
||||
handle, desc, seq_length, xDesc, &size
|
||||
));
|
||||
return size;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
*/
|
||||
struct RnnWeightDescriptor {
|
||||
cudnnFilterDescriptor_t desc;
|
||||
size_t size;
|
||||
RnnWeightDescriptor(size_t size) : size(size) {
|
||||
int dimW[3] = {(int) (size / sizeof(float)), 1, 1};
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor(&desc));
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dimW));
|
||||
}
|
||||
~RnnWeightDescriptor() {
|
||||
cudnnDestroyFilterDescriptor(desc);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
Returns offsets of RNN linear parameters in a flatten array.
|
||||
|
||||
Returns
|
||||
=======
|
||||
list: [total size, param #1 offset, param #2 offset, ...]
|
||||
|
||||
TODO: support cudnn rnn-v8; support proj_size
|
||||
*/
|
||||
// @pyjt(cudnn_rnn_weight_offset)
|
||||
vector<int32_t> cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional);
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,195 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cudnn_rnn_descriptor.h"
|
||||
#include "cudnn_rnn_backward_x_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var *x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation,
|
||||
string mode, int input_size, int hidden_size, int num_layers, int proj_size,
|
||||
double dropout, bool bias, bool bidirectional)
|
||||
: x(x), hx(hx), cx(cx), y(y), dy(dy), dhy(dhy), dcy(dcy), w(w), reservation(reservation),
|
||||
mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers),
|
||||
proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) {
|
||||
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
|
||||
ASSERTop(mode,==,"lstm");
|
||||
ASSERTop(proj_size,==,0);
|
||||
init_rnn();
|
||||
}
|
||||
|
||||
CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation,
|
||||
string mode, int input_size, int hidden_size, int num_layers, int proj_size,
|
||||
double dropout, bool bias, bool bidirectional)
|
||||
: x(x), hx(hx), cx(nullptr), y(y), dy(dy), dhy(dhy), dcy(nullptr), w(w), reservation(reservation),
|
||||
mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers),
|
||||
proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) {
|
||||
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
|
||||
ASSERTop(mode,!=,"lstm");
|
||||
ASSERTop(proj_size,==,0);
|
||||
init_rnn();
|
||||
}
|
||||
|
||||
void CudnnRnnBackwardXOp::init_rnn() {
|
||||
dx = create_output(nullptr, ns_float32);
|
||||
dhx = create_output(nullptr, ns_float32);
|
||||
|
||||
if (mode == "lstm")
|
||||
dcx = create_output(nullptr, ns_float32);
|
||||
else
|
||||
dcx = nullptr;
|
||||
|
||||
dw = create_output(nullptr, dtype_infer(x->ns, y->ns));
|
||||
|
||||
seq_length = y->shape[0];
|
||||
batch_size = y->shape[1];
|
||||
}
|
||||
|
||||
void CudnnRnnBackwardXOp::infer_shape() {
|
||||
dx->set_shape(NanoVector(seq_length, batch_size, input_size));
|
||||
|
||||
int num_directions = 1 + bidirectional;
|
||||
if (proj_size > 0)
|
||||
dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size));
|
||||
else
|
||||
dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size));
|
||||
|
||||
if (dcx)
|
||||
dcx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size));
|
||||
|
||||
dw->set_shape(w->shape);
|
||||
}
|
||||
|
||||
void CudnnRnnBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << hx->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnRnnBackwardXOp::jit_run() {
|
||||
int num_directions = 1 + bidirectional;
|
||||
|
||||
int in_dims[3] = {batch_size, input_size, 1};
|
||||
int out_dims[3] = {batch_size, hidden_size * num_directions, 1};
|
||||
int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1};
|
||||
int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1};
|
||||
int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size};
|
||||
int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1};
|
||||
|
||||
vector<cudnnTensorDescriptor_t> xDesc(seq_length), dxDesc(seq_length);
|
||||
vector<cudnnTensorDescriptor_t> yDesc(seq_length), dyDesc(seq_length);
|
||||
|
||||
for (int i = 0; i < seq_length; ++i) {
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i]));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&dxDesc[i]));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i]));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&dyDesc[i]));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType<Ty>(), 3, in_dims, in_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(dxDesc[i], getDataType<Ty>(), 3, in_dims, in_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType<Ty>(), 3, out_dims, out_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(dyDesc[i], getDataType<Ty>(), 3, out_dims, out_strides));
|
||||
}
|
||||
|
||||
cudnnTensorDescriptor_t dhyDesc, dcyDesc;
|
||||
cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc;
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&dhxDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&dcxDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&dhyDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&dcyDesc));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(dhxDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(dcxDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(dhyDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(dcyDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
|
||||
RnnWeightDescriptor w_desc(w->size);
|
||||
RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional);
|
||||
|
||||
void *work_space;
|
||||
size_t work_space_size = rnn_desc.work_space_size(dxDesc.data(), seq_length);
|
||||
size_t work_space_allocation;
|
||||
if (work_space_size > 0)
|
||||
work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation);
|
||||
|
||||
size_t reserveSpaceSize = reservation->size;
|
||||
|
||||
checkCudaErrors(cudnnRNNBackwardData(
|
||||
cudnn_handle, rnn_desc.desc,
|
||||
seq_length,
|
||||
yDesc.data(), y->ptr<Ty>(),
|
||||
dyDesc.data(), dy->ptr<Ty>(),
|
||||
dhyDesc, dhy->ptr<Ty>(),
|
||||
dcyDesc, mode == "lstm" ? dcy->ptr<Ty>(): nullptr,
|
||||
w_desc.desc, w->ptr<Tw>(),
|
||||
hxDesc, hx->ptr<Tx>(),
|
||||
cxDesc, mode == "lstm" ? cx->ptr<Tx>() : nullptr,
|
||||
dxDesc.data(), dx->ptr<Tx>(),
|
||||
dhxDesc, dhx->ptr<Tx>(),
|
||||
dcxDesc, mode == "lstm" ? dcx->ptr<Tx>() : nullptr,
|
||||
work_space, work_space_size,
|
||||
reservation->ptr<Tx>(), reservation->size
|
||||
));
|
||||
|
||||
checkCudaErrors(cudaMemset(dw->ptr<Tw>(), 0, dw->size));
|
||||
|
||||
checkCudaErrors(cudnnRNNBackwardWeights(
|
||||
cudnn_handle, rnn_desc.desc,
|
||||
seq_length,
|
||||
xDesc.data(), x->ptr<Tx>(),
|
||||
hxDesc, hx->ptr<Tx>(),
|
||||
yDesc.data(), y->ptr<Ty>(),
|
||||
work_space, work_space_size,
|
||||
w_desc.desc, dw->ptr<Tw>(),
|
||||
reservation->ptr<Tx>(), reservation->size
|
||||
));
|
||||
|
||||
for (int i = 0; i < seq_length; ++i) {
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i]));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(dxDesc[i]));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i]));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(dyDesc[i]));
|
||||
}
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(dhyDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(dcyDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(dhxDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(dcxDesc));
|
||||
|
||||
if (work_space)
|
||||
exe.temp_allocator->free(work_space, work_space_size, work_space_allocation);
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnRnnBackwardXOp : Op {
|
||||
Var* x, * hx, * cx;
|
||||
Var* y, * dy, * dhy, * dcy;
|
||||
Var* w;
|
||||
Var* dx, * dhx, * dcx, * dw;
|
||||
Var* reservation;
|
||||
string mode;
|
||||
int input_size, hidden_size, num_layers, proj_size, batch_size;
|
||||
int seq_length;
|
||||
float dropout;
|
||||
bool bias, bidirectional;
|
||||
|
||||
// @attrs(multiple_outputs)
|
||||
CudnnRnnBackwardXOp(Var* x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional);
|
||||
|
||||
// @attrs(multiple_outputs)
|
||||
CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional);
|
||||
|
||||
void init_rnn();
|
||||
|
||||
const char* name() const override { return "cudnn_rnn_backward_x"; }
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,227 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cudnn_rnn_descriptor.h"
|
||||
#include "cudnn_rnn_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w,
|
||||
string mode, int input_size, int hidden_size, int num_layers, int proj_size,
|
||||
double dropout, bool bias, bool bidirectional, bool is_train)
|
||||
: x(x), hx(hx), cx(cx), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size),
|
||||
num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias),
|
||||
bidirectional(bidirectional), is_train(is_train) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_grads, 1);
|
||||
|
||||
ASSERTop(mode,==,"lstm");
|
||||
ASSERTop(proj_size,==,0);
|
||||
init_rnn();
|
||||
}
|
||||
|
||||
CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* w,
|
||||
string mode, int input_size, int hidden_size, int num_layers, int proj_size,
|
||||
double dropout, bool bias, bool bidirectional, bool is_train)
|
||||
: x(x), hx(hx), cx(nullptr), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size),
|
||||
num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias),
|
||||
bidirectional(bidirectional), is_train(is_train) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_grads, 1);
|
||||
|
||||
ASSERTop(mode,!=,"lstm");
|
||||
ASSERTop(proj_size,==,0);
|
||||
init_rnn();
|
||||
}
|
||||
|
||||
void CudnnRnnOp::init_rnn() {
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
hy = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
if (mode == "lstm")
|
||||
cy = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
else
|
||||
cy = nullptr;
|
||||
|
||||
if (is_train)
|
||||
reservation = create_output(nullptr, ns_float32);
|
||||
else
|
||||
reservation = nullptr;
|
||||
|
||||
seq_length = x->shape[0];
|
||||
batch_size = x->shape[1];
|
||||
}
|
||||
|
||||
void CudnnRnnOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,3);
|
||||
ASSERTop(x->shape[2],==,input_size);
|
||||
|
||||
int num_directions = 1 + bidirectional;
|
||||
|
||||
y->set_shape(NanoVector(seq_length, batch_size, hidden_size * num_directions));
|
||||
|
||||
if (proj_size > 0)
|
||||
hy->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size));
|
||||
else
|
||||
hy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size));
|
||||
|
||||
if (cy)
|
||||
cy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size));
|
||||
|
||||
if (reservation) {
|
||||
int in_dims[3] = {batch_size, input_size, 1};
|
||||
int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1};
|
||||
|
||||
vector<cudnnTensorDescriptor_t> xDesc(seq_length);
|
||||
RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional);
|
||||
for (int i = 0; i < seq_length; ++i) {
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i]));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], CUDNN_DATA_FLOAT, 3, in_dims, in_strides));
|
||||
}
|
||||
reservation->set_shape(rnn_desc.reserve_space_size(xDesc.data(), seq_length));
|
||||
}
|
||||
}
|
||||
|
||||
void CudnnRnnOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x")
|
||||
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
|
||||
static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x")
|
||||
.get_constructor<vector<VarPtr>, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>();
|
||||
|
||||
void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
|
||||
Var *dy = dout[0];
|
||||
Var *dhy = dout[1];
|
||||
Var *dcy = cx ? dout[2] : nullptr;
|
||||
|
||||
vector<VarPtr> dInput;
|
||||
if (cx)
|
||||
dInput = make_backwardx_with_cx(x, hx, cx, y, dy, dhy, dcy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional);
|
||||
else
|
||||
dInput = make_backwardx_without_cx(x, hx, y, dy, dhy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional);
|
||||
|
||||
for (int i = 0; i < 3 + (cx != nullptr); ++i)
|
||||
dins[i] = move(dInput[i]);
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnRnnOp::jit_run() {
|
||||
int num_directions = bidirectional + 1;
|
||||
int num_linear_layers = rnn_string_to_num_linear_layers(mode);
|
||||
|
||||
int in_dims[3] = {batch_size, input_size, 1};
|
||||
int out_dims[3] = {batch_size, hidden_size * num_directions, 1};
|
||||
int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1};
|
||||
int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1};
|
||||
int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size};
|
||||
int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1};
|
||||
|
||||
vector<cudnnTensorDescriptor_t> xDesc(seq_length);
|
||||
vector<cudnnTensorDescriptor_t> yDesc(seq_length);
|
||||
cudnnTensorDescriptor_t hxDesc, cxDesc;
|
||||
cudnnTensorDescriptor_t hyDesc, cyDesc;
|
||||
|
||||
for (int i = 0; i < seq_length; ++i) {
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i]));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i]));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType<Tx>(), 3, in_dims, in_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType<Ty>(), 3, out_dims, out_strides));
|
||||
}
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&hyDesc));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&cyDesc));
|
||||
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(hyDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(cyDesc, getDataType<Tx>(), 3, hidden_dims, hidden_strides));
|
||||
|
||||
RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional);
|
||||
|
||||
void *work_space;
|
||||
size_t work_space_size = rnn_desc.work_space_size(xDesc.data(), seq_length);
|
||||
size_t work_space_allocation;
|
||||
if (work_space_size > 0)
|
||||
work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation);
|
||||
|
||||
RnnWeightDescriptor w_desc(w->size);
|
||||
|
||||
if (is_train) {
|
||||
checkCudaErrors(cudnnRNNForwardTraining(
|
||||
cudnn_handle, rnn_desc.desc,
|
||||
seq_length,
|
||||
xDesc.data(), x->ptr<Tx>(),
|
||||
hxDesc, hx->ptr<Tx>(),
|
||||
cxDesc, mode == "lstm" ? cx->ptr<Tx>() : nullptr,
|
||||
w_desc.desc, w->ptr<Tw>(),
|
||||
yDesc.data(), y->ptr<Ty>(),
|
||||
hyDesc, hy->ptr<Ty>(),
|
||||
cyDesc, mode == "lstm" ? cy->ptr<Ty>() : nullptr,
|
||||
work_space, work_space_size,
|
||||
reservation->ptr<Tx>(), reservation->size
|
||||
));
|
||||
} else {
|
||||
checkCudaErrors(cudnnRNNForwardInference(
|
||||
cudnn_handle, rnn_desc.desc,
|
||||
seq_length,
|
||||
xDesc.data(), x->ptr<Tx>(),
|
||||
hxDesc, hx->ptr<Tx>(),
|
||||
cxDesc, mode == "lstm" ? cx->ptr<Tx>() : nullptr,
|
||||
w_desc.desc, w->ptr<Tw>(),
|
||||
yDesc.data(), y->ptr<Ty>(),
|
||||
hyDesc, hy->ptr<Ty>(),
|
||||
cyDesc, mode == "lstm" ? cy->ptr<Ty>() : nullptr,
|
||||
work_space, work_space_size
|
||||
));
|
||||
}
|
||||
|
||||
for (int i = 0; i < seq_length; i++) {
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i]));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i]));
|
||||
}
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(hyDesc));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(cyDesc));
|
||||
|
||||
if (work_space)
|
||||
exe.temp_allocator->free(work_space, work_space_size, work_space_allocation);
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnRnnOp : Op {
|
||||
Var* x, * hx, * cx, * y, * hy, * cy;
|
||||
Var* w;
|
||||
Var* reservation;
|
||||
string mode;
|
||||
int input_size, hidden_size, num_layers, proj_size;
|
||||
int seq_length, batch_size;
|
||||
float dropout;
|
||||
bool bias, bidirectional, is_train;
|
||||
|
||||
// @attrs(multiple_outputs)
|
||||
CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional);
|
||||
// @attrs(multiple_outputs)
|
||||
CudnnRnnOp(Var* x, Var* hx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional);
|
||||
|
||||
void init_rnn();
|
||||
|
||||
const char* name() const override { return "cudnn_rnn"; }
|
||||
void grads(Var** douts, VarPtr* dins) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,74 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "cudnn_rnn_descriptor.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
vector<int32_t> cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional) {
|
||||
// A pseudo mini-batch for fetching weight space size.
|
||||
int dimX[] = {1, input_size, 1};
|
||||
int strideX[] = {input_size, 1, 1};
|
||||
cudnnTensorDescriptor_t xDesc;
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc));
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc, CUDNN_DATA_FLOAT, 3, dimX, strideX));
|
||||
|
||||
RnnDescriptor rnn_desc = RnnDescriptor(cudnn_handle, mode, hidden_size, num_layers, 0, bidirectional);
|
||||
int weightSpaceSize = rnn_desc.weight_space_size(xDesc);
|
||||
RnnWeightDescriptor w_desc(weightSpaceSize);
|
||||
|
||||
vector<int> weight_offsets;
|
||||
weight_offsets.push_back(weightSpaceSize / sizeof(float));
|
||||
|
||||
int num_directions = bidirectional + 1;
|
||||
int num_linear_layers = rnn_string_to_num_linear_layers(mode);
|
||||
|
||||
for (int layer = 0; layer < num_layers * num_directions; layer++) {
|
||||
for (int linLayerID = 0; linLayerID < num_linear_layers; linLayerID++) {
|
||||
cudnnFilterDescriptor_t linLayerMatDesc;
|
||||
cudnnFilterDescriptor_t linLayerBiasDesc;
|
||||
float *linLayerMat = nullptr;
|
||||
float *linLayerBias = nullptr;
|
||||
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerMatDesc));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerBiasDesc));
|
||||
|
||||
checkCudaErrors(cudnnGetRNNLinLayerMatrixParams(
|
||||
cudnn_handle, rnn_desc.desc,
|
||||
layer,
|
||||
xDesc,
|
||||
w_desc.desc,
|
||||
nullptr,
|
||||
linLayerID,
|
||||
linLayerMatDesc,
|
||||
(void **) &linLayerMat
|
||||
));
|
||||
weight_offsets.push_back(linLayerMat - (float *) nullptr);
|
||||
|
||||
if (bias) {
|
||||
checkCudaErrors(cudnnGetRNNLinLayerBiasParams(
|
||||
cudnn_handle, rnn_desc.desc,
|
||||
layer,
|
||||
xDesc,
|
||||
w_desc.desc,
|
||||
nullptr,
|
||||
linLayerID,
|
||||
linLayerBiasDesc,
|
||||
(void **) &linLayerBias
|
||||
));
|
||||
weight_offsets.push_back(linLayerBias - (float *) nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc));
|
||||
|
||||
return weight_offsets;
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -214,7 +214,8 @@ def t(x):
|
|||
return x.transpose(*pose)
|
||||
jt.Var.t = t
|
||||
|
||||
def median(x,dim=None,keepdim=False):
|
||||
def median(x,dim=None,keepdim=False, keepdims=False):
|
||||
keepdim = keepdim or keepdims
|
||||
if dim is None:
|
||||
x = x.reshape(-1)
|
||||
dim=0
|
||||
|
@ -637,7 +638,8 @@ def topk(input, k, dim=None, largest=True, sorted=True):
|
|||
|
||||
jt.Var.topk = topk
|
||||
|
||||
def kthvalue(input, k, dim=None, keepdim=False):
|
||||
def kthvalue(input, k, dim=None, keepdim=False, keepdims=False):
|
||||
keepdim = keepdim or keepdims
|
||||
if dim is None:
|
||||
dim = -1
|
||||
if dim<0:
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from abc import abstractmethod
|
||||
from sys import breakpointhook
|
||||
import jittor as jt
|
||||
from jittor import init, Module
|
||||
from jittor import flatten, init, Module
|
||||
import numpy as np
|
||||
import collections
|
||||
import math
|
||||
|
@ -1881,6 +1882,7 @@ class RNNCell(jt.Module):
|
|||
|
||||
return y
|
||||
|
||||
|
||||
class GRUCell(jt.Module):
|
||||
def __init__(self, input_size, hidden_size, bias=True):
|
||||
''' A gated recurrent unit (GRU) cell.
|
||||
|
@ -1941,7 +1943,7 @@ class RNNBase(Module):
|
|||
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
||||
dropout: float = 0, bidirectional: bool = False,
|
||||
proj_size: int = 0) -> None:
|
||||
proj_size: int = 0, nonlinearity: str = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
|
@ -1953,6 +1955,7 @@ class RNNBase(Module):
|
|||
self.dropout = dropout
|
||||
self.bidirectional = bidirectional
|
||||
self.proj_size = proj_size
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
if mode == 'LSTM':
|
||||
gate_size = 4 * hidden_size
|
||||
|
@ -1994,6 +1997,56 @@ class RNNBase(Module):
|
|||
build_unit(f'bias_ih_l{layer}', gate_size)
|
||||
build_unit(f'bias_hh_l{layer}', gate_size)
|
||||
|
||||
def _cudnn_flatten_weights(self, cudnn_mode):
|
||||
def copy_to_flatten_weight(param_name, offset_idx, num_gates):
|
||||
def copy_to(param_name, offset_idx, idx):
|
||||
cur_offset = self._cudnn_weight_offset[offset_idx]
|
||||
param = getattr(self, param_name)
|
||||
param = param[self.hidden_size * idx: self.hidden_size * (idx + 1)]
|
||||
ft_weight[cur_offset:cur_offset + param.numel()] = param.flatten()
|
||||
|
||||
if self.bias:
|
||||
for idx in range(num_gates):
|
||||
copy_to('weight' + param_name, offset_idx + idx * 2, idx)
|
||||
copy_to('bias' + param_name, offset_idx + idx * 2 + 1, idx)
|
||||
return num_gates * 2
|
||||
else:
|
||||
for idx in range(num_gates):
|
||||
copy_to('weight' + param_name, offset_idx + idx, idx)
|
||||
return num_gates
|
||||
|
||||
if jt.flags.use_cuda and jt.cudnn:
|
||||
if getattr(self, '_cudnn_weight_size', None) is None:
|
||||
offset_array = jt.cudnn.cudnn_rnn_weight_offset(
|
||||
cudnn_mode,
|
||||
self.input_size,
|
||||
self.hidden_size,
|
||||
self.num_layers,
|
||||
self.proj_size,
|
||||
self.bias,
|
||||
self.bidirectional
|
||||
)
|
||||
self._cudnn_weight_size = offset_array[0]
|
||||
self._cudnn_weight_offset = offset_array[1:]
|
||||
|
||||
num_gates = {
|
||||
"RNN": 1, "LSTM": 4, "GRU": 3
|
||||
}[self.mode]
|
||||
ft_weight = jt.zeros(self._cudnn_weight_size, dtype=jt.float32)
|
||||
|
||||
cnt = 0
|
||||
for layer in range(self.num_layers):
|
||||
suffix = ''
|
||||
cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates)
|
||||
cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates)
|
||||
if self.bidirectional:
|
||||
suffix = '_reverse'
|
||||
cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates)
|
||||
cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates)
|
||||
return ft_weight
|
||||
else:
|
||||
raise RuntimeError("Not Cudnn found")
|
||||
|
||||
@abstractmethod
|
||||
def call_rnn_cell(self, input, hidden, suffix):
|
||||
pass
|
||||
|
@ -2013,15 +2066,44 @@ class RNNBase(Module):
|
|||
|
||||
return output, hidden
|
||||
|
||||
def execute(self, input, hx):
|
||||
def _execute_cudnn_rnn(self, input, hx):
|
||||
cudnn_mode = {
|
||||
('RNN', 'tanh'): 'tanh',
|
||||
('RNN', 'relu'): 'relu',
|
||||
('LSTM', None): 'lstm',
|
||||
('GRU', None): 'gru'
|
||||
}[(self.mode, self.nonlinearity)]
|
||||
ft_weight = self._cudnn_flatten_weights(cudnn_mode)
|
||||
|
||||
if self.mode == 'LSTM':
|
||||
ret = jt.cudnn.ops.cudnn_rnn(input, hx[0], hx[1], ft_weight,
|
||||
cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0,
|
||||
self.dropout, self.bias, self.bidirectional, self.is_training()
|
||||
)
|
||||
return ret[0], (ret[1], ret[2])
|
||||
else:
|
||||
ret = jt.cudnn.ops.cudnn_rnn(input, hx, ft_weight,
|
||||
cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0,
|
||||
self.dropout, self.bias, self.bidirectional, self.is_training()
|
||||
)
|
||||
return ret[0], ret[1]
|
||||
|
||||
def execute(self, input, hx=None):
|
||||
if self.batch_first:
|
||||
input = input.permute(1, 0, 2)
|
||||
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
|
||||
if hx is None:
|
||||
hx = self.default_init_state()
|
||||
if self.mode in ['RNN', 'GRU']:
|
||||
hx = jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)
|
||||
elif self.mode == 'LSTM':
|
||||
hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype),
|
||||
jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype))
|
||||
|
||||
if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0:
|
||||
return self._execute_cudnn_rnn(input, hx)
|
||||
else:
|
||||
hidden_n = []
|
||||
|
||||
for l in range(self.num_layers):
|
||||
|
@ -2059,9 +2141,6 @@ class RNNBase(Module):
|
|||
|
||||
|
||||
class RNN(RNNBase):
|
||||
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
||||
nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False,
|
||||
dropout: float = 0, bidirectional: bool = False) -> None:
|
||||
''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence.
|
||||
|
||||
:param input_size: The number of expected features in the input.
|
||||
|
@ -2094,6 +2173,9 @@ class RNN(RNNBase):
|
|||
>>> h0 = jt.randn(2, 3, 20)
|
||||
>>> output, hn = rnn(input, h0)
|
||||
'''
|
||||
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
||||
nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False,
|
||||
dropout: float = 0, bidirectional: bool = False) -> None:
|
||||
super().__init__('RNN', input_size, hidden_size, num_layers=num_layers,
|
||||
bias=bias, batch_first=batch_first, dropout=dropout,
|
||||
bidirectional=bidirectional)
|
||||
|
@ -2112,14 +2194,12 @@ class RNN(RNNBase):
|
|||
if self.nonlinearity == 'tanh':
|
||||
h = jt.tanh(y)
|
||||
else:
|
||||
h = jt.relu(y)
|
||||
h = jt.nn.relu(y)
|
||||
|
||||
return h, h
|
||||
|
||||
|
||||
class LSTM(RNNBase):
|
||||
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
|
||||
batch_first=False, dropout=0, bidirectional=False, proj_size=0):
|
||||
''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
|
||||
|
||||
:param input_size: The number of expected features in the input.
|
||||
|
@ -2153,6 +2233,9 @@ class LSTM(RNNBase):
|
|||
>>> c0 = jt.randn(2, 3, 20)
|
||||
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||
'''
|
||||
|
||||
def __init__(self, input_size, hidden_size, num_layers=1, bias=True,
|
||||
batch_first=False, dropout=0, bidirectional=False, proj_size=0):
|
||||
super().__init__('LSTM', input_size, hidden_size, num_layers=num_layers,
|
||||
bias=bias, batch_first=batch_first, dropout=dropout,
|
||||
bidirectional=bidirectional, proj_size=proj_size)
|
||||
|
@ -2179,9 +2262,6 @@ class LSTM(RNNBase):
|
|||
|
||||
|
||||
class GRU(RNNBase):
|
||||
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
||||
bias: bool = True, batch_first: bool = False, dropout: float = 0,
|
||||
bidirectional: bool = False) -> None:
|
||||
''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
|
||||
|
||||
:param input_size: The number of expected features in the input.
|
||||
|
@ -2211,6 +2291,10 @@ class GRU(RNNBase):
|
|||
>>> h0 = jt.randn(2, 3, 20)
|
||||
>>> output, hn = rnn(input, h0)
|
||||
'''
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1,
|
||||
bias: bool = True, batch_first: bool = False, dropout: float = 0,
|
||||
bidirectional: bool = False) -> None:
|
||||
super().__init__('GRU', input_size, hidden_size, num_layers=num_layers,
|
||||
bias=bias, batch_first=batch_first, dropout=dropout,
|
||||
bidirectional=bidirectional)
|
||||
|
|
|
@ -36,7 +36,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -67,7 +67,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -98,7 +98,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -129,7 +129,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -160,7 +160,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -191,7 +191,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
@ -226,7 +226,7 @@ unordered_set<string> reduce_ops = {
|
|||
|
||||
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
|
||||
|
||||
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
* [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
|
||||
|
||||
----------------
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ struct ReindexReduceOp : Op {
|
|||
|
||||
* [in] shape: the output shape, a integer array
|
||||
|
||||
* [in] indexes: array of c++ style integer expression, its length should be the same with length of shape, some buildin variables it can use are::
|
||||
* [in] indexes: array of c++ style integer expression, its length should be the same with length of output shape, some buildin variables it can use are::
|
||||
|
||||
XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem
|
||||
YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden
|
||||
|
|
|
@ -7,15 +7,8 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import jittor.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
from unittest.case import skipIf
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
|
@ -23,72 +16,49 @@ except:
|
|||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
import jittor as jt
|
||||
import jittor.nn as nn
|
||||
import numpy as np
|
||||
|
||||
def check_equal_1(t_rnn, j_rnn, input, h0):
|
||||
skip_this_test = False
|
||||
|
||||
def check_equal_1(t_rnn, j_rnn, input, h0, dev=None):
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
|
||||
if dev:
|
||||
t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev))
|
||||
|
||||
else:
|
||||
t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0))
|
||||
t_output = t_output.detach().cpu().numpy()
|
||||
th = th.detach().cpu().numpy()
|
||||
|
||||
j_output, jh = j_rnn(jt.float32(input), jt.float32(h0))
|
||||
j_output, jh = j_output.data, jh.data
|
||||
|
||||
assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||
assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(t_output, j_output.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(th, jh.data, rtol=1e-03, atol=1e-06)
|
||||
|
||||
|
||||
def check_equal_2(t_rnn, j_rnn, input, h0, c0):
|
||||
def check_equal_2(t_rnn, j_rnn, input, h0, c0, dev=None):
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input),
|
||||
if dev:
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev),
|
||||
(torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev)))
|
||||
else:
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev),
|
||||
(torch.from_numpy(h0), torch.from_numpy(c0)))
|
||||
|
||||
j_output, (jh, jc) = j_rnn(jt.float32(input),
|
||||
(jt.float32(h0), jt.float32(c0)))
|
||||
|
||||
assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||
assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||
assert np.allclose(tc.detach().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
||||
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestRNN(unittest.TestCase):
|
||||
def test_rnn(self):
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200)
|
||||
j_rnn = nn.RNN(100, 200)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=4)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=4)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
h0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 1, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, bidirectional=True)
|
||||
j_rnn = nn.RNN(100, 200, bidirectional=True)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
h0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False)
|
||||
t_rnn.eval()
|
||||
j_rnn.eval()
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_lstm_cell(self):
|
||||
np_h0 = torch.randn(3, 20).numpy()
|
||||
np_c0 = torch.randn(3, 20).numpy()
|
||||
|
@ -173,7 +143,49 @@ class TestRNN(unittest.TestCase):
|
|||
j_output = j_output.data
|
||||
assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06)
|
||||
|
||||
def test_lstm(self):
|
||||
def test_basic_rnn(self):
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200)
|
||||
j_rnn = nn.RNN(100, 200)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_multilayer_rnn(self):
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=4)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=4)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_bidirectional_rnn(self):
|
||||
h0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 1, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, bidirectional=True)
|
||||
j_rnn = nn.RNN(100, 200, bidirectional=True)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_no_bias_rnn(self):
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_dropout_rnn(self):
|
||||
h0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False)
|
||||
t_rnn.eval()
|
||||
j_rnn.eval()
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_basic_lstm(self):
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
|
@ -182,6 +194,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(100, 200)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
def test_projection_lstm(self):
|
||||
proj_size = 13
|
||||
h0 = np.random.rand(1, 24, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
|
@ -190,6 +203,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(100, 200, proj_size=proj_size)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
def test_multilayer_lstm(self):
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
c0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
@ -198,6 +212,8 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(100, 200, num_layers=4)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
def test_multilayer_projection_lstm(self):
|
||||
proj_size = 8
|
||||
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 20).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 10).astype(np.float32)
|
||||
|
@ -206,6 +222,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
def test_bidirectional_lstm(self):
|
||||
h0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
c0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 1, 100).astype(np.float32)
|
||||
|
@ -214,7 +231,8 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(100, 200, bidirectional=True)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
proj_size = 13
|
||||
def test_bidirectional_projection_lstm(self):
|
||||
proj_size = 10
|
||||
h0 = np.random.rand(2, 4, proj_size).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
@ -223,6 +241,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
def test_multilayer_bidirectional_projection_lstm(self):
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
c0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
@ -231,7 +250,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
|
||||
def test_dropout_lstm(self):
|
||||
h0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
c0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
@ -242,7 +261,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn.eval()
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0)
|
||||
|
||||
def test_gru(self):
|
||||
def test_basic_gru(self):
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
|
||||
|
@ -250,6 +269,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.GRU(100, 200)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_multilayer_gru(self):
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
|
@ -257,6 +277,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.GRU(100, 200, num_layers=4)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_bidirectional_gru(self):
|
||||
h0 = np.random.rand(2, 1, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 1, 100).astype(np.float32)
|
||||
|
||||
|
@ -264,6 +285,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.GRU(100, 200, bidirectional=True)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_multilayer_bidirectional_gru(self):
|
||||
h0 = np.random.rand(4, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
|
@ -271,6 +293,7 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn = nn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_multilayer_dropout_gru(self):
|
||||
h0 = np.random.rand(2, 4, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 4, 100).astype(np.float32)
|
||||
|
||||
|
@ -280,6 +303,330 @@ class TestRNN(unittest.TestCase):
|
|||
j_rnn.eval()
|
||||
check_equal_1(t_rnn, j_rnn, input, h0)
|
||||
|
||||
def test_rnn_default_hx(self):
|
||||
input = np.random.rand(32, 24, 12).astype(np.float32)
|
||||
h0 = np.zeros((1, 24, 24)).astype(np.float32)
|
||||
|
||||
t_rnn = tnn.RNN(12, 24)
|
||||
j_rnn = nn.RNN(12, 24)
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
t_output, th = t_rnn(torch.from_numpy(input))
|
||||
j_output, jh = j_rnn(jt.array(input))
|
||||
|
||||
np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||
|
||||
def test_lstm_default_hx(self):
|
||||
input = np.random.rand(32, 24, 10).astype(np.float32)
|
||||
t_rnn = tnn.LSTM(10, 20, num_layers=2, bidirectional=True)
|
||||
j_rnn = nn.LSTM(10, 20, num_layers=2, bidirectional=True)
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input))
|
||||
j_output, (jh, jc) = j_rnn(jt.array(input))
|
||||
np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06)
|
||||
np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06)
|
||||
|
||||
@skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev)
|
||||
|
||||
j_rnn = nn.RNN(100, 200, nonlinearity='relu')
|
||||
j_rnn.train()
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
|
||||
t_output, th = t_rnn(torch.from_numpy(input).to(dev),
|
||||
torch.from_numpy(h0).to(dev))
|
||||
|
||||
j_output, jh = j_rnn(jt.array(input), jt.array(h0))
|
||||
|
||||
np.testing.assert_allclose(j_output.data, t_output.detach().cpu().numpy())
|
||||
np.testing.assert_allclose(jh.data, th.detach().cpu().numpy())
|
||||
|
||||
@skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_rnn_train(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev)
|
||||
t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
j_rnn = nn.RNN(32, 64, nonlinearity='relu')
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
h0 = np.random.rand(1, 4, 64).astype(np.float32)
|
||||
input = np.random.rand(12, 4, 32).astype(np.float32)
|
||||
|
||||
for _ in range(10):
|
||||
t_optim.zero_grad()
|
||||
t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev))
|
||||
t_loss = (t_output ** 2).sum() + (th ** 2).sum()
|
||||
t_loss.backward()
|
||||
t_optim.step()
|
||||
|
||||
j_input, jh = jt.array(input), jt.array(h0)
|
||||
j_output, jh = j_rnn(j_input, jh)
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
|
||||
np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-2)
|
||||
np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-3, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_basic_cudnn_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev)
|
||||
j_rnn = nn.RNN(100, 200, nonlinearity='relu')
|
||||
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0, dev)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_multilayer_cudnn_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=4, nonlinearity='tanh').to(dev)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=4, nonlinearity='tanh')
|
||||
|
||||
h0 = np.random.rand(4, 8, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 8, 100).astype(np.float32)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0, dev)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_bidirectional_cudnn_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(100, 200, bidirectional=True, nonlinearity='tanh').to(dev)
|
||||
j_rnn = nn.RNN(100, 200, bidirectional=True, nonlinearity='tanh')
|
||||
|
||||
h0 = np.random.rand(2, 8, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 8, 100).astype(np.float32)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0, dev)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_no_bias_cudnn_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh').to(dev)
|
||||
j_rnn = nn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh')
|
||||
|
||||
h0 = np.random.rand(2, 8, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 8, 100).astype(np.float32)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0, dev)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_dropout_cudnn_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh').to(dev)
|
||||
j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh')
|
||||
t_rnn.eval()
|
||||
j_rnn.eval()
|
||||
|
||||
h0 = np.random.rand(2, 8, 200).astype(np.float32)
|
||||
input = np.random.rand(5, 8, 100).astype(np.float32)
|
||||
check_equal_1(t_rnn, j_rnn, input, h0, dev)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_basic_lstm_rnn(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.LSTM(100, 200).to(dev)
|
||||
j_rnn = nn.LSTM(100, 200)
|
||||
|
||||
h0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
c0 = np.random.rand(1, 24, 200).astype(np.float32)
|
||||
input = np.random.rand(32, 24, 100).astype(np.float32)
|
||||
check_equal_2(t_rnn, j_rnn, input, h0, c0, dev)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_rnn_train(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev)
|
||||
t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
j_rnn = nn.RNN(32, 64, nonlinearity='relu')
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
h0 = np.random.rand(1, 4, 64).astype(np.float32)
|
||||
input = np.random.rand(12, 4, 32).astype(np.float32)
|
||||
|
||||
for _ in range(10):
|
||||
t_optim.zero_grad()
|
||||
t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev))
|
||||
t_loss = (t_output ** 2).sum() + (th ** 2).sum()
|
||||
t_loss.backward()
|
||||
t_optim.step()
|
||||
|
||||
j_input, jh = jt.array(input), jt.array(h0)
|
||||
j_output, jh = j_rnn(j_input, jh)
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
|
||||
np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4)
|
||||
np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_gru_train(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.GRU(32, 64).to(dev)
|
||||
t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
j_rnn = nn.GRU(32, 64)
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
h0 = np.random.rand(1, 4, 64).astype(np.float32)
|
||||
input = np.random.rand(12, 4, 32).astype(np.float32)
|
||||
|
||||
for _ in range(10):
|
||||
t_optim.zero_grad()
|
||||
t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev))
|
||||
t_loss = (t_output ** 2).sum() + (th ** 2).sum()
|
||||
t_loss.backward()
|
||||
t_optim.step()
|
||||
|
||||
j_input, jh = jt.array(input), jt.array(h0)
|
||||
j_output, jh = j_rnn(j_input, jh)
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
|
||||
np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4)
|
||||
np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_lstm_train(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.LSTM(32, 64).to(dev)
|
||||
t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
j_rnn = nn.LSTM(32, 64)
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
h0 = np.random.rand(1, 4, 64).astype(np.float32)
|
||||
c0 = np.random.rand(1, 4, 64).astype(np.float32)
|
||||
input = np.random.rand(12, 4, 32).astype(np.float32)
|
||||
|
||||
for _ in range(10):
|
||||
t_optim.zero_grad()
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev),
|
||||
(torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev)))
|
||||
t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum()
|
||||
t_loss.backward()
|
||||
t_optim.step()
|
||||
|
||||
j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0)
|
||||
j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0))
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
|
||||
np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4)
|
||||
np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@unittest.skipIf(not jt.cudnn, "No Cudnn found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_multilayer_bidirectional_cudnn_lstm_train(self):
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.LSTM(32, 64, num_layers=4, bidirectional=True).to(dev)
|
||||
t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
j_rnn = nn.LSTM(32, 64, num_layers=4, bidirectional=True)
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
h0 = np.random.rand(8, 4, 64).astype(np.float32)
|
||||
c0 = np.random.rand(8, 4, 64).astype(np.float32)
|
||||
input = np.random.rand(12, 4, 32).astype(np.float32)
|
||||
|
||||
for _ in range(10):
|
||||
t_optim.zero_grad()
|
||||
t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev),
|
||||
(torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev)))
|
||||
t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum()
|
||||
t_loss.backward()
|
||||
t_optim.step()
|
||||
|
||||
j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0)
|
||||
j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0))
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
|
||||
np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4)
|
||||
np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cudnn_rnn_speed(self):
|
||||
from time import time
|
||||
iters = 100
|
||||
|
||||
h0 = np.random.rand(1, 128, 256).astype(np.float32)
|
||||
input = np.random.rand(128, 128, 128).astype(np.float32)
|
||||
|
||||
dev = torch.device('cuda:0')
|
||||
t_rnn = tnn.RNN(128, 256, nonlinearity='relu').to(dev)
|
||||
t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
|
||||
t_input = torch.from_numpy(input).to(dev)
|
||||
t_h0 = torch.from_numpy(h0).to(dev)
|
||||
|
||||
start_time = time()
|
||||
for i in range(iters):
|
||||
t_optim.zero_grad()
|
||||
t_output, th = t_rnn(t_input, t_h0)
|
||||
t_loss = (t_output ** 2).sum() + (th ** 2).sum()
|
||||
t_loss.backward()
|
||||
t_optim.step()
|
||||
print('torch time = ', time() - start_time)
|
||||
|
||||
j_rnn = nn.RNN(128, 256, nonlinearity='relu')
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
j_input, j_h0 = jt.array(input), jt.array(h0)
|
||||
|
||||
start_time = time()
|
||||
for i in range(iters):
|
||||
j_output, jh = j_rnn(j_input, j_h0)
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
jt.sync_all(True)
|
||||
print('jittor Cudnn time = ', time() - start_time)
|
||||
|
||||
jt_cudnn, jt.cudnn = jt.cudnn, None
|
||||
j_rnn = nn.RNN(128, 256, nonlinearity='relu')
|
||||
j_rnn.load_state_dict(t_rnn.state_dict())
|
||||
j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9)
|
||||
start_time = time()
|
||||
for i in range(iters):
|
||||
j_output, jh = j_rnn(j_input, j_h0)
|
||||
j_loss = (j_output ** 2).sum() + (jh ** 2).sum()
|
||||
j_optim.step(j_loss)
|
||||
jt.sync_all(True)
|
||||
print('jittor native time = ', time() - start_time)
|
||||
jt.cudnn = jt_cudnn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue