ppmatting.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. import time
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. import paddleseg
  20. from paddleseg.models import layers
  21. from paddleseg import utils
  22. from paddleseg.cvlibs import manager
  23. from ppmatting.models.losses import MRSD, GradientLoss
  24. from ppmatting.models.backbone import resnet_vd
  25. @manager.MODELS.add_component
  26. class PPMatting(nn.Layer):
  27. """
  28. The PPMattinh implementation based on PaddlePaddle.
  29. The original article refers to
  30. Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting"
  31. (https://arxiv.org/pdf/2204.09433.pdf).
  32. Args:
  33. backbone: backbone model.
  34. pretrained(str, optional): The path of pretrianed model. Defautl: None.
  35. """
  36. def __init__(self, backbone, pretrained=None):
  37. super().__init__()
  38. self.backbone = backbone
  39. self.pretrained = pretrained
  40. self.loss_func_dict = self.get_loss_func_dict()
  41. self.backbone_channels = backbone.feat_channels
  42. self.scb = SCB(self.backbone_channels[-1])
  43. self.hrdb = HRDB(
  44. self.backbone_channels[0] + self.backbone_channels[1],
  45. scb_channels=self.scb.out_channels,
  46. gf_index=[0, 2, 4])
  47. self.init_weight()
  48. def forward(self, inputs):
  49. x = inputs['img']
  50. input_shape = paddle.shape(x)
  51. fea_list = self.backbone(x)
  52. scb_logits = self.scb(fea_list[-1])
  53. semantic_map = F.softmax(scb_logits[-1], axis=1)
  54. fea0 = F.interpolate(
  55. fea_list[0], input_shape[2:], mode='bilinear', align_corners=False)
  56. fea1 = F.interpolate(
  57. fea_list[1], input_shape[2:], mode='bilinear', align_corners=False)
  58. hrdb_input = paddle.concat([fea0, fea1], 1)
  59. hrdb_logit = self.hrdb(hrdb_input, scb_logits)
  60. detail_map = F.sigmoid(hrdb_logit)
  61. fusion = self.fusion(semantic_map, detail_map)
  62. if self.training:
  63. logit_dict = {
  64. 'semantic': semantic_map,
  65. 'detail': detail_map,
  66. 'fusion': fusion
  67. }
  68. loss_dict = self.loss(logit_dict, inputs)
  69. return logit_dict, loss_dict
  70. else:
  71. return fusion
  72. def get_loss_func_dict(self):
  73. loss_func_dict = defaultdict(list)
  74. loss_func_dict['semantic'].append(nn.NLLLoss())
  75. loss_func_dict['detail'].append(MRSD())
  76. loss_func_dict['detail'].append(GradientLoss())
  77. loss_func_dict['fusion'].append(MRSD())
  78. loss_func_dict['fusion'].append(MRSD())
  79. loss_func_dict['fusion'].append(GradientLoss())
  80. return loss_func_dict
  81. def loss(self, logit_dict, label_dict):
  82. loss = {}
  83. # semantic loss computation
  84. # get semantic label
  85. semantic_label = label_dict['trimap']
  86. semantic_label_trans = (semantic_label == 128).astype('int64')
  87. semantic_label_bg = (semantic_label == 0).astype('int64')
  88. semantic_label = semantic_label_trans + semantic_label_bg * 2
  89. loss_semantic = self.loss_func_dict['semantic'][0](
  90. paddle.log(logit_dict['semantic'] + 1e-6),
  91. semantic_label.squeeze(1))
  92. loss['semantic'] = loss_semantic
  93. # detail loss computation
  94. transparent = label_dict['trimap'] == 128
  95. detail_alpha_loss = self.loss_func_dict['detail'][0](
  96. logit_dict['detail'], label_dict['alpha'], transparent)
  97. # gradient loss
  98. detail_gradient_loss = self.loss_func_dict['detail'][1](
  99. logit_dict['detail'], label_dict['alpha'], transparent)
  100. loss_detail = detail_alpha_loss + detail_gradient_loss
  101. loss['detail'] = loss_detail
  102. loss['detail_alpha'] = detail_alpha_loss
  103. loss['detail_gradient'] = detail_gradient_loss
  104. # fusion loss
  105. loss_fusion_func = self.loss_func_dict['fusion']
  106. # fusion_sigmoid loss
  107. fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'],
  108. label_dict['alpha'])
  109. # composion loss
  110. comp_pred = logit_dict['fusion'] * label_dict['fg'] + (
  111. 1 - logit_dict['fusion']) * label_dict['bg']
  112. comp_gt = label_dict['alpha'] * label_dict['fg'] + (
  113. 1 - label_dict['alpha']) * label_dict['bg']
  114. fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt)
  115. # grandient loss
  116. fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'],
  117. label_dict['alpha'])
  118. # fusion loss
  119. loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss
  120. loss['fusion'] = loss_fusion
  121. loss['fusion_alpha'] = fusion_alpha_loss
  122. loss['fusion_composition'] = fusion_composition_loss
  123. loss['fusion_gradient'] = fusion_grad_loss
  124. loss[
  125. 'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion
  126. return loss
  127. def fusion(self, semantic_map, detail_map):
  128. # semantic_map [N, 3, H, W]
  129. # In index, 0 is foreground, 1 is transition, 2 is backbone
  130. # After fusion, the foreground is 1, the background is 0, and the transion is between [0, 1]
  131. index = paddle.argmax(semantic_map, axis=1, keepdim=True)
  132. transition_mask = (index == 1).astype('float32')
  133. fg = (index == 0).astype('float32')
  134. alpha = detail_map * transition_mask + fg
  135. return alpha
  136. def init_weight(self):
  137. if self.pretrained is not None:
  138. utils.load_entire_model(self, self.pretrained)
  139. class SCB(nn.Layer):
  140. def __init__(self, in_channels):
  141. super().__init__()
  142. self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64]
  143. self.mid_channels = [512, 256, 128, 128, 64, 64]
  144. self.out_channels = [256, 128, 64, 64, 64, 3]
  145. self.psp_module = layers.PPModule(
  146. in_channels,
  147. 512,
  148. bin_sizes=(1, 3, 5),
  149. dim_reduction=False,
  150. align_corners=False)
  151. psp_upsamples = [2, 4, 8, 16]
  152. self.psps = nn.LayerList([
  153. self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i])
  154. for i in range(4)
  155. ])
  156. scb_list = [
  157. self._make_stage(
  158. self.in_channels[i],
  159. self.mid_channels[i],
  160. self.out_channels[i],
  161. padding=int(i == 0) + 1,
  162. dilation=int(i == 0) + 1)
  163. for i in range(len(self.in_channels) - 1)
  164. ]
  165. scb_list += [
  166. nn.Sequential(
  167. layers.ConvBNReLU(
  168. self.in_channels[-1], self.mid_channels[-1], 3, padding=1),
  169. layers.ConvBNReLU(
  170. self.mid_channels[-1], self.mid_channels[-1], 3, padding=1),
  171. nn.Conv2D(
  172. self.mid_channels[-1], self.out_channels[-1], 3, padding=1))
  173. ]
  174. self.scb_stages = nn.LayerList(scb_list)
  175. def forward(self, x):
  176. psp_x = self.psp_module(x)
  177. psps = [psp(psp_x) for psp in self.psps]
  178. scb_logits = []
  179. for i, scb_stage in enumerate(self.scb_stages):
  180. if i == 0:
  181. x = scb_stage(paddle.concat((psp_x, x), 1))
  182. elif i <= len(psps):
  183. x = scb_stage(paddle.concat((psps[i - 1], x), 1))
  184. else:
  185. x = scb_stage(x)
  186. scb_logits.append(x)
  187. return scb_logits
  188. def conv_up_psp(self, in_channels, out_channels, up_sample):
  189. return nn.Sequential(
  190. layers.ConvBNReLU(
  191. in_channels, out_channels, 3, padding=1),
  192. nn.Upsample(
  193. scale_factor=up_sample, mode='bilinear', align_corners=False))
  194. def _make_stage(self,
  195. in_channels,
  196. mid_channels,
  197. out_channels,
  198. padding=1,
  199. dilation=1):
  200. layer_list = [
  201. layers.ConvBNReLU(
  202. in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU(
  203. mid_channels,
  204. mid_channels,
  205. 3,
  206. padding=padding,
  207. dilation=dilation), layers.ConvBNReLU(
  208. mid_channels,
  209. out_channels,
  210. 3,
  211. padding=padding,
  212. dilation=dilation), nn.Upsample(
  213. scale_factor=2,
  214. mode='bilinear',
  215. align_corners=False)
  216. ]
  217. return nn.Sequential(*layer_list)
  218. class HRDB(nn.Layer):
  219. """
  220. The High-Resolution Detail Branch
  221. Args:
  222. in_channels(int): The number of input channels.
  223. scb_channels(list|tuple): The channels of scb logits
  224. gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4)
  225. """
  226. def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)):
  227. super().__init__()
  228. self.gf_index = gf_index
  229. self.gf_list = nn.LayerList(
  230. [nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index])
  231. channels = [64, 32, 16, 8]
  232. self.res_list = [
  233. resnet_vd.BasicBlock(
  234. in_channels, channels[0], stride=1, shortcut=False)
  235. ]
  236. self.res_list += [
  237. resnet_vd.BasicBlock(
  238. i, i, stride=1) for i in channels[1:-1]
  239. ]
  240. self.res_list = nn.LayerList(self.res_list)
  241. self.convs = nn.LayerList([
  242. nn.Conv2D(
  243. channels[i], channels[i + 1], kernel_size=1)
  244. for i in range(len(channels) - 1)
  245. ])
  246. self.gates = nn.LayerList(
  247. [GatedSpatailConv2d(i, i) for i in channels[1:]])
  248. self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False)
  249. def forward(self, x, scb_logits):
  250. for i in range(len(self.res_list)):
  251. x = self.res_list[i](x)
  252. x = self.convs[i](x)
  253. gf = self.gf_list[i](scb_logits[self.gf_index[i]])
  254. gf = F.interpolate(
  255. gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False)
  256. x = self.gates[i](x, gf)
  257. return self.detail_conv(x)
  258. class GatedSpatailConv2d(nn.Layer):
  259. def __init__(self,
  260. in_channels,
  261. out_channels,
  262. kernel_size=1,
  263. stride=1,
  264. padding=0,
  265. dilation=1,
  266. groups=1,
  267. bias_attr=False):
  268. super().__init__()
  269. self._gate_conv = nn.Sequential(
  270. layers.SyncBatchNorm(in_channels + 1),
  271. nn.Conv2D(
  272. in_channels + 1, in_channels + 1, kernel_size=1),
  273. nn.ReLU(),
  274. nn.Conv2D(
  275. in_channels + 1, 1, kernel_size=1),
  276. layers.SyncBatchNorm(1),
  277. nn.Sigmoid())
  278. self.conv = nn.Conv2D(
  279. in_channels,
  280. out_channels,
  281. kernel_size=kernel_size,
  282. stride=stride,
  283. padding=padding,
  284. dilation=dilation,
  285. groups=groups,
  286. bias_attr=bias_attr)
  287. def forward(self, input_features, gating_features):
  288. cat = paddle.concat([input_features, gating_features], axis=1)
  289. alphas = self._gate_conv(cat)
  290. x = input_features * (alphas + 1)
  291. x = self.conv(x)
  292. return x