mirror of https://github.com/Jittor/Jittor
feats: defaults to keep the first dim in nn.Flatten
This commit is contained in:
parent
601101ea44
commit
b2fb32aa52
|
@ -780,7 +780,23 @@ LeakyReLU = Leaky_relu
|
|||
ReLU6 = jt.make_module(relu6)
|
||||
Softmax = jt.make_module(softmax, 2)
|
||||
GELU = jt.make_module(gelu)
|
||||
Flatten = jt.make_module(jt.flatten)
|
||||
|
||||
class Flatten(Module):
|
||||
''' Flattens the contiguous range of dimensions in a Var.
|
||||
|
||||
:param start_dim: the first dimension to be flattened. Defaults: 1.
|
||||
:type start_dim: int
|
||||
|
||||
:param end_dim: the last dimension to be flattened. Defaults: -1.
|
||||
:type end_dim: int
|
||||
'''
|
||||
def __init__(self, start_dim=1, end_dim=-1):
|
||||
self.start_dim = start_dim
|
||||
self.end_dim = end_dim
|
||||
|
||||
def execute(self, x) -> jt.Var:
|
||||
return x.flatten(self.start_dim, self.end_dim)
|
||||
|
||||
|
||||
from jittor.depthwise_conv import DepthwiseConv
|
||||
|
||||
|
|
Loading…
Reference in New Issue