diff --git a/README.cn.md b/README.cn.md index ee37597c..bb1bcd4a 100644 --- a/README.cn.md +++ b/README.cn.md @@ -1,5 +1,6 @@ # Jittor: 即时编译深度学习框架 +![Jittor Logo](doc/logo.png) [快速开始](#快速开始) | [安装](#安装) | [教程](#教程) diff --git a/README.md b/README.md index 456b677a..06b4b5e5 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Jittor: a Just-in-time(JIT) deep learning framework +![Jittor Logo](doc/logo.png) + [Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md) diff --git a/README.src.md b/README.src.md index 73c9bf83..4ab7614c 100644 --- a/README.src.md +++ b/README.src.md @@ -1,6 +1,8 @@ # Jittor: a Just-in-time(JIT) deep learning framework # Jittor: 即时编译深度学习框架 +![Jittor Logo](doc/logo.png) + [Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md) [快速开始](#快速开始) | [安装](#安装) | [教程](#教程) diff --git a/doc/logo.png b/doc/logo.png new file mode 100644 index 00000000..7bbc7488 Binary files /dev/null and b/doc/logo.png differ diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h new file mode 100644 index 00000000..977aafb9 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h @@ -0,0 +1,154 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "init.h" + + +namespace jittor { + +static inline cudnnRNNMode_t rnn_string_to_rnn_mode(string mode) { + if (mode == "relu") + return CUDNN_RNN_RELU; + if (mode == "tanh") + return CUDNN_RNN_TANH; + if (mode == "lstm") + return CUDNN_LSTM; + ASSERT(mode == "gru") << "rnn mode must be relu, tanh, lstm, or gru, but got " << mode; + return CUDNN_GRU; +} + +static inline int rnn_string_to_num_linear_layers(string mode) { + if (mode == "relu") + return 2; + if (mode == "tanh") + return 2; + if (mode == "lstm") + return 8; + ASSERT(mode == "gru") << "mode must be relu, tanh, lstm, or gru, but got " << mode; + return 6; +} + +/** A wrapper for CUDNN dropout descriptor + */ +struct DropoutDescriptor { + cudnnDropoutDescriptor_t desc; + size_t stateSize, stateAllocation; + float dropout; + void *stateSpace; + + DropoutDescriptor(cudnnHandle_t handle, float dropout) + : dropout(dropout), stateSpace(nullptr) { + checkCudaErrors(cudnnCreateDropoutDescriptor(&desc)); + if (dropout > 0) { + checkCudaErrors(cudnnDropoutGetStatesSize(handle, &stateSize)); + stateSpace = exe.temp_allocator->alloc(stateSize, stateAllocation); + checkCudaErrors(cudnnSetDropoutDescriptor( + desc, + cudnn_handle, + dropout, + stateSpace, + stateSize, + get_seed() + )); + } else { + checkCudaErrors(cudnnSetDropoutDescriptor( + desc, handle, 0, nullptr, 0, 0 + )); + } + } + ~DropoutDescriptor() { + checkCudaErrors(cudnnDestroyDropoutDescriptor(desc)); + if (stateSpace) + exe.temp_allocator->free(stateSpace, stateSize, stateAllocation); + } +}; + +/** A wrapper for CUDNN RNN descriptor + */ +struct RnnDescriptor { + cudnnHandle_t handle; + cudnnRNNDescriptor_t desc; + DropoutDescriptor dropoutDesc; + + RnnDescriptor(cudnnHandle_t handle, string mode, int hidden_size, int num_layers, + float dropout, bool bidirectional) : handle(handle), dropoutDesc(handle, dropout) { + checkCudaErrors(cudnnCreateRNNDescriptor(&desc)); + checkCudaErrors(cudnnSetRNNDescriptor_v6( + handle, + desc, + hidden_size, + num_layers, + dropoutDesc.desc, + CUDNN_LINEAR_INPUT, + bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, + rnn_string_to_rnn_mode(mode), + CUDNN_RNN_ALGO_STANDARD, + CUDNN_DATA_FLOAT + )); + } + + ~RnnDescriptor() { + checkCudaErrors(cudnnDestroyRNNDescriptor(desc)); + } + + size_t weight_space_size(const cudnnTensorDescriptor_t &xDesc) { + size_t size; + checkCudaErrors(cudnnGetRNNParamsSize( + handle, desc, xDesc, &size, CUDNN_DATA_FLOAT + )); + return size; + } + + size_t work_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) { + size_t size; + checkCudaErrors(cudnnGetRNNWorkspaceSize( + handle, desc, seq_length, xDesc, &size + )); + return size; + } + + size_t reserve_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) { + size_t size; + checkCudaErrors(cudnnGetRNNTrainingReserveSize( + handle, desc, seq_length, xDesc, &size + )); + return size; + } +}; + +/** + */ +struct RnnWeightDescriptor { + cudnnFilterDescriptor_t desc; + size_t size; + RnnWeightDescriptor(size_t size) : size(size) { + int dimW[3] = {(int) (size / sizeof(float)), 1, 1}; + checkCudaErrors(cudnnCreateFilterDescriptor(&desc)); + checkCudaErrors(cudnnSetFilterNdDescriptor(desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dimW)); + } + ~RnnWeightDescriptor() { + cudnnDestroyFilterDescriptor(desc); + } +}; + +/** + Returns offsets of RNN linear parameters in a flatten array. + + Returns + ======= + list: [total size, param #1 offset, param #2 offset, ...] + + TODO: support cudnn rnn-v8; support proj_size + */ +// @pyjt(cudnn_rnn_weight_offset) +vector cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional); + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc new file mode 100644 index 00000000..347cf5bf --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc @@ -0,0 +1,195 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_rnn_descriptor.h" +#include "cudnn_rnn_backward_x_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var *x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional) + : x(x), hx(hx), cx(cx), y(y), dy(dy), dhy(dhy), dcy(dcy), w(w), reservation(reservation), + mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers), + proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) { + + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + + ASSERTop(mode,==,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional) + : x(x), hx(hx), cx(nullptr), y(y), dy(dy), dhy(dhy), dcy(nullptr), w(w), reservation(reservation), + mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers), + proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) { + + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + + ASSERTop(mode,!=,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +void CudnnRnnBackwardXOp::init_rnn() { + dx = create_output(nullptr, ns_float32); + dhx = create_output(nullptr, ns_float32); + + if (mode == "lstm") + dcx = create_output(nullptr, ns_float32); + else + dcx = nullptr; + + dw = create_output(nullptr, dtype_infer(x->ns, y->ns)); + + seq_length = y->shape[0]; + batch_size = y->shape[1]; +} + +void CudnnRnnBackwardXOp::infer_shape() { + dx->set_shape(NanoVector(seq_length, batch_size, input_size)); + + int num_directions = 1 + bidirectional; + if (proj_size > 0) + dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size)); + else + dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (dcx) + dcx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + dw->set_shape(w->shape); +} + +void CudnnRnnBackwardXOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << hx->dtype(); + jk << _CS("][Ty:") << y->dtype(); + jk << _CS("][Tw:") << w->dtype(); + jk << ']'; +} + +#else // JIT +#ifdef JIT_cuda + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnRnnBackwardXOp::jit_run() { + int num_directions = 1 + bidirectional; + + int in_dims[3] = {batch_size, input_size, 1}; + int out_dims[3] = {batch_size, hidden_size * num_directions, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1}; + int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size}; + int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1}; + + vector xDesc(seq_length), dxDesc(seq_length); + vector yDesc(seq_length), dyDesc(seq_length); + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&dxDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&dyDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dxDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType(), 3, out_dims, out_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dyDesc[i], getDataType(), 3, out_dims, out_strides)); + } + + cudnnTensorDescriptor_t dhyDesc, dcyDesc; + cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc; + checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dhxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dcxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dhyDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dcyDesc)); + checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dhxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dcxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dhyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dcyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + RnnWeightDescriptor w_desc(w->size); + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + + void *work_space; + size_t work_space_size = rnn_desc.work_space_size(dxDesc.data(), seq_length); + size_t work_space_allocation; + if (work_space_size > 0) + work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation); + + size_t reserveSpaceSize = reservation->size; + + checkCudaErrors(cudnnRNNBackwardData( + cudnn_handle, rnn_desc.desc, + seq_length, + yDesc.data(), y->ptr(), + dyDesc.data(), dy->ptr(), + dhyDesc, dhy->ptr(), + dcyDesc, mode == "lstm" ? dcy->ptr(): nullptr, + w_desc.desc, w->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + dxDesc.data(), dx->ptr(), + dhxDesc, dhx->ptr(), + dcxDesc, mode == "lstm" ? dcx->ptr() : nullptr, + work_space, work_space_size, + reservation->ptr(), reservation->size + )); + + checkCudaErrors(cudaMemset(dw->ptr(), 0, dw->size)); + + checkCudaErrors(cudnnRNNBackwardWeights( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + yDesc.data(), y->ptr(), + work_space, work_space_size, + w_desc.desc, dw->ptr(), + reservation->ptr(), reservation->size + )); + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(dxDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(dyDesc[i])); + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(dhyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dcyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dhxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dcxDesc)); + + if (work_space) + exe.temp_allocator->free(work_space, work_space_size, work_space_allocation); +} + +#endif +#endif // JIT +} diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h new file mode 100644 index 00000000..b4064131 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnRnnBackwardXOp : Op { + Var* x, * hx, * cx; + Var* y, * dy, * dhy, * dcy; + Var* w; + Var* dx, * dhx, * dcx, * dw; + Var* reservation; + string mode; + int input_size, hidden_size, num_layers, proj_size, batch_size; + int seq_length; + float dropout; + bool bias, bidirectional; + + // @attrs(multiple_outputs) + CudnnRnnBackwardXOp(Var* x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional); + + // @attrs(multiple_outputs) + CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional); + + void init_rnn(); + + const char* name() const override { return "cudnn_rnn_backward_x"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc new file mode 100644 index 00000000..6772a6f8 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc @@ -0,0 +1,227 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_rnn_descriptor.h" +#include "cudnn_rnn_op.h" +#include "cudnn_warper.h" +#include "executor.h" +#include "ops/op_register.h" + +using namespace std; + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional, bool is_train) + : x(x), hx(hx), cx(cx), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size), + num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias), + bidirectional(bidirectional), is_train(is_train) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_grads, 1); + + ASSERTop(mode,==,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* w, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional, bool is_train) + : x(x), hx(hx), cx(nullptr), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size), + num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias), + bidirectional(bidirectional), is_train(is_train) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_grads, 1); + + ASSERTop(mode,!=,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +void CudnnRnnOp::init_rnn() { + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); + hy = create_output(nullptr, dtype_infer(x->ns, w->ns)); + if (mode == "lstm") + cy = create_output(nullptr, dtype_infer(x->ns, w->ns)); + else + cy = nullptr; + + if (is_train) + reservation = create_output(nullptr, ns_float32); + else + reservation = nullptr; + + seq_length = x->shape[0]; + batch_size = x->shape[1]; +} + +void CudnnRnnOp::infer_shape() { + ASSERTop(x->shape.size(),==,3); + ASSERTop(x->shape[2],==,input_size); + + int num_directions = 1 + bidirectional; + + y->set_shape(NanoVector(seq_length, batch_size, hidden_size * num_directions)); + + if (proj_size > 0) + hy->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size)); + else + hy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (cy) + cy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (reservation) { + int in_dims[3] = {batch_size, input_size, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + + vector xDesc(seq_length); + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], CUDNN_DATA_FLOAT, 3, in_dims, in_strides)); + } + reservation->set_shape(rnn_desc.reserve_space_size(xDesc.data(), seq_length)); + } +} + +void CudnnRnnOp::jit_prepare(JK& jk) { + jk << _CS("[Tx:") << x->dtype(); + jk << _CS("][Ty:") << y->dtype(); + jk << _CS("][Tw:") << w->dtype(); + jk << ']'; +} + +static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x") + .get_constructor, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>(); +static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x") + .get_constructor, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>(); + +void CudnnRnnOp::grads(Var** dout, VarPtr* dins) { + Var *dy = dout[0]; + Var *dhy = dout[1]; + Var *dcy = cx ? dout[2] : nullptr; + + vector dInput; + if (cx) + dInput = make_backwardx_with_cx(x, hx, cx, y, dy, dhy, dcy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional); + else + dInput = make_backwardx_without_cx(x, hx, y, dy, dhy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional); + + for (int i = 0; i < 3 + (cx != nullptr); ++i) + dins[i] = move(dInput[i]); +} + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnRnnOp::jit_run() { + int num_directions = bidirectional + 1; + int num_linear_layers = rnn_string_to_num_linear_layers(mode); + + int in_dims[3] = {batch_size, input_size, 1}; + int out_dims[3] = {batch_size, hidden_size * num_directions, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1}; + int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size}; + int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1}; + + vector xDesc(seq_length); + vector yDesc(seq_length); + cudnnTensorDescriptor_t hxDesc, cxDesc; + cudnnTensorDescriptor_t hyDesc, cyDesc; + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType(), 3, out_dims, out_strides)); + } + + checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&hyDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cyDesc)); + + checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + checkCudaErrors(cudnnSetTensorNdDescriptor(hyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + + void *work_space; + size_t work_space_size = rnn_desc.work_space_size(xDesc.data(), seq_length); + size_t work_space_allocation; + if (work_space_size > 0) + work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation); + + RnnWeightDescriptor w_desc(w->size); + + if (is_train) { + checkCudaErrors(cudnnRNNForwardTraining( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + w_desc.desc, w->ptr(), + yDesc.data(), y->ptr(), + hyDesc, hy->ptr(), + cyDesc, mode == "lstm" ? cy->ptr() : nullptr, + work_space, work_space_size, + reservation->ptr(), reservation->size + )); + } else { + checkCudaErrors(cudnnRNNForwardInference( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + w_desc.desc, w->ptr(), + yDesc.data(), y->ptr(), + hyDesc, hy->ptr(), + cyDesc, mode == "lstm" ? cy->ptr() : nullptr, + work_space, work_space_size + )); + } + + for (int i = 0; i < seq_length; i++) { + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i])); + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(hyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cyDesc)); + + if (work_space) + exe.temp_allocator->free(work_space, work_space_size, work_space_allocation); +} + +#endif +#endif // JIT + +} // jittor + diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h new file mode 100644 index 00000000..709a1436 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h @@ -0,0 +1,36 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnRnnOp : Op { + Var* x, * hx, * cx, * y, * hy, * cy; + Var* w; + Var* reservation; + string mode; + int input_size, hidden_size, num_layers, proj_size; + int seq_length, batch_size; + float dropout; + bool bias, bidirectional, is_train; + + // @attrs(multiple_outputs) + CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional); + // @attrs(multiple_outputs) + CudnnRnnOp(Var* x, Var* hx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional); + + void init_rnn(); + + const char* name() const override { return "cudnn_rnn"; } + void grads(Var** douts, VarPtr* dins) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc b/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc new file mode 100644 index 00000000..b4f4518b --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc @@ -0,0 +1,74 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cudnn_rnn_descriptor.h" + +namespace jittor { + +vector cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional) { + // A pseudo mini-batch for fetching weight space size. + int dimX[] = {1, input_size, 1}; + int strideX[] = {input_size, 1, 1}; + cudnnTensorDescriptor_t xDesc; + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc)); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc, CUDNN_DATA_FLOAT, 3, dimX, strideX)); + + RnnDescriptor rnn_desc = RnnDescriptor(cudnn_handle, mode, hidden_size, num_layers, 0, bidirectional); + int weightSpaceSize = rnn_desc.weight_space_size(xDesc); + RnnWeightDescriptor w_desc(weightSpaceSize); + + vector weight_offsets; + weight_offsets.push_back(weightSpaceSize / sizeof(float)); + + int num_directions = bidirectional + 1; + int num_linear_layers = rnn_string_to_num_linear_layers(mode); + + for (int layer = 0; layer < num_layers * num_directions; layer++) { + for (int linLayerID = 0; linLayerID < num_linear_layers; linLayerID++) { + cudnnFilterDescriptor_t linLayerMatDesc; + cudnnFilterDescriptor_t linLayerBiasDesc; + float *linLayerMat = nullptr; + float *linLayerBias = nullptr; + + checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerMatDesc)); + checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerBiasDesc)); + + checkCudaErrors(cudnnGetRNNLinLayerMatrixParams( + cudnn_handle, rnn_desc.desc, + layer, + xDesc, + w_desc.desc, + nullptr, + linLayerID, + linLayerMatDesc, + (void **) &linLayerMat + )); + weight_offsets.push_back(linLayerMat - (float *) nullptr); + + if (bias) { + checkCudaErrors(cudnnGetRNNLinLayerBiasParams( + cudnn_handle, rnn_desc.desc, + layer, + xDesc, + w_desc.desc, + nullptr, + linLayerID, + linLayerBiasDesc, + (void **) &linLayerBias + )); + weight_offsets.push_back(linLayerBias - (float *) nullptr); + } + } + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc)); + + return weight_offsets; +} + + +} // jittor diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 3b85d134..5daf1dc8 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -214,7 +214,8 @@ def t(x): return x.transpose(*pose) jt.Var.t = t -def median(x,dim=None,keepdim=False): +def median(x,dim=None,keepdim=False, keepdims=False): + keepdim = keepdim or keepdims if dim is None: x = x.reshape(-1) dim=0 @@ -637,7 +638,8 @@ def topk(input, k, dim=None, largest=True, sorted=True): jt.Var.topk = topk -def kthvalue(input, k, dim=None, keepdim=False): +def kthvalue(input, k, dim=None, keepdim=False, keepdims=False): + keepdim = keepdim or keepdims if dim is None: dim = -1 if dim<0: diff --git a/python/jittor/nn.py b/python/jittor/nn.py index bdfe04d6..86d9fffc 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -12,8 +12,9 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** from abc import abstractmethod +from sys import breakpointhook import jittor as jt -from jittor import init, Module +from jittor import flatten, init, Module import numpy as np import collections import math @@ -1881,6 +1882,7 @@ class RNNCell(jt.Module): return y + class GRUCell(jt.Module): def __init__(self, input_size, hidden_size, bias=True): ''' A gated recurrent unit (GRU) cell. @@ -1941,7 +1943,7 @@ class RNNBase(Module): def __init__(self, mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0, bidirectional: bool = False, - proj_size: int = 0) -> None: + proj_size: int = 0, nonlinearity: str = None) -> None: super().__init__() self.mode = mode @@ -1953,6 +1955,7 @@ class RNNBase(Module): self.dropout = dropout self.bidirectional = bidirectional self.proj_size = proj_size + self.nonlinearity = nonlinearity if mode == 'LSTM': gate_size = 4 * hidden_size @@ -1994,6 +1997,56 @@ class RNNBase(Module): build_unit(f'bias_ih_l{layer}', gate_size) build_unit(f'bias_hh_l{layer}', gate_size) + def _cudnn_flatten_weights(self, cudnn_mode): + def copy_to_flatten_weight(param_name, offset_idx, num_gates): + def copy_to(param_name, offset_idx, idx): + cur_offset = self._cudnn_weight_offset[offset_idx] + param = getattr(self, param_name) + param = param[self.hidden_size * idx: self.hidden_size * (idx + 1)] + ft_weight[cur_offset:cur_offset + param.numel()] = param.flatten() + + if self.bias: + for idx in range(num_gates): + copy_to('weight' + param_name, offset_idx + idx * 2, idx) + copy_to('bias' + param_name, offset_idx + idx * 2 + 1, idx) + return num_gates * 2 + else: + for idx in range(num_gates): + copy_to('weight' + param_name, offset_idx + idx, idx) + return num_gates + + if jt.flags.use_cuda and jt.cudnn: + if getattr(self, '_cudnn_weight_size', None) is None: + offset_array = jt.cudnn.cudnn_rnn_weight_offset( + cudnn_mode, + self.input_size, + self.hidden_size, + self.num_layers, + self.proj_size, + self.bias, + self.bidirectional + ) + self._cudnn_weight_size = offset_array[0] + self._cudnn_weight_offset = offset_array[1:] + + num_gates = { + "RNN": 1, "LSTM": 4, "GRU": 3 + }[self.mode] + ft_weight = jt.zeros(self._cudnn_weight_size, dtype=jt.float32) + + cnt = 0 + for layer in range(self.num_layers): + suffix = '' + cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates) + cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates) + if self.bidirectional: + suffix = '_reverse' + cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates) + cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates) + return ft_weight + else: + raise RuntimeError("Not Cudnn found") + @abstractmethod def call_rnn_cell(self, input, hidden, suffix): pass @@ -2013,87 +2066,116 @@ class RNNBase(Module): return output, hidden - def execute(self, input, hx): + def _execute_cudnn_rnn(self, input, hx): + cudnn_mode = { + ('RNN', 'tanh'): 'tanh', + ('RNN', 'relu'): 'relu', + ('LSTM', None): 'lstm', + ('GRU', None): 'gru' + }[(self.mode, self.nonlinearity)] + ft_weight = self._cudnn_flatten_weights(cudnn_mode) + + if self.mode == 'LSTM': + ret = jt.cudnn.ops.cudnn_rnn(input, hx[0], hx[1], ft_weight, + cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0, + self.dropout, self.bias, self.bidirectional, self.is_training() + ) + return ret[0], (ret[1], ret[2]) + else: + ret = jt.cudnn.ops.cudnn_rnn(input, hx, ft_weight, + cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0, + self.dropout, self.bias, self.bidirectional, self.is_training() + ) + return ret[0], ret[1] + + def execute(self, input, hx=None): if self.batch_first: input = input.permute(1, 0, 2) num_directions = 2 if self.bidirectional else 1 if hx is None: - hx = self.default_init_state() + if self.mode in ['RNN', 'GRU']: + hx = jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype) + elif self.mode == 'LSTM': + hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype), + jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)) - hidden_n = [] + if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0: + return self._execute_cudnn_rnn(input, hx) + else: + hidden_n = [] - for l in range(self.num_layers): - output = [] + for l in range(self.num_layers): + output = [] - if isinstance(hx, tuple): - hidden = [h[l * num_directions] for h in hx] - else: - hidden = hx[l * num_directions] - - output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}') - hidden_n.append(_hidden) - - if self.bidirectional: if isinstance(hx, tuple): - hidden = [h[l * num_directions + 1] for h in hx] + hidden = [h[l * num_directions] for h in hx] else: - hidden = hx[l * num_directions + 1] + hidden = hx[l * num_directions] - output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse') - output = jt.concat([output, output_b], dim=-1) + output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}') hidden_n.append(_hidden) - if self.dropout > 0: - input = dropout(output, p=self.dropout) + if self.bidirectional: + if isinstance(hx, tuple): + hidden = [h[l * num_directions + 1] for h in hx] + else: + hidden = hx[l * num_directions + 1] + + output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse') + output = jt.concat([output, output_b], dim=-1) + hidden_n.append(_hidden) + + if self.dropout > 0: + input = dropout(output, p=self.dropout) + else: + input = output + + if isinstance(hx, tuple): + hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n)) else: - input = output + hidden_n = jt.stack(hidden_n, dim=0) - if isinstance(hx, tuple): - hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n)) - else: - hidden_n = jt.stack(hidden_n, dim=0) - - return output, hidden_n + return output, hidden_n class RNN(RNNBase): + ''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' + :type nonlinearity: str, optional + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional RNN. Default: False + :type bidirectional: bool, optional + + Example: + >>> rnn = nn.RNN(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + ''' def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False, dropout: float = 0, bidirectional: bool = False) -> None: - ''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence. - - :param input_size: The number of expected features in the input. - :type input_size: int - - :param hidden_size: The number of features in the hidden state. - :type hidden_size: int - - :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 - :type num_layers: int, optinal - - :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' - :type nonlinearity: str, optional - - :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. - :type bias: bool, optional - - :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False - :type bias: bool, optional - - :param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0 - :type dropout: float, optional - - :param bidirectional: If True, becomes a bidirectional RNN. Default: False - :type bidirectional: bool, optional - - Example: - >>> rnn = nn.RNN(10, 20, 2) - >>> input = jt.randn(5, 3, 10) - >>> h0 = jt.randn(2, 3, 20) - >>> output, hn = rnn(input, h0) - ''' super().__init__('RNN', input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) @@ -2112,47 +2194,48 @@ class RNN(RNNBase): if self.nonlinearity == 'tanh': h = jt.tanh(y) else: - h = jt.relu(y) + h = jt.nn.relu(y) return h, h class LSTM(RNNBase): + ''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional LSTM. Default: False + :type bidirectional: bool, optional + + :param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0 + :type proj_size: int, optional + + Example: + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> c0 = jt.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + ''' + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False, proj_size=0): - ''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. - - :param input_size: The number of expected features in the input. - :type input_size: int - - :param hidden_size: The number of features in the hidden state. - :type hidden_size: int - - :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 - :type num_layers: int, optinal - - :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. - :type bias: bool, optional - - :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False - :type bias: bool, optional - - :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0 - :type dropout: float, optional - - :param bidirectional: If True, becomes a bidirectional LSTM. Default: False - :type bidirectional: bool, optional - - :param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0 - :type proj_size: int, optional - - Example: - >>> rnn = nn.LSTM(10, 20, 2) - >>> input = jt.randn(5, 3, 10) - >>> h0 = jt.randn(2, 3, 20) - >>> c0 = jt.randn(2, 3, 20) - >>> output, (hn, cn) = rnn(input, (h0, c0)) - ''' super().__init__('LSTM', input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size) @@ -2179,38 +2262,39 @@ class LSTM(RNNBase): class GRU(RNNBase): + ''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional GRU. Default: False + :type bidirectional: bool, optional + + Example: + >>> rnn = nn.GRU(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + ''' + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0, bidirectional: bool = False) -> None: - ''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. - - :param input_size: The number of expected features in the input. - :type input_size: int - - :param hidden_size: The number of features in the hidden state. - :type hidden_size: int - - :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 - :type num_layers: int, optinal - - :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. - :type bias: bool, optional - - :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False - :type bias: bool, optional - - :param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0 - :type dropout: float, optional - - :param bidirectional: If True, becomes a bidirectional GRU. Default: False - :type bidirectional: bool, optional - - Example: - >>> rnn = nn.GRU(10, 20, 2) - >>> input = jt.randn(5, 3, 10) - >>> h0 = jt.randn(2, 3, 20) - >>> output, hn = rnn(input, h0) - ''' super().__init__('GRU', input_size, hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) diff --git a/python/jittor/src/ops/reduce_op.cc b/python/jittor/src/ops/reduce_op.cc index 703c9f50..dc056af5 100644 --- a/python/jittor/src/ops/reduce_op.cc +++ b/python/jittor/src/ops/reduce_op.cc @@ -36,7 +36,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -67,7 +67,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -98,7 +98,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -129,7 +129,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -160,7 +160,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -191,7 +191,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- @@ -226,7 +226,7 @@ unordered_set reduce_ops = { * [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). - * [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. ---------------- diff --git a/python/jittor/src/ops/reindex_reduce_op.h b/python/jittor/src/ops/reindex_reduce_op.h index 37d7e376..41d29197 100644 --- a/python/jittor/src/ops/reindex_reduce_op.h +++ b/python/jittor/src/ops/reindex_reduce_op.h @@ -56,7 +56,7 @@ struct ReindexReduceOp : Op { * [in] shape: the output shape, a integer array - * [in] indexes: array of c++ style integer expression, its length should be the same with length of shape, some buildin variables it can use are:: + * [in] indexes: array of c++ style integer expression, its length should be the same with length of output shape, some buildin variables it can use are:: XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden diff --git a/python/jittor/test/test_rnn.py b/python/jittor/test/test_rnn.py index 51e03668..9cc37fb6 100644 --- a/python/jittor/test/test_rnn.py +++ b/python/jittor/test/test_rnn.py @@ -7,15 +7,8 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** import unittest -import jittor as jt -import jittor.nn as nn -import numpy as np - - -skip_this_test = False - +from unittest.case import skipIf try: - jt.dirty_fix_pytorch_runtime_error() import torch import torch.nn as tnn except: @@ -23,72 +16,49 @@ except: tnn = None skip_this_test = True +import jittor as jt +import jittor.nn as nn +import numpy as np -def check_equal_1(t_rnn, j_rnn, input, h0): +skip_this_test = False + +def check_equal_1(t_rnn, j_rnn, input, h0, dev=None): j_rnn.load_state_dict(t_rnn.state_dict()) - t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0)) + if dev: + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + + else: + t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0)) + t_output = t_output.detach().cpu().numpy() + th = th.detach().cpu().numpy() j_output, jh = j_rnn(jt.float32(input), jt.float32(h0)) + j_output, jh = j_output.data, jh.data - assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06) - assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(t_output, j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th, jh.data, rtol=1e-03, atol=1e-06) - -def check_equal_2(t_rnn, j_rnn, input, h0, c0): +def check_equal_2(t_rnn, j_rnn, input, h0, c0, dev=None): j_rnn.load_state_dict(t_rnn.state_dict()) - t_output, (th, tc) = t_rnn(torch.from_numpy(input), - (torch.from_numpy(h0), torch.from_numpy(c0))) + if dev: + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + else: + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0), torch.from_numpy(c0))) j_output, (jh, jc) = j_rnn(jt.float32(input), (jt.float32(h0), jt.float32(c0))) - assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06) - assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06) - assert np.allclose(tc.detach().numpy(), jc.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) @unittest.skipIf(skip_this_test, "No Torch found") class TestRNN(unittest.TestCase): - def test_rnn(self): - h0 = np.random.rand(1, 24, 200).astype(np.float32) - input = np.random.rand(32, 24, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200) - j_rnn = nn.RNN(100, 200) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(4, 4, 200).astype(np.float32) - input = np.random.rand(5, 4, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, num_layers=4) - j_rnn = nn.RNN(100, 200, num_layers=4) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(2, 1, 200).astype(np.float32) - input = np.random.rand(5, 1, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, bidirectional=True) - j_rnn = nn.RNN(100, 200, bidirectional=True) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(4, 4, 200).astype(np.float32) - input = np.random.rand(5, 4, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) - j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) - check_equal_1(t_rnn, j_rnn, input, h0) - - h0 = np.random.rand(2, 4, 200).astype(np.float32) - input = np.random.rand(5, 4, 100).astype(np.float32) - - t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) - j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) - t_rnn.eval() - j_rnn.eval() - check_equal_1(t_rnn, j_rnn, input, h0) - def test_lstm_cell(self): np_h0 = torch.randn(3, 20).numpy() np_c0 = torch.randn(3, 20).numpy() @@ -173,7 +143,49 @@ class TestRNN(unittest.TestCase): j_output = j_output.data assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) - def test_lstm(self): + def test_basic_rnn(self): + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200) + j_rnn = nn.RNN(100, 200) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_multilayer_rnn(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=4) + j_rnn = nn.RNN(100, 200, num_layers=4) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_bidirectional_rnn(self): + h0 = np.random.rand(2, 1, 200).astype(np.float32) + input = np.random.rand(5, 1, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, bidirectional=True) + j_rnn = nn.RNN(100, 200, bidirectional=True) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_no_bias_rnn(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) + j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_dropout_rnn(self): + h0 = np.random.rand(2, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) + j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) + t_rnn.eval() + j_rnn.eval() + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_basic_lstm(self): h0 = np.random.rand(1, 24, 200).astype(np.float32) c0 = np.random.rand(1, 24, 200).astype(np.float32) input = np.random.rand(32, 24, 100).astype(np.float32) @@ -182,6 +194,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_projection_lstm(self): proj_size = 13 h0 = np.random.rand(1, 24, proj_size).astype(np.float32) c0 = np.random.rand(1, 24, 200).astype(np.float32) @@ -190,6 +203,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, proj_size=proj_size) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_multilayer_lstm(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) c0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -198,6 +212,8 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, num_layers=4) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_multilayer_projection_lstm(self): + proj_size = 8 h0 = np.random.rand(2, 4, proj_size).astype(np.float32) c0 = np.random.rand(2, 4, 20).astype(np.float32) input = np.random.rand(5, 4, 10).astype(np.float32) @@ -206,6 +222,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_bidirectional_lstm(self): h0 = np.random.rand(2, 1, 200).astype(np.float32) c0 = np.random.rand(2, 1, 200).astype(np.float32) input = np.random.rand(5, 1, 100).astype(np.float32) @@ -214,7 +231,8 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, bidirectional=True) check_equal_2(t_rnn, j_rnn, input, h0, c0) - proj_size = 13 + def test_bidirectional_projection_lstm(self): + proj_size = 10 h0 = np.random.rand(2, 4, proj_size).astype(np.float32) c0 = np.random.rand(2, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -223,6 +241,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size) check_equal_2(t_rnn, j_rnn, input, h0, c0) + def test_multilayer_bidirectional_projection_lstm(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) c0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -231,7 +250,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False) check_equal_2(t_rnn, j_rnn, input, h0, c0) - + def test_dropout_lstm(self): h0 = np.random.rand(2, 4, 200).astype(np.float32) c0 = np.random.rand(2, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -242,7 +261,7 @@ class TestRNN(unittest.TestCase): j_rnn.eval() check_equal_2(t_rnn, j_rnn, input, h0, c0) - def test_gru(self): + def test_basic_gru(self): h0 = np.random.rand(1, 24, 200).astype(np.float32) input = np.random.rand(32, 24, 100).astype(np.float32) @@ -250,6 +269,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200) check_equal_1(t_rnn, j_rnn, input, h0) + def test_multilayer_gru(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -257,6 +277,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200, num_layers=4) check_equal_1(t_rnn, j_rnn, input, h0) + def test_bidirectional_gru(self): h0 = np.random.rand(2, 1, 200).astype(np.float32) input = np.random.rand(5, 1, 100).astype(np.float32) @@ -264,6 +285,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200, bidirectional=True) check_equal_1(t_rnn, j_rnn, input, h0) + def test_multilayer_bidirectional_gru(self): h0 = np.random.rand(4, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -271,6 +293,7 @@ class TestRNN(unittest.TestCase): j_rnn = nn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False) check_equal_1(t_rnn, j_rnn, input, h0) + def test_multilayer_dropout_gru(self): h0 = np.random.rand(2, 4, 200).astype(np.float32) input = np.random.rand(5, 4, 100).astype(np.float32) @@ -280,6 +303,330 @@ class TestRNN(unittest.TestCase): j_rnn.eval() check_equal_1(t_rnn, j_rnn, input, h0) + def test_rnn_default_hx(self): + input = np.random.rand(32, 24, 12).astype(np.float32) + h0 = np.zeros((1, 24, 24)).astype(np.float32) + + t_rnn = tnn.RNN(12, 24) + j_rnn = nn.RNN(12, 24) + j_rnn.load_state_dict(t_rnn.state_dict()) + t_output, th = t_rnn(torch.from_numpy(input)) + j_output, jh = j_rnn(jt.array(input)) + + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + + def test_lstm_default_hx(self): + input = np.random.rand(32, 24, 10).astype(np.float32) + t_rnn = tnn.LSTM(10, 20, num_layers=2, bidirectional=True) + j_rnn = nn.LSTM(10, 20, num_layers=2, bidirectional=True) + j_rnn.load_state_dict(t_rnn.state_dict()) + t_output, (th, tc) = t_rnn(torch.from_numpy(input)) + j_output, (jh, jc) = j_rnn(jt.array(input)) + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) + + @skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev) + + j_rnn = nn.RNN(100, 200, nonlinearity='relu') + j_rnn.train() + j_rnn.load_state_dict(t_rnn.state_dict()) + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_output, th = t_rnn(torch.from_numpy(input).to(dev), + torch.from_numpy(h0).to(dev)) + + j_output, jh = j_rnn(jt.array(input), jt.array(h0)) + + np.testing.assert_allclose(j_output.data, t_output.detach().cpu().numpy()) + np.testing.assert_allclose(jh.data, th.detach().cpu().numpy()) + + @skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.RNN(32, 64, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-2) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-3, rtol=1e-2) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_basic_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev) + j_rnn = nn.RNN(100, 200, nonlinearity='relu') + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_multilayer_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, num_layers=4, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, num_layers=4, nonlinearity='tanh') + + h0 = np.random.rand(4, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_bidirectional_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, bidirectional=True, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, bidirectional=True, nonlinearity='tanh') + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_no_bias_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh') + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_dropout_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh') + t_rnn.eval() + j_rnn.eval() + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_basic_lstm_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(100, 200).to(dev) + j_rnn = nn.LSTM(100, 200) + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + c0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + check_equal_2(t_rnn, j_rnn, input, h0, c0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.RNN(32, 64, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_gru_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.GRU(32, 64).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.GRU(32, 64) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_lstm_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(32, 64).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.LSTM(32, 64) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + c0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0) + j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0)) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_multilayer_bidirectional_cudnn_lstm_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(32, 64, num_layers=4, bidirectional=True).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.LSTM(32, 64, num_layers=4, bidirectional=True) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(8, 4, 64).astype(np.float32) + c0 = np.random.rand(8, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0) + j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0)) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_speed(self): + from time import time + iters = 100 + + h0 = np.random.rand(1, 128, 256).astype(np.float32) + input = np.random.rand(128, 128, 128).astype(np.float32) + + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(128, 256, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + t_input = torch.from_numpy(input).to(dev) + t_h0 = torch.from_numpy(h0).to(dev) + + start_time = time() + for i in range(iters): + t_optim.zero_grad() + t_output, th = t_rnn(t_input, t_h0) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + print('torch time = ', time() - start_time) + + j_rnn = nn.RNN(128, 256, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + j_input, j_h0 = jt.array(input), jt.array(h0) + + start_time = time() + for i in range(iters): + j_output, jh = j_rnn(j_input, j_h0) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + jt.sync_all(True) + print('jittor Cudnn time = ', time() - start_time) + + jt_cudnn, jt.cudnn = jt.cudnn, None + j_rnn = nn.RNN(128, 256, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + start_time = time() + for i in range(iters): + j_output, jh = j_rnn(j_input, j_h0) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + jt.sync_all(True) + print('jittor native time = ', time() - start_time) + jt.cudnn = jt_cudnn + if __name__ == "__main__": unittest.main() \ No newline at end of file