tensor_fusion.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddleseg.models import layers
  19. from ppmatting.models.layers import tensor_fusion_helper as helper
  20. class MLFF(nn.Layer):
  21. """
  22. Multi-level features are fused adaptively by obtaining spatial attention.
  23. Args:
  24. in_channels(list): The channels of input tensors.
  25. mid_channles(list): The middle channels while fusing the features.
  26. out_channel(int): The output channel after fusing.
  27. merge_type(str): Which type to merge the multi features before output.
  28. It should be one of ('add', 'concat'). Default: 'concat'.
  29. """
  30. def __init__(self,
  31. in_channels,
  32. mid_channels,
  33. out_channel,
  34. merge_type='concat'):
  35. super().__init__()
  36. self.merge_type = merge_type
  37. # Check arguments
  38. if len(in_channels) != len(mid_channels):
  39. raise ValueError(
  40. "`mid_channels` should have the same length as `in_channels`, but they are {} and {}".
  41. format(mid_channels, in_channels))
  42. if self.merge_type == 'add' and len(np.unique(np.array(
  43. mid_channels))) != 1:
  44. raise ValueError(
  45. "if `merge_type='add', `mid_channels` should be same of all input features, but it is {}.".
  46. format(mid_channels))
  47. self.pwconvs = nn.LayerList()
  48. self.dwconvs = nn.LayerList()
  49. for in_channel, mid_channel in zip(in_channels, mid_channels):
  50. self.pwconvs.append(
  51. layers.ConvBN(
  52. in_channel, mid_channel, 1, bias_attr=False))
  53. self.dwconvs.append(
  54. layers.ConvBNReLU(
  55. mid_channel,
  56. mid_channel,
  57. 3,
  58. padding=1,
  59. groups=mid_channel,
  60. bias_attr=False))
  61. num_feas = len(in_channels)
  62. self.conv_atten = nn.Sequential(
  63. layers.ConvBNReLU(
  64. 2 * num_feas,
  65. num_feas,
  66. kernel_size=3,
  67. padding=1,
  68. bias_attr=False),
  69. layers.ConvBN(
  70. num_feas, num_feas, kernel_size=3, padding=1, bias_attr=False))
  71. if self.merge_type == 'add':
  72. in_chan = mid_channels[0]
  73. else:
  74. in_chan = sum(mid_channels)
  75. self.conv_out = layers.ConvBNReLU(
  76. in_chan, out_channel, kernel_size=3, padding=1, bias_attr=False)
  77. def forward(self, inputs, shape):
  78. """
  79. args:
  80. inputs(list): List of tensor to be fused.
  81. shape(Tensor): A tensor with two elements like (H, W).
  82. """
  83. feas = []
  84. for i, input in enumerate(inputs):
  85. x = self.pwconvs[i](input)
  86. x = F.interpolate(
  87. x, size=shape, mode='bilinear', align_corners=False)
  88. x = self.dwconvs[i](x)
  89. feas.append(x)
  90. atten = helper.avg_max_reduce_channel(feas)
  91. atten = F.sigmoid(self.conv_atten(atten))
  92. feas_att = []
  93. for i, fea in enumerate(feas):
  94. fea = fea * (atten[:, i, :, :].unsqueeze(1))
  95. feas_att.append(fea)
  96. if self.merge_type == 'concat':
  97. out = paddle.concat(feas_att, axis=1)
  98. else:
  99. out = sum(feas_att)
  100. out = self.conv_out(out)
  101. return out