feats: defaults to keep the first dim in nn.Flatten

This commit is contained in:
lzhengning 2022-09-29 22:14:29 +08:00
parent 601101ea44
commit b2fb32aa52
1 changed files with 17 additions and 1 deletions

View File

@ -780,7 +780,23 @@ LeakyReLU = Leaky_relu
ReLU6 = jt.make_module(relu6)
Softmax = jt.make_module(softmax, 2)
GELU = jt.make_module(gelu)
Flatten = jt.make_module(jt.flatten)
class Flatten(Module):
''' Flattens the contiguous range of dimensions in a Var.
:param start_dim: the first dimension to be flattened. Defaults: 1.
:type start_dim: int
:param end_dim: the last dimension to be flattened. Defaults: -1.
:type end_dim: int
'''
def __init__(self, start_dim=1, end_dim=-1):
self.start_dim = start_dim
self.end_dim = end_dim
def execute(self, x) -> jt.Var:
return x.flatten(self.start_dim, self.end_dim)
from jittor.depthwise_conv import DepthwiseConv