mirror of https://github.com/Jittor/Jittor
change GroupNorm
This commit is contained in:
parent
80cebbb389
commit
d66606a205
|
@ -359,16 +359,14 @@ class GroupNorm(Module):
|
|||
N,C,H,W = x.shape
|
||||
assert C == self.num_channels
|
||||
assert C % self.num_groups == 0
|
||||
x_ = x.reindex([N, int(C/self.num_groups), self.num_groups, H, W], [
|
||||
"i0", f"i2*{C/self.num_groups}+i1", "i3", "i4"
|
||||
])
|
||||
xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
|
||||
x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
|
||||
x = x.reshape((N, self.num_groups, int(C/self.num_groups), H*W))
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
return (norm_x * w + b).reshape((N,C,H,W))
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
|
|
|
@ -141,6 +141,18 @@ pjmap = {
|
|||
'extras': {'affine': 'None'},
|
||||
'delete': ['track_running_stats'],
|
||||
},
|
||||
'GroupNorm': {
|
||||
'pytorch': {
|
||||
'args': "num_groups, num_channels, eps=1e-05, momentum=0.1, affine=True"
|
||||
},
|
||||
'jittor': {
|
||||
'module': 'nn',
|
||||
'name': 'GroupNorm',
|
||||
'args': 'num_groups, num_channels, eps=1e-05, affine=None, is_train=True',
|
||||
},
|
||||
'links': {},
|
||||
'extras': {'affine': 'None'},
|
||||
},
|
||||
'Dropout2d': {
|
||||
'pytorch': {
|
||||
'args': 'p=0.5, inplace=False',
|
||||
|
@ -349,13 +361,12 @@ unsupport_ops = [
|
|||
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
|
||||
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
|
||||
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
|
||||
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
|
||||
'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
|
||||
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
|
||||
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss',
|
||||
'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss',
|
||||
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
|
||||
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
|
||||
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
|
||||
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',
|
||||
|
|
Loading…
Reference in New Issue