change GroupNorm

This commit is contained in:
zhouwy19 2020-08-11 16:36:08 +08:00
parent 80cebbb389
commit d66606a205
2 changed files with 21 additions and 12 deletions

View File

@ -359,16 +359,14 @@ class GroupNorm(Module):
N,C,H,W = x.shape
assert C == self.num_channels
assert C % self.num_groups == 0
x_ = x.reindex([N, int(C/self.num_groups), self.num_groups, H, W], [
"i0", f"i2*{C/self.num_groups}+i1", "i3", "i4"
])
xmean = jt.mean(x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
x2mean = jt.mean(x_*x_, dims=[1,3,4], keepdims=1).reindex(x.shape, ["i0", "0", f"i1/({C}/{self.num_groups})","0", "0"])
x = x.reshape((N, self.num_groups, int(C/self.num_groups), H*W))
xmean = jt.mean(x, dims=[2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
return (norm_x * w + b).reshape((N,C,H,W))
Relu = jt.make_module(relu)
ReLU = Relu

View File

@ -141,6 +141,18 @@ pjmap = {
'extras': {'affine': 'None'},
'delete': ['track_running_stats'],
},
'GroupNorm': {
'pytorch': {
'args': "num_groups, num_channels, eps=1e-05, momentum=0.1, affine=True"
},
'jittor': {
'module': 'nn',
'name': 'GroupNorm',
'args': 'num_groups, num_channels, eps=1e-05, affine=None, is_train=True',
},
'links': {},
'extras': {'affine': 'None'},
},
'Dropout2d': {
'pytorch': {
'args': 'p=0.5, inplace=False',
@ -349,13 +361,12 @@ unsupport_ops = [
'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool3d',
'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d',
'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention',
'RReLU', 'SELU', 'CELU', 'GELU', 'Softplus', 'Softshrink', 'Softsign', 'Tanhshrink',
'RReLU', 'SELU', 'CELU', 'GELU', 'Softshrink', 'Softsign', 'Tanhshrink',
'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss',
'BatchNorm3d', 'GroupNorm', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm',
'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder',
'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Identity', 'Bilinear',
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'L1Loss',
'MSELoss', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCELoss', 'BCEWithLogitsLoss',
'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss',
'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss',
'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', 'UpsamplingNearest2d',
'UpsamplingBilinear2d', 'DataParallel', 'DistributedDataParallel', 'clip_grad_norm_', 'clip_grad_value_',