gca_module.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting
  15. # and https://github.com/open-mmlab/mmediting
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. from paddleseg.cvlibs import param_init
  20. class GuidedCxtAtten(nn.Layer):
  21. def __init__(self,
  22. out_channels,
  23. guidance_channels,
  24. kernel_size=3,
  25. stride=1,
  26. rate=2):
  27. super().__init__()
  28. self.kernel_size = kernel_size
  29. self.rate = rate
  30. self.stride = stride
  31. self.guidance_conv = nn.Conv2D(
  32. in_channels=guidance_channels,
  33. out_channels=guidance_channels // 2,
  34. kernel_size=1)
  35. self.out_conv = nn.Sequential(
  36. nn.Conv2D(
  37. in_channels=out_channels,
  38. out_channels=out_channels,
  39. kernel_size=1,
  40. bias_attr=False),
  41. nn.BatchNorm(out_channels))
  42. self.init_weight()
  43. def init_weight(self):
  44. param_init.xavier_uniform(self.guidance_conv.weight)
  45. param_init.constant_init(self.guidance_conv.bias, value=0.0)
  46. param_init.xavier_uniform(self.out_conv[0].weight)
  47. param_init.constant_init(self.out_conv[1].weight, value=1e-3)
  48. param_init.constant_init(self.out_conv[1].bias, value=0.0)
  49. def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.):
  50. img_feat = self.guidance_conv(img_feat)
  51. img_feat = F.interpolate(
  52. img_feat, scale_factor=1 / self.rate, mode='nearest')
  53. # process unknown mask
  54. unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat,
  55. softmax_scale)
  56. img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches(
  57. img_feat, alpha_feat, unknown)
  58. self_mask = self.get_self_correlation_mask(img_feat)
  59. # split tensors by batch dimension; tuple is returned
  60. img_groups = paddle.split(img_feat, 1, axis=0)
  61. img_ps_groups = paddle.split(img_ps, 1, axis=0)
  62. alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0)
  63. unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0)
  64. scale_groups = paddle.split(softmax_scale, 1, axis=0)
  65. groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups,
  66. scale_groups)
  67. y = []
  68. for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups):
  69. # conv for compare
  70. similarity_map = self.compute_similarity_map(img_i, img_ps_i)
  71. gca_score = self.compute_guided_attention_score(
  72. similarity_map, unknown_ps_i, scale_i, self_mask)
  73. yi = self.propagate_alpha_feature(gca_score, alpha_ps_i)
  74. y.append(yi)
  75. y = paddle.concat(y, axis=0) # back to the mini-batch
  76. y = paddle.reshape(y, alpha_feat.shape)
  77. y = self.out_conv(y) + alpha_feat
  78. return y
  79. def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown):
  80. # extract image feature patches with shape:
  81. # (N, img_h*img_w, img_c, img_ks, img_ks)
  82. img_ks = self.kernel_size
  83. img_ps = self.extract_patches(img_feat, img_ks, self.stride)
  84. # extract alpha feature patches with shape:
  85. # (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks)
  86. alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate)
  87. # extract unknown mask patches with shape: (N, img_h*img_w, 1, 1)
  88. unknown_ps = self.extract_patches(unknown, img_ks, self.stride)
  89. unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension
  90. unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True)
  91. return img_ps, alpha_ps, unknown_ps
  92. def extract_patches(self, x, kernel_size, stride):
  93. n, c, _, _ = x.shape
  94. x = self.pad(x, kernel_size, stride)
  95. x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride])
  96. x = paddle.transpose(x, (0, 2, 1))
  97. x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size))
  98. return x
  99. def pad(self, x, kernel_size, stride):
  100. left = (kernel_size - stride + 1) // 2
  101. right = (kernel_size - stride) // 2
  102. pad = (left, right, left, right)
  103. return F.pad(x, pad, mode='reflect')
  104. def compute_guided_attention_score(self, similarity_map, unknown_ps, scale,
  105. self_mask):
  106. # scale the correlation with predicted scale factor for known and
  107. # unknown area
  108. unknown_scale, known_scale = scale[0]
  109. out = similarity_map * (
  110. unknown_scale * paddle.greater_than(unknown_ps,
  111. paddle.to_tensor([0.])) +
  112. known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.])))
  113. # mask itself, self-mask only applied to unknown area
  114. out = out + self_mask * unknown_ps
  115. gca_score = F.softmax(out, axis=1)
  116. return gca_score
  117. def propagate_alpha_feature(self, gca_score, alpha_ps):
  118. alpha_ps = alpha_ps[0] # squeeze dim 0
  119. if self.rate == 1:
  120. gca_score = self.pad(gca_score, kernel_size=2, stride=1)
  121. alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3))
  122. out = F.conv2d(gca_score, alpha_ps) / 4.
  123. else:
  124. out = F.conv2d_transpose(
  125. gca_score, alpha_ps, stride=self.rate, padding=1) / 4.
  126. return out
  127. def compute_similarity_map(self, img_feat, img_ps):
  128. img_ps = img_ps[0] # squeeze dim 0
  129. # convolve the feature to get correlation (similarity) map
  130. img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4)
  131. img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect')
  132. similarity_map = F.conv2d(img_feat, img_ps_normed)
  133. return similarity_map
  134. def get_self_correlation_mask(self, img_feat):
  135. _, _, h, w = img_feat.shape
  136. self_mask = F.one_hot(
  137. paddle.reshape(paddle.arange(h * w), (h, w)),
  138. num_classes=int(h * w))
  139. self_mask = paddle.transpose(self_mask, (2, 0, 1))
  140. self_mask = paddle.reshape(self_mask, (1, h * w, h, w))
  141. return self_mask * (-1e4)
  142. def process_unknown_mask(self, unknown, img_feat, softmax_scale):
  143. n, _, h, w = img_feat.shape
  144. if unknown is not None:
  145. unknown = unknown.clone()
  146. unknown = F.interpolate(
  147. unknown, scale_factor=1 / self.rate, mode='nearest')
  148. unknown_mean = unknown.mean(axis=[2, 3])
  149. known_mean = 1 - unknown_mean
  150. unknown_scale = paddle.clip(
  151. paddle.sqrt(unknown_mean / known_mean), 0.1, 10)
  152. known_scale = paddle.clip(
  153. paddle.sqrt(known_mean / unknown_mean), 0.1, 10)
  154. softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1)
  155. else:
  156. unknown = paddle.ones([n, 1, h, w])
  157. softmax_scale = paddle.reshape(
  158. paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2))
  159. softmax_scale = paddle.expand(softmax_scale, (n, 2))
  160. return unknown, softmax_scale
  161. @staticmethod
  162. def l2_norm(x):
  163. x = x**2
  164. x = x.sum(axis=[1, 2, 3], keepdim=True)
  165. return paddle.sqrt(x)