feats: polish rnnbase

This commit is contained in:
lzhengning 2021-04-29 15:05:39 +08:00
parent 026dfb8fa2
commit b0669c11c0
1 changed files with 5 additions and 0 deletions

View File

@ -11,6 +11,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from abc import abstractmethod
import jittor as jt
from jittor import init, Module
import numpy as np
@ -1537,6 +1538,10 @@ class RNNBase(Module):
build_unit(f'bias_ih_l{layer}', gate_size)
build_unit(f'bias_hh_l{layer}', gate_size)
@abstractmethod
def call_rnn_cell(self, input, hidden, suffix):
pass
def call_rnn_sequence(self, input, hidden, suffix):
if 'reverse' in suffix:
input = input[::-1]