dim.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddleseg.models import layers
  19. from paddleseg import utils
  20. from paddleseg.cvlibs import manager
  21. from ppmatting.models.losses import MRSD
  22. @manager.MODELS.add_component
  23. class DIM(nn.Layer):
  24. """
  25. The DIM implementation based on PaddlePaddle.
  26. The original article refers to
  27. Ning Xu, et, al. "Deep Image Matting"
  28. (https://arxiv.org/pdf/1908.07919.pdf).
  29. Args:
  30. backbone: backbone model.
  31. stage (int, optional): The stage of model. Defautl: 3.
  32. decoder_input_channels(int, optional): The channel of decoder input. Default: 512.
  33. pretrained(str, optional): The path of pretrianed model. Defautl: None.
  34. """
  35. def __init__(self,
  36. backbone,
  37. stage=3,
  38. decoder_input_channels=512,
  39. pretrained=None):
  40. super().__init__()
  41. self.backbone = backbone
  42. self.pretrained = pretrained
  43. self.stage = stage
  44. self.loss_func_dict = None
  45. decoder_output_channels = [64, 128, 256, 512]
  46. self.decoder = Decoder(
  47. input_channels=decoder_input_channels,
  48. output_channels=decoder_output_channels)
  49. if self.stage == 2:
  50. for param in self.backbone.parameters():
  51. param.stop_gradient = True
  52. for param in self.decoder.parameters():
  53. param.stop_gradient = True
  54. if self.stage >= 2:
  55. self.refine = Refine()
  56. self.init_weight()
  57. def forward(self, inputs):
  58. input_shape = paddle.shape(inputs['img'])[-2:]
  59. x = paddle.concat([inputs['img'], inputs['trimap'] / 255], axis=1)
  60. fea_list = self.backbone(x)
  61. # decoder stage
  62. up_shape = []
  63. for i in range(5):
  64. up_shape.append(paddle.shape(fea_list[i])[-2:])
  65. alpha_raw = self.decoder(fea_list, up_shape)
  66. alpha_raw = F.interpolate(
  67. alpha_raw, input_shape, mode='bilinear', align_corners=False)
  68. logit_dict = {'alpha_raw': alpha_raw}
  69. if self.stage < 2:
  70. return logit_dict
  71. if self.stage >= 2:
  72. # refine stage
  73. refine_input = paddle.concat([inputs['img'], alpha_raw], axis=1)
  74. alpha_refine = self.refine(refine_input)
  75. # finally alpha
  76. alpha_pred = alpha_refine + alpha_raw
  77. alpha_pred = F.interpolate(
  78. alpha_pred, input_shape, mode='bilinear', align_corners=False)
  79. if not self.training:
  80. alpha_pred = paddle.clip(alpha_pred, min=0, max=1)
  81. logit_dict['alpha_pred'] = alpha_pred
  82. if self.training:
  83. loss_dict = self.loss(logit_dict, inputs)
  84. return logit_dict, loss_dict
  85. else:
  86. return alpha_pred
  87. def loss(self, logit_dict, label_dict, loss_func_dict=None):
  88. if loss_func_dict is None:
  89. if self.loss_func_dict is None:
  90. self.loss_func_dict = defaultdict(list)
  91. self.loss_func_dict['alpha_raw'].append(MRSD())
  92. self.loss_func_dict['comp'].append(MRSD())
  93. self.loss_func_dict['alpha_pred'].append(MRSD())
  94. else:
  95. self.loss_func_dict = loss_func_dict
  96. loss = {}
  97. mask = label_dict['trimap'] == 128
  98. loss['all'] = 0
  99. if self.stage != 2:
  100. loss['alpha_raw'] = self.loss_func_dict['alpha_raw'][0](
  101. logit_dict['alpha_raw'], label_dict['alpha'], mask)
  102. loss['alpha_raw'] = 0.5 * loss['alpha_raw']
  103. loss['all'] = loss['all'] + loss['alpha_raw']
  104. if self.stage == 1 or self.stage == 3:
  105. comp_pred = logit_dict['alpha_raw'] * label_dict['fg'] + \
  106. (1 - logit_dict['alpha_raw']) * label_dict['bg']
  107. loss['comp'] = self.loss_func_dict['comp'][0](
  108. comp_pred, label_dict['img'], mask)
  109. loss['comp'] = 0.5 * loss['comp']
  110. loss['all'] = loss['all'] + loss['comp']
  111. if self.stage == 2 or self.stage == 3:
  112. loss['alpha_pred'] = self.loss_func_dict['alpha_pred'][0](
  113. logit_dict['alpha_pred'], label_dict['alpha'], mask)
  114. loss['all'] = loss['all'] + loss['alpha_pred']
  115. return loss
  116. def init_weight(self):
  117. if self.pretrained is not None:
  118. utils.load_entire_model(self, self.pretrained)
  119. # bilinear interpolate skip connect
  120. class Up(nn.Layer):
  121. def __init__(self, input_channels, output_channels):
  122. super().__init__()
  123. self.conv = layers.ConvBNReLU(
  124. input_channels,
  125. output_channels,
  126. kernel_size=5,
  127. padding=2,
  128. bias_attr=False)
  129. def forward(self, x, skip, output_shape):
  130. x = F.interpolate(
  131. x, size=output_shape, mode='bilinear', align_corners=False)
  132. x = x + skip
  133. x = self.conv(x)
  134. x = F.relu(x)
  135. return x
  136. class Decoder(nn.Layer):
  137. def __init__(self, input_channels, output_channels=(64, 128, 256, 512)):
  138. super().__init__()
  139. self.deconv6 = nn.Conv2D(
  140. input_channels, input_channels, kernel_size=1, bias_attr=False)
  141. self.deconv5 = Up(input_channels, output_channels[-1])
  142. self.deconv4 = Up(output_channels[-1], output_channels[-2])
  143. self.deconv3 = Up(output_channels[-2], output_channels[-3])
  144. self.deconv2 = Up(output_channels[-3], output_channels[-4])
  145. self.deconv1 = Up(output_channels[-4], 64)
  146. self.alpha_conv = nn.Conv2D(
  147. 64, 1, kernel_size=5, padding=2, bias_attr=False)
  148. def forward(self, fea_list, shape_list):
  149. x = fea_list[-1]
  150. x = self.deconv6(x)
  151. x = self.deconv5(x, fea_list[4], shape_list[4])
  152. x = self.deconv4(x, fea_list[3], shape_list[3])
  153. x = self.deconv3(x, fea_list[2], shape_list[2])
  154. x = self.deconv2(x, fea_list[1], shape_list[1])
  155. x = self.deconv1(x, fea_list[0], shape_list[0])
  156. alpha = self.alpha_conv(x)
  157. alpha = F.sigmoid(alpha)
  158. return alpha
  159. class Refine(nn.Layer):
  160. def __init__(self):
  161. super().__init__()
  162. self.conv1 = layers.ConvBNReLU(
  163. 4, 64, kernel_size=3, padding=1, bias_attr=False)
  164. self.conv2 = layers.ConvBNReLU(
  165. 64, 64, kernel_size=3, padding=1, bias_attr=False)
  166. self.conv3 = layers.ConvBNReLU(
  167. 64, 64, kernel_size=3, padding=1, bias_attr=False)
  168. self.alpha_pred = layers.ConvBNReLU(
  169. 64, 1, kernel_size=3, padding=1, bias_attr=False)
  170. def forward(self, x):
  171. x = self.conv1(x)
  172. x = self.conv2(x)
  173. x = self.conv3(x)
  174. alpha = self.alpha_pred(x)
  175. return alpha