human_matting.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. import time
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. import paddleseg
  20. from paddleseg.models import layers
  21. from paddleseg import utils
  22. from paddleseg.cvlibs import manager
  23. from ppmatting.models.losses import MRSD
  24. def conv_up_psp(in_channels, out_channels, up_sample):
  25. return nn.Sequential(
  26. layers.ConvBNReLU(
  27. in_channels, out_channels, 3, padding=1),
  28. nn.Upsample(
  29. scale_factor=up_sample, mode='bilinear', align_corners=False))
  30. @manager.MODELS.add_component
  31. class HumanMatting(nn.Layer):
  32. """A model for """
  33. def __init__(self,
  34. backbone,
  35. pretrained=None,
  36. backbone_scale=0.25,
  37. refine_kernel_size=3,
  38. if_refine=True):
  39. super().__init__()
  40. if if_refine:
  41. if backbone_scale > 0.5:
  42. raise ValueError(
  43. 'Backbone_scale should not be greater than 1/2, but it is {}'
  44. .format(backbone_scale))
  45. else:
  46. backbone_scale = 1
  47. self.backbone = backbone
  48. self.backbone_scale = backbone_scale
  49. self.pretrained = pretrained
  50. self.if_refine = if_refine
  51. if if_refine:
  52. self.refiner = Refiner(kernel_size=refine_kernel_size)
  53. self.loss_func_dict = None
  54. self.backbone_channels = backbone.feat_channels
  55. ######################
  56. ### Decoder part - Glance
  57. ######################
  58. self.psp_module = layers.PPModule(
  59. self.backbone_channels[-1],
  60. 512,
  61. bin_sizes=(1, 3, 5),
  62. dim_reduction=False,
  63. align_corners=False)
  64. self.psp4 = conv_up_psp(512, 256, 2)
  65. self.psp3 = conv_up_psp(512, 128, 4)
  66. self.psp2 = conv_up_psp(512, 64, 8)
  67. self.psp1 = conv_up_psp(512, 64, 16)
  68. # stage 5g
  69. self.decoder5_g = nn.Sequential(
  70. layers.ConvBNReLU(
  71. 512 + self.backbone_channels[-1], 512, 3, padding=1),
  72. layers.ConvBNReLU(
  73. 512, 512, 3, padding=2, dilation=2),
  74. layers.ConvBNReLU(
  75. 512, 256, 3, padding=2, dilation=2),
  76. nn.Upsample(
  77. scale_factor=2, mode='bilinear', align_corners=False))
  78. # stage 4g
  79. self.decoder4_g = nn.Sequential(
  80. layers.ConvBNReLU(
  81. 512, 256, 3, padding=1),
  82. layers.ConvBNReLU(
  83. 256, 256, 3, padding=1),
  84. layers.ConvBNReLU(
  85. 256, 128, 3, padding=1),
  86. nn.Upsample(
  87. scale_factor=2, mode='bilinear', align_corners=False))
  88. # stage 3g
  89. self.decoder3_g = nn.Sequential(
  90. layers.ConvBNReLU(
  91. 256, 128, 3, padding=1),
  92. layers.ConvBNReLU(
  93. 128, 128, 3, padding=1),
  94. layers.ConvBNReLU(
  95. 128, 64, 3, padding=1),
  96. nn.Upsample(
  97. scale_factor=2, mode='bilinear', align_corners=False))
  98. # stage 2g
  99. self.decoder2_g = nn.Sequential(
  100. layers.ConvBNReLU(
  101. 128, 128, 3, padding=1),
  102. layers.ConvBNReLU(
  103. 128, 128, 3, padding=1),
  104. layers.ConvBNReLU(
  105. 128, 64, 3, padding=1),
  106. nn.Upsample(
  107. scale_factor=2, mode='bilinear', align_corners=False))
  108. # stage 1g
  109. self.decoder1_g = nn.Sequential(
  110. layers.ConvBNReLU(
  111. 128, 64, 3, padding=1),
  112. layers.ConvBNReLU(
  113. 64, 64, 3, padding=1),
  114. layers.ConvBNReLU(
  115. 64, 64, 3, padding=1),
  116. nn.Upsample(
  117. scale_factor=2, mode='bilinear', align_corners=False))
  118. # stage 0g
  119. self.decoder0_g = nn.Sequential(
  120. layers.ConvBNReLU(
  121. 64, 64, 3, padding=1),
  122. layers.ConvBNReLU(
  123. 64, 64, 3, padding=1),
  124. nn.Conv2D(
  125. 64, 3, 3, padding=1))
  126. ##########################
  127. ### Decoder part - FOCUS
  128. ##########################
  129. self.bridge_block = nn.Sequential(
  130. layers.ConvBNReLU(
  131. self.backbone_channels[-1], 512, 3, dilation=2, padding=2),
  132. layers.ConvBNReLU(
  133. 512, 512, 3, dilation=2, padding=2),
  134. layers.ConvBNReLU(
  135. 512, 512, 3, dilation=2, padding=2))
  136. # stage 5f
  137. self.decoder5_f = nn.Sequential(
  138. layers.ConvBNReLU(
  139. 512 + self.backbone_channels[-1], 512, 3, padding=1),
  140. layers.ConvBNReLU(
  141. 512, 512, 3, padding=2, dilation=2),
  142. layers.ConvBNReLU(
  143. 512, 256, 3, padding=2, dilation=2),
  144. nn.Upsample(
  145. scale_factor=2, mode='bilinear', align_corners=False))
  146. # stage 4f
  147. self.decoder4_f = nn.Sequential(
  148. layers.ConvBNReLU(
  149. 256 + self.backbone_channels[-2], 256, 3, padding=1),
  150. layers.ConvBNReLU(
  151. 256, 256, 3, padding=1),
  152. layers.ConvBNReLU(
  153. 256, 128, 3, padding=1),
  154. nn.Upsample(
  155. scale_factor=2, mode='bilinear', align_corners=False))
  156. # stage 3f
  157. self.decoder3_f = nn.Sequential(
  158. layers.ConvBNReLU(
  159. 128 + self.backbone_channels[-3], 128, 3, padding=1),
  160. layers.ConvBNReLU(
  161. 128, 128, 3, padding=1),
  162. layers.ConvBNReLU(
  163. 128, 64, 3, padding=1),
  164. nn.Upsample(
  165. scale_factor=2, mode='bilinear', align_corners=False))
  166. # stage 2f
  167. self.decoder2_f = nn.Sequential(
  168. layers.ConvBNReLU(
  169. 64 + self.backbone_channels[-4], 128, 3, padding=1),
  170. layers.ConvBNReLU(
  171. 128, 128, 3, padding=1),
  172. layers.ConvBNReLU(
  173. 128, 64, 3, padding=1),
  174. nn.Upsample(
  175. scale_factor=2, mode='bilinear', align_corners=False))
  176. # stage 1f
  177. self.decoder1_f = nn.Sequential(
  178. layers.ConvBNReLU(
  179. 64 + self.backbone_channels[-5], 64, 3, padding=1),
  180. layers.ConvBNReLU(
  181. 64, 64, 3, padding=1),
  182. layers.ConvBNReLU(
  183. 64, 64, 3, padding=1),
  184. nn.Upsample(
  185. scale_factor=2, mode='bilinear', align_corners=False))
  186. # stage 0f
  187. self.decoder0_f = nn.Sequential(
  188. layers.ConvBNReLU(
  189. 64, 64, 3, padding=1),
  190. layers.ConvBNReLU(
  191. 64, 64, 3, padding=1),
  192. nn.Conv2D(
  193. 64, 1 + 1 + 32, 3, padding=1))
  194. self.init_weight()
  195. def forward(self, data):
  196. src = data['img']
  197. src_h, src_w = paddle.shape(src)[2:]
  198. if self.if_refine:
  199. # It is not need when exporting.
  200. if isinstance(src_h, paddle.Tensor):
  201. if (src_h % 4 != 0) or (src_w % 4) != 0:
  202. raise ValueError(
  203. 'The input image must have width and height that are divisible by 4'
  204. )
  205. # Downsample src for backbone
  206. src_sm = F.interpolate(
  207. src,
  208. scale_factor=self.backbone_scale,
  209. mode='bilinear',
  210. align_corners=False)
  211. # Base
  212. fea_list = self.backbone(src_sm)
  213. ##########################
  214. ### Decoder part - GLANCE
  215. ##########################
  216. #psp: N, 512, H/32, W/32
  217. psp = self.psp_module(fea_list[-1])
  218. #d6_g: N, 512, H/16, W/16
  219. d5_g = self.decoder5_g(paddle.concat((psp, fea_list[-1]), 1))
  220. #d5_g: N, 512, H/8, W/8
  221. d4_g = self.decoder4_g(paddle.concat((self.psp4(psp), d5_g), 1))
  222. #d4_g: N, 256, H/4, W/4
  223. d3_g = self.decoder3_g(paddle.concat((self.psp3(psp), d4_g), 1))
  224. #d4_g: N, 128, H/2, W/2
  225. d2_g = self.decoder2_g(paddle.concat((self.psp2(psp), d3_g), 1))
  226. #d2_g: N, 64, H, W
  227. d1_g = self.decoder1_g(paddle.concat((self.psp1(psp), d2_g), 1))
  228. #d0_g: N, 3, H, W
  229. d0_g = self.decoder0_g(d1_g)
  230. # The 1st channel is foreground. The 2nd is transition region. The 3rd is background.
  231. # glance_sigmoid = F.sigmoid(d0_g)
  232. glance_sigmoid = F.softmax(d0_g, axis=1)
  233. ##########################
  234. ### Decoder part - FOCUS
  235. ##########################
  236. bb = self.bridge_block(fea_list[-1])
  237. #bg: N, 512, H/32, W/32
  238. d5_f = self.decoder5_f(paddle.concat((bb, fea_list[-1]), 1))
  239. #d5_f: N, 256, H/16, W/16
  240. d4_f = self.decoder4_f(paddle.concat((d5_f, fea_list[-2]), 1))
  241. #d4_f: N, 128, H/8, W/8
  242. d3_f = self.decoder3_f(paddle.concat((d4_f, fea_list[-3]), 1))
  243. #d3_f: N, 64, H/4, W/4
  244. d2_f = self.decoder2_f(paddle.concat((d3_f, fea_list[-4]), 1))
  245. #d2_f: N, 64, H/2, W/2
  246. d1_f = self.decoder1_f(paddle.concat((d2_f, fea_list[-5]), 1))
  247. #d1_f: N, 64, H, W
  248. d0_f = self.decoder0_f(d1_f)
  249. #d0_f: N, 1, H, W
  250. focus_sigmoid = F.sigmoid(d0_f[:, 0:1, :, :])
  251. pha_sm = self.fusion(glance_sigmoid, focus_sigmoid)
  252. err_sm = d0_f[:, 1:2, :, :]
  253. err_sm = paddle.clip(err_sm, 0., 1.)
  254. hid_sm = F.relu(d0_f[:, 2:, :, :])
  255. # Refiner
  256. if self.if_refine:
  257. pha = self.refiner(
  258. src=src, pha=pha_sm, err=err_sm, hid=hid_sm, tri=glance_sigmoid)
  259. # Clamp outputs
  260. pha = paddle.clip(pha, 0., 1.)
  261. if self.training:
  262. logit_dict = {
  263. 'glance': glance_sigmoid,
  264. 'focus': focus_sigmoid,
  265. 'fusion': pha_sm,
  266. 'error': err_sm
  267. }
  268. if self.if_refine:
  269. logit_dict['refine'] = pha
  270. loss_dict = self.loss(logit_dict, data)
  271. return logit_dict, loss_dict
  272. else:
  273. return pha if self.if_refine else pha_sm
  274. def loss(self, logit_dict, label_dict, loss_func_dict=None):
  275. if loss_func_dict is None:
  276. if self.loss_func_dict is None:
  277. self.loss_func_dict = defaultdict(list)
  278. self.loss_func_dict['glance'].append(nn.NLLLoss())
  279. self.loss_func_dict['focus'].append(MRSD())
  280. self.loss_func_dict['cm'].append(MRSD())
  281. self.loss_func_dict['err'].append(paddleseg.models.MSELoss())
  282. self.loss_func_dict['refine'].append(paddleseg.models.L1Loss())
  283. else:
  284. self.loss_func_dict = loss_func_dict
  285. loss = {}
  286. # glance loss computation
  287. # get glance label
  288. glance_label = F.interpolate(
  289. label_dict['trimap'],
  290. logit_dict['glance'].shape[2:],
  291. mode='nearest',
  292. align_corners=False)
  293. glance_label_trans = (glance_label == 128).astype('int64')
  294. glance_label_bg = (glance_label == 0).astype('int64')
  295. glance_label = glance_label_trans + glance_label_bg * 2
  296. loss_glance = self.loss_func_dict['glance'][0](
  297. paddle.log(logit_dict['glance'] + 1e-6), glance_label.squeeze(1))
  298. loss['glance'] = loss_glance
  299. # focus loss computation
  300. focus_label = F.interpolate(
  301. label_dict['alpha'],
  302. logit_dict['focus'].shape[2:],
  303. mode='bilinear',
  304. align_corners=False)
  305. loss_focus = self.loss_func_dict['focus'][0](
  306. logit_dict['focus'], focus_label, glance_label_trans)
  307. loss['focus'] = loss_focus
  308. # collaborative matting loss
  309. loss_cm_func = self.loss_func_dict['cm']
  310. # fusion_sigmoid loss
  311. loss_cm = loss_cm_func[0](logit_dict['fusion'], focus_label)
  312. loss['cm'] = loss_cm
  313. # error loss
  314. err = F.interpolate(
  315. logit_dict['error'],
  316. label_dict['alpha'].shape[2:],
  317. mode='bilinear',
  318. align_corners=False)
  319. err_label = (F.interpolate(
  320. logit_dict['fusion'],
  321. label_dict['alpha'].shape[2:],
  322. mode='bilinear',
  323. align_corners=False) - label_dict['alpha']).abs()
  324. loss_err = self.loss_func_dict['err'][0](err, err_label)
  325. loss['err'] = loss_err
  326. loss_all = 0.25 * loss_glance + 0.25 * loss_focus + 0.25 * loss_cm + loss_err
  327. # refine loss
  328. if self.if_refine:
  329. loss_refine = self.loss_func_dict['refine'][0](logit_dict['refine'],
  330. label_dict['alpha'])
  331. loss['refine'] = loss_refine
  332. loss_all = loss_all + loss_refine
  333. loss['all'] = loss_all
  334. return loss
  335. def fusion(self, glance_sigmoid, focus_sigmoid):
  336. # glance_sigmoid [N, 3, H, W].
  337. # In index, 0 is foreground, 1 is transition, 2 is backbone.
  338. # After fusion, the foreground is 1, the background is 0, and the transion is between (0, 1).
  339. index = paddle.argmax(glance_sigmoid, axis=1, keepdim=True)
  340. transition_mask = (index == 1).astype('float32')
  341. fg = (index == 0).astype('float32')
  342. fusion_sigmoid = focus_sigmoid * transition_mask + fg
  343. return fusion_sigmoid
  344. def init_weight(self):
  345. if self.pretrained is not None:
  346. utils.load_entire_model(self, self.pretrained)
  347. class Refiner(nn.Layer):
  348. '''
  349. Refiner refines the coarse output to full resolution.
  350. Args:
  351. kernel_size: The convolution kernel_size. Options: [1, 3]. Default: 3.
  352. '''
  353. def __init__(self, kernel_size=3):
  354. super().__init__()
  355. if kernel_size not in [1, 3]:
  356. raise ValueError("kernel_size must be in [1, 3]")
  357. self.kernel_size = kernel_size
  358. channels = [32, 24, 16, 12, 1]
  359. self.conv1 = layers.ConvBNReLU(
  360. channels[0] + 4 + 3,
  361. channels[1],
  362. kernel_size,
  363. padding=0,
  364. bias_attr=False)
  365. self.conv2 = layers.ConvBNReLU(
  366. channels[1], channels[2], kernel_size, padding=0, bias_attr=False)
  367. self.conv3 = layers.ConvBNReLU(
  368. channels[2] + 3,
  369. channels[3],
  370. kernel_size,
  371. padding=0,
  372. bias_attr=False)
  373. self.conv4 = nn.Conv2D(
  374. channels[3], channels[4], kernel_size, padding=0, bias_attr=True)
  375. def forward(self, src, pha, err, hid, tri):
  376. '''
  377. Args:
  378. src: (B, 3, H, W) full resolution source image.
  379. pha: (B, 1, Hc, Wc) coarse alpha prediction.
  380. err: (B, 1, Hc, Hc) coarse error prediction.
  381. hid: (B, 32, Hc, Hc) coarse hidden encoding.
  382. tri: (B, 1, Hc, Hc) trimap prediction.
  383. '''
  384. h_full, w_full = paddle.shape(src)[2:]
  385. h_half, w_half = h_full // 2, w_full // 2
  386. h_quat, w_quat = h_full // 4, w_full // 4
  387. x = paddle.concat([hid, pha, tri], axis=1)
  388. x = F.interpolate(
  389. x,
  390. paddle.concat((h_half, w_half)),
  391. mode='bilinear',
  392. align_corners=False)
  393. y = F.interpolate(
  394. src,
  395. paddle.concat((h_half, w_half)),
  396. mode='bilinear',
  397. align_corners=False)
  398. if self.kernel_size == 3:
  399. x = F.pad(x, [3, 3, 3, 3])
  400. y = F.pad(y, [3, 3, 3, 3])
  401. x = self.conv1(paddle.concat([x, y], axis=1))
  402. x = self.conv2(x)
  403. if self.kernel_size == 3:
  404. x = F.interpolate(x, paddle.concat((h_full + 4, w_full + 4)))
  405. y = F.pad(src, [2, 2, 2, 2])
  406. else:
  407. x = F.interpolate(
  408. x, paddle.concat((h_full, w_full)), mode='nearest')
  409. y = src
  410. x = self.conv3(paddle.concat([x, y], axis=1))
  411. x = self.conv4(x)
  412. pha = x
  413. return pha