1287 lines
39 KiB
Python
Executable File
1287 lines
39 KiB
Python
Executable File
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
"""Define the grad rules of neural network related operations."""
|
|
import os
|
|
import numpy as np
|
|
from mindspore.ops import _selected_grad_ops as SG
|
|
from mindspore.ops.primitive import constexpr
|
|
from mindspore.common.tensor import Tensor
|
|
from mindspore.ops.operations import nn_ops as nps
|
|
from .grad_base import bprop_getters
|
|
from .. import functional as F
|
|
from .. import operations as P
|
|
from ...common import dtype as mstype
|
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
|
from ..operations import _grad_ops as G
|
|
from ..operations import _inner_ops as inner
|
|
from ... import context
|
|
|
|
env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
|
|
|
|
|
|
@bprop_getters.register(P.BiasAdd)
|
|
def get_bprop_bias_add(self):
|
|
"""Grad definition for `BiasAdd` operation."""
|
|
bias_grad = SG.BiasAddGrad(self.data_format)
|
|
|
|
def bprop(x, w, out, dout):
|
|
return dout, bias_grad(dout)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Conv2D)
|
|
def get_bprop_conv2d(self):
|
|
"""Grad definition for `Conv2D` operation."""
|
|
input_grad = P.Conv2DBackpropInput(
|
|
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
|
|
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
|
)
|
|
filter_grad = G.Conv2DBackpropFilter(
|
|
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
|
|
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
|
)
|
|
get_shape = P.Shape()
|
|
|
|
def bprop(x, w, out, dout):
|
|
dx = input_grad(dout, w, get_shape(x))
|
|
if env_force_bprop_seq == '1':
|
|
x = F.depend(x, dx)
|
|
dw = filter_grad(dout, x, get_shape(w))
|
|
return dx, dw
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(nps.Conv3D)
|
|
def get_bprop_conv3d(self):
|
|
"""Grad definition for `Conv3D` operation."""
|
|
input_grad = nps.Conv3DBackpropInput(
|
|
self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
|
|
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
|
)
|
|
filter_grad = G.Conv3DBackpropFilter(
|
|
self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
|
|
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
|
)
|
|
get_shape = P.Shape()
|
|
|
|
def bprop(x, w, out, dout):
|
|
dx = input_grad(w, dout, get_shape(x))
|
|
dw = filter_grad(x, dout, get_shape(w))
|
|
return dx, dw
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(nps.Conv3DTranspose)
|
|
def get_bprop_conv3d_transpose(self):
|
|
"""Grad definition for `Conv3DTranspose` operation."""
|
|
input_grad = nps.Conv3D(
|
|
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
|
|
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
|
)
|
|
filter_grad = G.Conv3DBackpropFilter(
|
|
out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
|
|
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
|
|
)
|
|
input_size = self.input_size
|
|
|
|
def bprop(x, w, out, dout):
|
|
dx = input_grad(dout, w)
|
|
dw = filter_grad(dout, x, F.shape(w))
|
|
return dx, dw, zeros_like(input_size)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(inner.ExtractImagePatches)
|
|
def get_bprop_extract_image_patches(self):
|
|
"""Grad definition for `ExtractImagePatches` operation."""
|
|
get_shape = P.Shape()
|
|
reshape = P.Reshape()
|
|
extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
|
|
strides=self.strides,
|
|
rates=self.rates,
|
|
padding=self.padding)
|
|
concat = P.Concat(axis=-1)
|
|
expand_dims = P.ExpandDims()
|
|
scatter_nd = P.ScatterNd()
|
|
dtype = P.DType()
|
|
fill = P.Fill()
|
|
slice_op = P.Slice()
|
|
transpose = P.Transpose()
|
|
cast = P.Cast()
|
|
matmul = P.MatMul()
|
|
|
|
_, _, ksizes_row, ksizes_col = self.ksizes
|
|
|
|
def bprop(x, out, dout):
|
|
x_shape = get_shape(x)
|
|
x_batch, x_depth, x_row, x_col = x_shape
|
|
x_indices_num = x_row * x_col + 1
|
|
x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
|
|
x_idx = reshape(x_idx, (1, 1, x_row, x_col))
|
|
x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
|
|
x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
|
|
|
|
out_shape = get_shape(out)
|
|
_, _, out_row, out_col = out_shape
|
|
out_indices_num = out_row * out_col * ksizes_row * ksizes_col
|
|
out_idx = F.tuple_to_array(range(out_indices_num))
|
|
out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
|
|
|
|
idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
|
|
idx_tensor = reshape(idx_tensor, (-1, 2))
|
|
sp_shape = (x_indices_num, out_indices_num)
|
|
sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
|
|
sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
|
|
|
|
grad = transpose(dout, (0, 2, 3, 1))
|
|
grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
|
|
grad = transpose(grad, (1, 2, 3, 4, 0, 5))
|
|
grad = reshape(grad, (-1, x_batch * x_depth))
|
|
|
|
jac = matmul(sp_tensor, grad)
|
|
dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
|
|
dx = transpose(dx, (2, 3, 0, 1))
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DepthwiseConv2dNative)
|
|
def get_bprop_depthwise_conv2d_native(self):
|
|
"""Grad definition for `DepthwiseConv2dNative` operation."""
|
|
input_grad = G.DepthwiseConv2dNativeBackpropInput(
|
|
self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
|
|
self.dilation, self.group
|
|
)
|
|
filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
|
|
self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
|
|
self.dilation, self.group
|
|
)
|
|
get_shape = P.Shape()
|
|
|
|
def bprop(x, w, out, dout):
|
|
dx = input_grad(get_shape(x), w, dout)
|
|
if env_force_bprop_seq == '1':
|
|
x = F.depend(x, dx)
|
|
dw = filter_grad(x, get_shape(w), dout)
|
|
return dx, dw
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.MaxPoolWithArgmax)
|
|
def get_bprop_max_pool_with_argmax(self):
|
|
"""Grad definition for `MaxPoolWithArgmax` operation."""
|
|
maxpool_grad = G.MaxPoolGradWithArgmax(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = maxpool_grad(x, dout[0], out[1])
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(G.MaxPoolGrad)
|
|
def get_bprop_max_pool_grad_grad(self):
|
|
"""Grad definition for `MaxPoolGrad` operation."""
|
|
maxpool_grad_grad = G.MaxPoolGradGrad(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode)
|
|
|
|
def bprop(x1, x2, grad, out, dout):
|
|
dx1 = zeros_like(x1)
|
|
dx2 = zeros_like(x2)
|
|
dgrad = maxpool_grad_grad(x1, x2, dout)
|
|
return (dx1, dx2, dgrad)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(G.MaxPoolGradGrad)
|
|
def get_bprop_max_pool_grad_grad_grad(self):
|
|
"""Grad definition for `MaxPoolGradGrad` operation."""
|
|
maxpool_grad = G.MaxPoolGrad(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode)
|
|
|
|
def bprop(x1, x2, grad, out, dout):
|
|
dx1 = zeros_like(x1)
|
|
dx2 = zeros_like(x2)
|
|
dgrad = maxpool_grad(x1, x2, dout)
|
|
return (dx1, dx2, dgrad)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.MaxPool)
|
|
def get_bprop_max_pool_grad(self):
|
|
"""Grad definition for `MaxPool` operation."""
|
|
maxpool_grad = G.MaxPoolGrad(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode,
|
|
data_format=self.format)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = maxpool_grad(x, out, dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
def _windowed_output_size(input_size, ksize, stride, pad_mode):
|
|
"""
|
|
helper func for AvgPoolGrad
|
|
"""
|
|
|
|
tmp_output = 0
|
|
tmp_pad_need = 0
|
|
tmp_pad_before = 0
|
|
tmp_pad_after = 0
|
|
if pad_mode == 'VALID':
|
|
tmp_output = (input_size - ksize + stride) // stride
|
|
tmp_pad_before = 0
|
|
tmp_pad_after = 0
|
|
elif pad_mode == 'SAME':
|
|
tmp_output = (input_size + stride - 1) // stride
|
|
tmp_pad_need = max(0, (tmp_output - 1) * stride + ksize - input_size)
|
|
tmp_pad_before = tmp_pad_need // 2
|
|
tmp_pad_after = tmp_pad_need - tmp_pad_before
|
|
return tmp_output, tmp_pad_before, tmp_pad_after
|
|
|
|
|
|
@constexpr
|
|
def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype):
|
|
"""
|
|
helper func for AvgPoolGrad.
|
|
|
|
`assist_input_matrix` is a 2d matrix with input_shape after padding,
|
|
the value of element which is padded is 0, else are 1.
|
|
For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize,
|
|
w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the
|
|
number of input that associate with output element.
|
|
"""
|
|
|
|
n_input, c_input, h_input, w_input = x_shape
|
|
h_ksize, w_ksize = ksize[2], ksize[3]
|
|
h_stride, w_stride = stride[2], stride[3]
|
|
n_output = n_input
|
|
c_output = c_input
|
|
h_output, w_output = 0, 0
|
|
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
|
|
h_output, pad_top, pad_bottom = _windowed_output_size(h_input, h_ksize,
|
|
h_stride, pad_mode)
|
|
w_output, pad_left, pad_right = _windowed_output_size(w_input, w_ksize,
|
|
w_stride, pad_mode)
|
|
|
|
output_size = n_output * c_output * h_output * w_output
|
|
output_shape = (n_output, c_output, h_output, w_output)
|
|
output = np.array([0.0] * output_size)
|
|
output = np.reshape(output, output_shape)
|
|
|
|
in_shape_after_padding_2d = (h_input + pad_top + pad_bottom, w_input + pad_left + pad_right)
|
|
assist_input_matrix = np.ones(in_shape_after_padding_2d).astype(np.float32)
|
|
if pad_top > 0:
|
|
assist_input_matrix[:pad_top, :] = 0
|
|
if pad_bottom > 0:
|
|
assist_input_matrix[-pad_bottom:, :] = 0
|
|
if pad_left > 0:
|
|
assist_input_matrix[:, :pad_left] = 0
|
|
if pad_right > 0:
|
|
assist_input_matrix[:, -pad_right:] = 0
|
|
|
|
for h in range(h_output):
|
|
for w in range(w_output):
|
|
curr_input = assist_input_matrix[h * h_stride: h * h_stride + h_ksize, w * w_stride: w * w_stride + w_ksize]
|
|
curr_sum = np.sum(curr_input)
|
|
if curr_sum > 0:
|
|
output[:, :, h, w] = 1. / curr_sum
|
|
return Tensor(output, x_dtype)
|
|
|
|
|
|
@constexpr
|
|
def _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype):
|
|
kernel_matrix = np.ones(kernel_matrix_shape)
|
|
return Tensor(kernel_matrix, x_dtype)
|
|
|
|
|
|
@bprop_getters.register(P.AvgPool)
|
|
def get_bprop_avg_pool_grad(self):
|
|
"""Grad definition for `AvgPool` operation."""
|
|
|
|
# the parameter of AvgPoolGrad in GPU and TBE/CPU is not same
|
|
if self.target == "GPU":
|
|
avgpool_grad_gpu = G.AvgPoolGradGpu(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode,
|
|
data_format=self.format)
|
|
|
|
def bprop_gpu(x, out, dout):
|
|
dx = avgpool_grad_gpu(x, out, dout)
|
|
return (dx,)
|
|
|
|
bprop_fn = bprop_gpu
|
|
|
|
elif self.target == "CPU":
|
|
avgpool_grad_cpu = G.AvgPoolGradCpu(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode,
|
|
data_format=self.format)
|
|
|
|
def bprop_cpu(x, out, dout):
|
|
dx = avgpool_grad_cpu(x, out, dout)
|
|
return (dx,)
|
|
|
|
bprop_fn = bprop_cpu
|
|
|
|
elif self.target == "GE":
|
|
avgpool_grad_ge = G.AvgPoolGrad(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode)
|
|
shape_op = P.Shape()
|
|
|
|
def bprop_ge(x, out, dout):
|
|
dx = avgpool_grad_ge(shape_op(x), dout)
|
|
return (dx,)
|
|
|
|
bprop_fn = bprop_ge
|
|
|
|
else:
|
|
avgpool_grad_vm = G.AvgPoolGradVm(
|
|
kernel_size=self.kernel_size,
|
|
strides=self.strides,
|
|
pad_mode=self.pad_mode)
|
|
k_size_nchw = avgpool_grad_vm.kernel_size
|
|
stride_nchw = avgpool_grad_vm.strides
|
|
pad_mode = self.pad_mode
|
|
|
|
def bprop_vm(x, out, dout):
|
|
x_shape_nchw = F.shape(x)
|
|
x_dtype = F.dtype(x)
|
|
kernel_matrix_shape = (1, x_shape_nchw[1],
|
|
k_size_nchw[2],
|
|
k_size_nchw[3])
|
|
mean_matrix = _get_mean_matrix(x_shape_nchw, k_size_nchw, stride_nchw, pad_mode, x_dtype)
|
|
kernel_matrix = _get_kernel_matrix(x_shape_nchw, kernel_matrix_shape, pad_mode, x_dtype)
|
|
dx = avgpool_grad_vm(x_shape_nchw, dout, mean_matrix, kernel_matrix)
|
|
return (dx,)
|
|
|
|
bprop_fn = bprop_vm
|
|
|
|
return bprop_fn
|
|
|
|
|
|
@bprop_getters.register(P.DropoutGenMask)
|
|
def get_bprop_dropout_gen_mask(self):
|
|
"""Grad definition for `DropoutGenMask` operation."""
|
|
|
|
def bprop(shape, keep_prob, out, dout):
|
|
return (zeros_like(shape), zeros_like(keep_prob))
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DropoutDoMask)
|
|
def get_bprop_dropout_do_mask(self):
|
|
"""Grad definition for `DropoutDoMask` operation."""
|
|
do_mask = P.DropoutDoMask()
|
|
|
|
def bprop(x, y, keep_prob, out, dout):
|
|
return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Mish)
|
|
def get_bprop_mish(self):
|
|
"""Grad definition for `Mish` operation."""
|
|
tanh = P.Tanh()
|
|
tanh_grad = SG.TanhGrad()
|
|
softplus = P.Softplus()
|
|
softplus_grad = G.SoftplusGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx1 = tanh(softplus(x))
|
|
dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
|
|
dx = (dx1 * dout + dx2)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SeLU)
|
|
def get_bprop_selu(self):
|
|
"""Grad definition for `SeLU` operation."""
|
|
scale = 1.0507009873554804934193349852946
|
|
elu_grad = G.EluGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = elu_grad(dout, out) * scale
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.MulNoNan)
|
|
def get_bprop_mul_no_nan(self):
|
|
"""Grad definition for `MulNoNan` operation."""
|
|
mul_no_nan = P.MulNoNan()
|
|
reduce_sum = P.ReduceSum()
|
|
reshape = P.Reshape()
|
|
|
|
def bprop(x, y, out, dout):
|
|
x_shape = F.shape(x)
|
|
y_shape = F.shape(y)
|
|
dx = mul_no_nan(dout, y)
|
|
dy = mul_no_nan(x, dout)
|
|
broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
|
|
if broadcast_x != ():
|
|
dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
|
|
if broadcast_y != ():
|
|
dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
|
|
return dx, dy
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ReLU)
|
|
def get_bprop_relu(self):
|
|
"""Grad definition for `ReLU` operation."""
|
|
input_grad = G.ReluGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, out)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(G.ReluGrad)
|
|
def get_bprop_relu_grad(self):
|
|
"""Grad definition for `ReLUGrad` operation."""
|
|
input_grad = G.ReluGrad()
|
|
|
|
def bprop(grad, y, out, dout):
|
|
dgrad = input_grad(dout, y)
|
|
return dgrad, zeros_like(y)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ReLU6)
|
|
def get_bprop_relu6(self):
|
|
"""Grad definition for `ReLU6` operation."""
|
|
input_grad = G.ReLU6Grad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ReLUV2)
|
|
def get_bprop_relu_v2(self):
|
|
"""Grad definition for `ReLUV2` operation."""
|
|
input_grad = G.ReluGradV2()
|
|
|
|
def bprop(x, out, dout):
|
|
mask = out[1]
|
|
dx = input_grad(dout[0], mask)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.HSwish)
|
|
def get_bprop_hswish(self):
|
|
"""Grad definition for `HSwish` operation."""
|
|
input_grad = G.HSwishGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.HSigmoid)
|
|
def get_bprop_hsigmoid(self):
|
|
"""Grad definition for `HSigmoid` operation."""
|
|
input_grad = G.HSigmoidGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Elu)
|
|
def get_bprop_elu(self):
|
|
"""Grad definition for `Elu` operation."""
|
|
input_grad = G.EluGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, out)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Sigmoid)
|
|
def get_bprop_sigmoid(self):
|
|
"""Grad definition for `Sigmoid` operation."""
|
|
input_grad = G.SigmoidGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(out, dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(G.SigmoidGrad)
|
|
def get_bprop_sigmoid_grad(self):
|
|
"""Grad definition for `SigmoidGrad` operation."""
|
|
sigmoid_grad = G.SigmoidGrad()
|
|
|
|
def bprop(y, grad, out, dout):
|
|
dy = dout * grad * (1. - 2 * y)
|
|
dgrad = sigmoid_grad(y, dout)
|
|
return dy, dgrad
|
|
|
|
return bprop
|
|
|
|
|
|
@constexpr
|
|
def _get_transpose_axis(x_shp, axis):
|
|
rank = len(x_shp)
|
|
if axis < 0:
|
|
axis += rank
|
|
reverse_axis = [i for i in range(rank)]
|
|
reverse_axis[axis] = rank - 1
|
|
reverse_axis[rank - 1] = axis
|
|
return tuple(reverse_axis)
|
|
|
|
|
|
@bprop_getters.register(P.Softmax)
|
|
def get_bprop_softmax(self):
|
|
"""Grad definition for `Softmax` operation."""
|
|
sum_func = P.ReduceSum(keep_dims=True)
|
|
sub = P.Sub()
|
|
mul = P.Mul()
|
|
get_shape = P.Shape()
|
|
transpose = P.Transpose()
|
|
axis = self.axis
|
|
if not isinstance(axis, int):
|
|
axis = axis[0]
|
|
|
|
def bprop(x, out, dout):
|
|
# dx = (dout - sum(dout * out)) * out
|
|
# This formula is correct only when the `axis` is the last dimension.
|
|
# In order to support the scenario where the `axis` is other values,
|
|
# we transpose the data of the `axis` dimension to the last dimension for calculation,
|
|
# and then transpose it back after the calculation.
|
|
reverse_axis = _get_transpose_axis(get_shape(x), axis)
|
|
out = transpose(out, reverse_axis)
|
|
dout = transpose(dout, reverse_axis)
|
|
dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
|
|
dx = transpose(dx, reverse_axis)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.LogSoftmax)
|
|
def get_bprop_log_softmax(self):
|
|
"""Grad definition for `LogSoftmax` operation."""
|
|
logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = logsoftmax_grad(out, dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Softplus)
|
|
def get_bprop_softplus(self):
|
|
"""Grad definition for `Softplus` operation."""
|
|
softplus_grad = G.SoftplusGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = softplus_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Softsign)
|
|
def get_bprop_softsign(self):
|
|
"""Grad definition for `Softsign` operation."""
|
|
mul = P.Mul()
|
|
absolute = P.Abs()
|
|
div = P.Div()
|
|
square = P.Square()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = mul(dout, div(1, square(1 + absolute(x))))
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Tanh)
|
|
def get_bprop_tanh(self):
|
|
"""Grad definition for `Tanh` operation."""
|
|
tanh_grad = SG.TanhGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = tanh_grad(out, dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(G.TanhGrad)
|
|
def get_bprop_tanh_grad(self):
|
|
"""Grad definition for `TanhGrad` operation."""
|
|
tanh_grad = G.TanhGrad()
|
|
|
|
def bprop(y, grad, out, dout):
|
|
dy = dout * -2.0 * grad * y
|
|
dgrad = tanh_grad(y, dout)
|
|
return dy, dgrad
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.GeLU)
|
|
def get_bprop_gelu(self):
|
|
"""Grad definition for `GeLU` operation."""
|
|
input_grad = G.GeLUGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x, out)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Gelu)
|
|
def get_bprop_gelu_2(self):
|
|
"""Grad definition for `GeLU` operation."""
|
|
input_grad = G.GeLUGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x, out)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.FastGeLU)
|
|
def get_bprop_fast_gelu(self):
|
|
"""Grad definition for `FastGeLU` operation."""
|
|
input_grad = G.FastGeLUGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.FastGelu)
|
|
def get_bprop_fast_gelu_2(self):
|
|
"""Grad definition for `FastGeLU` operation."""
|
|
input_grad = G.FastGeLUGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.FusedBatchNorm)
|
|
def get_bprop_fused_batch_norm(self):
|
|
"""Grad definition for `FusedBatchNorm` operation."""
|
|
input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum)
|
|
target_cpu = False
|
|
if self.target == "CPU":
|
|
input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum)
|
|
target_cpu = True
|
|
|
|
def bprop(x, scale, b, mean, variance, out, dout):
|
|
saved_mean = out[3]
|
|
saved_variance = out[4]
|
|
if target_cpu:
|
|
out = input_grad(dout[0], x, scale, b, saved_mean, saved_variance)
|
|
else:
|
|
out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
|
|
dx = out[0]
|
|
dscale = out[1]
|
|
dbias = out[2]
|
|
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.FusedBatchNormEx)
|
|
def get_bprop_fused_batch_norm_ex(self):
|
|
"""Grad definition for `FusedBatchNormEx` operation."""
|
|
input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum, self.format)
|
|
|
|
def bprop(x, scale, b, mean, variance, out, dout):
|
|
saved_mean = out[3]
|
|
saved_variance = out[4]
|
|
reserve = out[5]
|
|
out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
|
|
dx = out[0]
|
|
dscale = out[1]
|
|
dbias = out[2]
|
|
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.InstanceNorm)
|
|
def get_bprop_instance_norm(self):
|
|
"""Grad definition for `InstanceNorm` operation."""
|
|
is_training = self.is_training
|
|
input_grad = G.InstanceNormGrad(is_training, self.epsilon, self.momentum)
|
|
|
|
def bprop(x, gamma, beta, mean, variance, out, dout):
|
|
saved_mean = out[1]
|
|
saved_variance = out[2]
|
|
out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
|
|
dx = out[0]
|
|
dgamma = out[1]
|
|
dbeta = out[2]
|
|
return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.BatchNorm)
|
|
def get_bprop_batch_norm(self):
|
|
"""Grad definition for `BatchNorm` operation."""
|
|
is_training = self.is_training
|
|
input_grad = G.BatchNormGrad(is_training, self.epsilon)
|
|
|
|
def bprop(x, scale, b, mean, variance, out, dout):
|
|
if is_training:
|
|
saved_reserve_1 = out[3]
|
|
saved_reserve_2 = out[4]
|
|
else:
|
|
saved_reserve_1 = mean
|
|
saved_reserve_2 = variance
|
|
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
|
|
dx = out[0]
|
|
dscale = out[1]
|
|
dbias = out[2]
|
|
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.LayerNorm)
|
|
def get_bprop_layer_norm(self):
|
|
"""Grad definition for `LayerNorm` operation."""
|
|
layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
|
|
|
|
def bprop(x, gamma, beta, out, dout):
|
|
dx, d_gamma, d_beta = layer_norm_grad(
|
|
x, dout[0], out[2], out[1], gamma)
|
|
return dx, d_gamma, d_beta
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(G.LayerNormGrad)
|
|
def get_bprop_layer_norm_grad(self):
|
|
"""Grad definition for `LayerNormGrad` operation."""
|
|
layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
|
|
|
|
def bprop(x, dy, variance, mean, gamma, out, dout):
|
|
d_x, d_dy, d_gamma = layer_norm_grad_grad(
|
|
x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
|
|
return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.L2Normalize)
|
|
def get_bprop_l2normalize(self):
|
|
"""Grad definition for `L2Normalize` operation."""
|
|
input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = input_grad(x, out, dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
|
|
def get_bprop_softmax_cross_entropy_with_logits(self):
|
|
"""Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
|
|
expand = P.ExpandDims()
|
|
|
|
def bprop(logits, labels, out, dout):
|
|
grad = out[1]
|
|
grad = grad * expand(dout[0], -1)
|
|
return grad, zeros_like(labels)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.NLLLoss)
|
|
def get_bprop_nll_loss(self):
|
|
"""Grad definition for `NLLLoss` operation."""
|
|
nll_loss_grad = G.NLLLossGrad(reduction=self.reduction)
|
|
|
|
def bprop(x, target, weight, out, dout):
|
|
total_weight = out[1]
|
|
dout_x = dout[0]
|
|
dx = nll_loss_grad(x, dout_x, target, weight, total_weight)
|
|
return dx, zeros_like(target), zeros_like(weight)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
|
|
def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
|
|
"""Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
|
|
is_grad = self.is_grad
|
|
grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
|
|
|
|
def bprop(logits, labels, out, dout):
|
|
grad = out[0]
|
|
if not is_grad:
|
|
# if construct use loss
|
|
grad = grad_op(logits, labels)
|
|
grad = F.depend(grad, out)
|
|
grad = grad * dout
|
|
return grad, zeros_like(labels)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ResizeBilinear)
|
|
def get_bprop_resize_bilinear(self):
|
|
"""Grad definition for `ResizeBilinear` operation."""
|
|
resize_grad = G.ResizeBilinearGrad(self.align_corners)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = resize_grad(dout, x)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.OneHot)
|
|
def get_bprop_onehot(self):
|
|
"""Grad definition for `OneHot` operation."""
|
|
|
|
def bprop(indices, depth, on_value, off_value, out, dout):
|
|
return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
|
|
|
|
return bprop
|
|
|
|
|
|
@constexpr
|
|
def _range_op(start, limit, delta, dtype):
|
|
"""helper function for Grad TopK"""
|
|
output_tensor = Tensor(list(range(start, limit, delta)), dtype)
|
|
return output_tensor
|
|
|
|
|
|
@constexpr
|
|
def _get_1d_shape(in_shape):
|
|
"""helper function for Grad TopK"""
|
|
out_shape = 1
|
|
for i in in_shape:
|
|
out_shape *= i
|
|
return (out_shape,)
|
|
|
|
|
|
@bprop_getters.register(P.TopK)
|
|
def get_bprop_top_kv2(self):
|
|
"""Grad definition for `TopK` operation."""
|
|
scatter = P.ScatterNd()
|
|
expand_dims = P.ExpandDims()
|
|
shape_op = P.Shape()
|
|
reshape_op = P.Reshape()
|
|
dtype = P.DType()
|
|
|
|
def bprop(input_x, k, out, dout):
|
|
in_shape = shape_op(input_x)
|
|
in_lastdim = in_shape[-1]
|
|
|
|
indices = out[1]
|
|
ind_shape = shape_op(indices)
|
|
ind_lastdim = ind_shape[-1]
|
|
|
|
ind_2d = reshape_op(indices, (-1, ind_lastdim))
|
|
outerdim = shape_op(ind_2d)[0]
|
|
|
|
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
|
indices_dtype = dtype(indices)
|
|
range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
|
|
|
|
# expand_dims to (k, 1), then broadcast
|
|
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
|
|
in_shape_1d = _get_1d_shape(in_shape)
|
|
|
|
out_grad = reshape_op(
|
|
scatter(
|
|
expand_dims(ind, -1),
|
|
reshape_op(dout[0], (-1,)),
|
|
in_shape_1d),
|
|
in_shape)
|
|
return out_grad, zeros_like(k)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SmoothL1Loss)
|
|
def get_bprop_smooth_l1_loss(self):
|
|
"""Grad definition for `SmoothL1Loss` operation."""
|
|
grad = G.SmoothL1LossGrad(self.beta)
|
|
|
|
def bprop(prediction, target, out, dout):
|
|
dx = grad(prediction, target, dout)
|
|
dy = grad(target, prediction, dout)
|
|
return dx, dy
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.L2Loss)
|
|
def get_bprop_l2_loss(self):
|
|
"""Grad definition for `L2Loss` operation."""
|
|
|
|
def bprop(x, out, dout):
|
|
dx = x * dout
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.RNNTLoss)
|
|
def get_bprop_rnnt_loss(self):
|
|
"""Grad definition for `RNNTLoss` operation."""
|
|
|
|
def bprop(acts, labels, act_lens, label_lens, out, dout):
|
|
grad = out[1]
|
|
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.PReLU)
|
|
def get_bprop_prelu(self):
|
|
"""Grad definition for `PReLU` operation."""
|
|
grad = G.PReLUGrad()
|
|
|
|
def bprop(x, w, out, dout):
|
|
dx, dw = grad(dout, x, w)
|
|
return dx, dw
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.LSTM)
|
|
def get_bprop_lstm(self):
|
|
"""Grad definition for `LSTM` operation."""
|
|
lstm_grad_data = G.LSTMGradData(
|
|
input_size=self.input_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_layers,
|
|
has_bias=self.has_bias,
|
|
bidirectional=self.bidirectional,
|
|
dropout=self.dropout
|
|
)
|
|
|
|
lstm_grad_weight = G.LSTMGradWeight(
|
|
input_size=self.input_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_layers,
|
|
has_bias=self.has_bias,
|
|
bidirectional=self.bidirectional,
|
|
dropout=self.dropout
|
|
)
|
|
lstm_grad = G.LSTMGrad(
|
|
input_size=self.input_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_layers,
|
|
has_bias=self.has_bias,
|
|
bidirectional=self.bidirectional,
|
|
dropout=self.dropout
|
|
)
|
|
|
|
def bprop(x, hx, cx, w, out, dout):
|
|
y, _, _, reserve, state = out
|
|
dy, dhy, dcy, _, _ = dout
|
|
dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
|
|
dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
|
|
return dx, dhx, dcx, dw
|
|
|
|
#
|
|
def bprop_cpu(x, hx, cx, w, out, dout):
|
|
y, hy, cy, reserve, _ = out
|
|
dy, dhy, dcy, _, _ = dout
|
|
dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
|
|
return dx, dhx, dcx, dw
|
|
|
|
if context.get_context('device_target') == "CPU":
|
|
return bprop_cpu
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DynamicRNN)
|
|
def get_bprop_dynamic_rnn(self):
|
|
"""Grad definition for `DynamicRNN` operation."""
|
|
dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
|
|
direction=self.direction,
|
|
cell_depth=self.cell_depth,
|
|
use_peephole=self.use_peephole,
|
|
keep_prob=self.keep_prob,
|
|
cell_clip=self.cell_clip,
|
|
num_proj=self.num_proj,
|
|
time_major=self.time_major,
|
|
forget_bias=self.forget_bias)
|
|
expand_dims = P.ExpandDims()
|
|
|
|
def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
|
|
dy, dh, dc, _, _, _, _, _, = dout
|
|
dh = dh[-1]
|
|
dc = dc[-1]
|
|
y, h, c, i, j, f, o, tanhct = out
|
|
dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
|
|
c, dy, dh, dc, i, j, f, o, tanhct)
|
|
dh_prev = expand_dims(dh_prev, 0)
|
|
dc_prev = expand_dims(dc_prev, 0)
|
|
return dx, dw, db, (0), dh_prev, dc_prev
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DynamicGRUV2)
|
|
def get_bprop_dynamic_gru_v2(self):
|
|
"""Grad definition for `DynamicGRUV2` operation."""
|
|
dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
|
|
self.num_proj, self.time_major, self.gate_order,
|
|
self.reset_after)
|
|
|
|
def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
|
|
y, out_h, update, reset, new, hidden_new = out
|
|
dy, dout_h, _, _, _, _ = dout
|
|
|
|
dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
|
|
out_h, dy, dout_h[-1], update,
|
|
reset, new, hidden_new, None, None)
|
|
return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
|
|
def get_bprop_sigmoid_crossentropy_with_logits(self):
|
|
"""Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
|
|
op = G.SigmoidCrossEntropyWithLogitsGrad()
|
|
|
|
def bprop(x, y, out, dout):
|
|
dx = op(x, y, dout)
|
|
return (dx, zeros_like(y))
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Pad)
|
|
def get_bprop_pad(self):
|
|
"""Grad definition for `Pad` operation."""
|
|
shape_op = P.Shape()
|
|
paddings = self.paddings
|
|
|
|
def bprop(x, out, dout):
|
|
begin = ()
|
|
for item in paddings:
|
|
begin += (item[0],)
|
|
shp = shape_op(x)
|
|
dx = P.Slice()(dout, begin, shp)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.MirrorPad)
|
|
def get_bprop_mirror_pad(self):
|
|
"""Grad definition for `MirrorPad` operation."""
|
|
mirror_pad_grad = G.MirrorPadGrad(self.mode)
|
|
|
|
def bprop(x, paddings, out, dout):
|
|
dx = mirror_pad_grad(dout, paddings)
|
|
return (dx, zeros_like(paddings))
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ROIAlign)
|
|
def get_bprop_roi_align(self):
|
|
"""Grad definition for `ROIAlign` operation."""
|
|
shape_op = P.Shape()
|
|
pooled_height = self.pooled_height
|
|
pooled_width = self.pooled_width
|
|
spatial_scale = self.spatial_scale
|
|
sample_num = self.sample_num
|
|
|
|
def bprop(inputs, rois, out, dout):
|
|
inputs_shape = shape_op(inputs)
|
|
dx = G.ROIAlignGrad(inputs_shape,
|
|
pooled_height,
|
|
pooled_width,
|
|
spatial_scale,
|
|
sample_num,
|
|
)(dout, rois)
|
|
return dx, zeros_like(rois)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Conv2DBackpropInput)
|
|
def get_bprop_conv2d_backprop_input(self):
|
|
"""Grad definition for `Conv2DBackpropInput` operation."""
|
|
filter_grad = G.Conv2DBackpropFilter(
|
|
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
|
|
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
|
)
|
|
input_grad = P.Conv2D(
|
|
self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
|
|
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
|
|
)
|
|
|
|
def bprop(x, w, f_sizes, out, dout):
|
|
dx = input_grad(dout, w)
|
|
if env_force_bprop_seq == '1':
|
|
x = F.depend(x, dx)
|
|
dw = filter_grad(x, dout, F.shape(w))
|
|
return dx, dw, zeros_like(f_sizes)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.BinaryCrossEntropy)
|
|
def get_bprop_binary_cross_entropy(self):
|
|
"""Grad definition for `BinaryCrossEntropy` operation."""
|
|
grad = G.BinaryCrossEntropyGrad(self.reduction)
|
|
|
|
def bprop(x, y, weight, out, dout):
|
|
dx = grad(x, y, dout, weight)
|
|
return dx, zeros_like(y), zeros_like(weight)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.KLDivLoss)
|
|
def get_bprop_kl_div_loss(self):
|
|
"""Grad definition for `KLDivLoss` operation."""
|
|
grad = G.KLDivLossGrad(self.reduction)
|
|
|
|
def bprop(x, y, out, dout):
|
|
dx, dy = grad(x, y, dout)
|
|
return dx, dy
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Dropout)
|
|
def get_bprop_dropout(self):
|
|
"""Grad definition for `Dropout` operation."""
|
|
grad = G.DropoutGrad(self.keep_prob)
|
|
|
|
def bprop(x, out, dout):
|
|
_, mask = out
|
|
dy, _ = dout
|
|
dx = grad(dy, mask)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.CTCLoss)
|
|
def get_bprop_ctc_loss(self):
|
|
"""Grad definition for `CTCLoss` operation"""
|
|
expand = P.ExpandDims()
|
|
|
|
def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
|
|
grad_loss = out[1]
|
|
grad = grad_loss * expand(dout[0], -1)
|
|
return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.BasicLSTMCell)
|
|
def get_bprop_basic_lstm_cell(self):
|
|
"""Grad definition for `BasicLSTMCell` operation."""
|
|
basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
|
|
forget_bias=self.forget_bias,
|
|
activation=self.activation
|
|
)
|
|
|
|
basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
|
|
|
|
basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
|
|
|
|
def bprop(x, h, c, w, b, out, dout):
|
|
_, _, it, jt, ft, ot, tanhct = out
|
|
dct, dht, _, _, _, _, _ = dout
|
|
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
|
|
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
|
|
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
|
return dxt, dht, dct_1, dw, db
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.LRN)
|
|
def get_bprop_lrn(self):
|
|
"""Grad definition for `LRN` operation."""
|
|
grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = grad(dout, x, out)
|
|
return (dx,)
|
|
|
|
return bprop
|