mirror of https://github.com/Jittor/Jittor
feats: polish rnnbase
This commit is contained in:
parent
026dfb8fa2
commit
b0669c11c0
|
@ -11,6 +11,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from abc import abstractmethod
|
||||
import jittor as jt
|
||||
from jittor import init, Module
|
||||
import numpy as np
|
||||
|
@ -1537,6 +1538,10 @@ class RNNBase(Module):
|
|||
build_unit(f'bias_ih_l{layer}', gate_size)
|
||||
build_unit(f'bias_hh_l{layer}', gate_size)
|
||||
|
||||
@abstractmethod
|
||||
def call_rnn_cell(self, input, hidden, suffix):
|
||||
pass
|
||||
|
||||
def call_rnn_sequence(self, input, hidden, suffix):
|
||||
if 'reverse' in suffix:
|
||||
input = input[::-1]
|
||||
|
|
Loading…
Reference in New Issue