modnet.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # you may not use this file except in compliance with the License.
  2. # You may obtain a copy of the License at
  3. #
  4. # http://www.apache.org/licenses/LICENSE-2.0
  5. #
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. from collections import defaultdict
  12. import paddle
  13. import paddle.nn as nn
  14. import paddle.nn.functional as F
  15. import numpy as np
  16. import scipy
  17. import paddleseg
  18. from paddleseg.models import layers, losses
  19. from paddleseg import utils
  20. from paddleseg.cvlibs import manager, param_init
  21. @manager.MODELS.add_component
  22. class MODNet(nn.Layer):
  23. """
  24. The MODNet implementation based on PaddlePaddle.
  25. The original article refers to
  26. Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?"
  27. (https://arxiv.org/pdf/2011.11961.pdf).
  28. Args:
  29. backbone: backbone model.
  30. hr(int, optional): The channels of high resolutions branch. Defautl: None.
  31. pretrained(str, optional): The path of pretrianed model. Defautl: None.
  32. """
  33. def __init__(self, backbone, hr_channels=32, pretrained=None):
  34. super().__init__()
  35. self.backbone = backbone
  36. self.pretrained = pretrained
  37. self.head = MODNetHead(
  38. hr_channels=hr_channels, backbone_channels=backbone.feat_channels)
  39. self.init_weight()
  40. self.blurer = GaussianBlurLayer(1, 3)
  41. self.loss_func_dict = None
  42. def forward(self, inputs):
  43. """
  44. If training, return a dict.
  45. If evaluation, return the final alpha prediction.
  46. """
  47. x = inputs['img']
  48. feat_list = self.backbone(x)
  49. y = self.head(inputs=inputs, feat_list=feat_list)
  50. if self.training:
  51. loss = self.loss(y, inputs)
  52. return y, loss
  53. else:
  54. return y
  55. def loss(self, logit_dict, label_dict, loss_func_dict=None):
  56. if loss_func_dict is None:
  57. if self.loss_func_dict is None:
  58. self.loss_func_dict = defaultdict(list)
  59. self.loss_func_dict['semantic'].append(paddleseg.models.MSELoss(
  60. ))
  61. self.loss_func_dict['detail'].append(paddleseg.models.L1Loss())
  62. self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
  63. self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
  64. else:
  65. self.loss_func_dict = loss_func_dict
  66. loss = {}
  67. # semantic loss
  68. semantic_gt = F.interpolate(
  69. label_dict['alpha'],
  70. scale_factor=1 / 16,
  71. mode='bilinear',
  72. align_corners=False)
  73. semantic_gt = self.blurer(semantic_gt)
  74. # semantic_gt.stop_gradient=True
  75. loss['semantic'] = self.loss_func_dict['semantic'][0](
  76. logit_dict['semantic'], semantic_gt)
  77. # detail loss
  78. trimap = label_dict['trimap']
  79. mask = (trimap == 128).astype('float32')
  80. logit_detail = logit_dict['detail'] * mask
  81. label_detail = label_dict['alpha'] * mask
  82. loss_detail = self.loss_func_dict['detail'][0](logit_detail,
  83. label_detail)
  84. loss_detail = loss_detail / (mask.mean() + 1e-6)
  85. loss['detail'] = 10 * loss_detail
  86. # fusion loss
  87. matte = logit_dict['matte']
  88. alpha = label_dict['alpha']
  89. transition_mask = label_dict['trimap'] == 128
  90. matte_boundary = paddle.where(transition_mask, matte, alpha)
  91. # l1 loss
  92. loss_fusion_l1 = self.loss_func_dict['fusion'][0](
  93. matte, alpha) + 4 * self.loss_func_dict['fusion'][0](matte_boundary,
  94. alpha)
  95. # composition loss
  96. loss_fusion_comp = self.loss_func_dict['fusion'][1](
  97. matte * label_dict['img'], alpha *
  98. label_dict['img']) + 4 * self.loss_func_dict['fusion'][1](
  99. matte_boundary * label_dict['img'], alpha * label_dict['img'])
  100. # consisten loss with semantic
  101. transition_mask = F.interpolate(
  102. label_dict['trimap'],
  103. scale_factor=1 / 16,
  104. mode='nearest',
  105. align_corners=False)
  106. transition_mask = transition_mask == 128
  107. matte_con_sem = F.interpolate(
  108. matte, scale_factor=1 / 16, mode='bilinear', align_corners=False)
  109. matte_con_sem = self.blurer(matte_con_sem)
  110. logit_semantic = logit_dict['semantic'].clone()
  111. logit_semantic.stop_gradient = True
  112. matte_con_sem = paddle.where(transition_mask, logit_semantic,
  113. matte_con_sem)
  114. if False:
  115. import cv2
  116. matte_con_sem_num = matte_con_sem.numpy()
  117. matte_con_sem_num = matte_con_sem_num[0].squeeze()
  118. matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8')
  119. semantic = logit_dict['semantic'].numpy()
  120. semantic = semantic[0].squeeze()
  121. semantic = (semantic * 255).astype('uint8')
  122. transition_mask = transition_mask.astype('uint8')
  123. transition_mask = transition_mask.numpy()
  124. transition_mask = (transition_mask[0].squeeze()) * 255
  125. cv2.imwrite('matte_con.png', matte_con_sem_num)
  126. cv2.imwrite('semantic.png', semantic)
  127. cv2.imwrite('transition.png', transition_mask)
  128. mse_loss = paddleseg.models.MSELoss()
  129. loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic'])
  130. loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem
  131. loss['fusion'] = loss_fusion
  132. loss['fusion_l1'] = loss_fusion_l1
  133. loss['fusion_comp'] = loss_fusion_comp
  134. loss['fusion_con_sem'] = loss_fusion_con_sem
  135. loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion']
  136. return loss
  137. def init_weight(self):
  138. if self.pretrained is not None:
  139. utils.load_entire_model(self, self.pretrained)
  140. class MODNetHead(nn.Layer):
  141. def __init__(self, hr_channels, backbone_channels):
  142. super().__init__()
  143. self.lr_branch = LRBranch(backbone_channels)
  144. self.hr_branch = HRBranch(hr_channels, backbone_channels)
  145. self.f_branch = FusionBranch(hr_channels, backbone_channels)
  146. self.init_weight()
  147. def forward(self, inputs, feat_list):
  148. pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list)
  149. pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x)
  150. pred_matte = self.f_branch(inputs['img'], lr8x, hr2x)
  151. if self.training:
  152. logit_dict = {
  153. 'semantic': pred_semantic,
  154. 'detail': pred_detail,
  155. 'matte': pred_matte
  156. }
  157. return logit_dict
  158. else:
  159. return pred_matte
  160. def init_weight(self):
  161. for layer in self.sublayers():
  162. if isinstance(layer, nn.Conv2D):
  163. param_init.kaiming_uniform(layer.weight)
  164. class FusionBranch(nn.Layer):
  165. def __init__(self, hr_channels, enc_channels):
  166. super().__init__()
  167. self.conv_lr4x = Conv2dIBNormRelu(
  168. enc_channels[2], hr_channels, 5, stride=1, padding=2)
  169. self.conv_f2x = Conv2dIBNormRelu(
  170. 2 * hr_channels, hr_channels, 3, stride=1, padding=1)
  171. self.conv_f = nn.Sequential(
  172. Conv2dIBNormRelu(
  173. hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
  174. Conv2dIBNormRelu(
  175. int(hr_channels / 2),
  176. 1,
  177. 1,
  178. stride=1,
  179. padding=0,
  180. with_ibn=False,
  181. with_relu=False))
  182. def forward(self, img, lr8x, hr2x):
  183. lr4x = F.interpolate(
  184. lr8x, scale_factor=2, mode='bilinear', align_corners=False)
  185. lr4x = self.conv_lr4x(lr4x)
  186. lr2x = F.interpolate(
  187. lr4x, scale_factor=2, mode='bilinear', align_corners=False)
  188. f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1))
  189. f = F.interpolate(
  190. f2x, scale_factor=2, mode='bilinear', align_corners=False)
  191. f = self.conv_f(paddle.concat((f, img), axis=1))
  192. pred_matte = F.sigmoid(f)
  193. return pred_matte
  194. class HRBranch(nn.Layer):
  195. """
  196. High Resolution Branch of MODNet
  197. """
  198. def __init__(self, hr_channels, enc_channels):
  199. super().__init__()
  200. self.tohr_enc2x = Conv2dIBNormRelu(
  201. enc_channels[0], hr_channels, 1, stride=1, padding=0)
  202. self.conv_enc2x = Conv2dIBNormRelu(
  203. hr_channels + 3, hr_channels, 3, stride=2, padding=1)
  204. self.tohr_enc4x = Conv2dIBNormRelu(
  205. enc_channels[1], hr_channels, 1, stride=1, padding=0)
  206. self.conv_enc4x = Conv2dIBNormRelu(
  207. 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
  208. self.conv_hr4x = nn.Sequential(
  209. Conv2dIBNormRelu(
  210. 2 * hr_channels + enc_channels[2] + 3,
  211. 2 * hr_channels,
  212. 3,
  213. stride=1,
  214. padding=1),
  215. Conv2dIBNormRelu(
  216. 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
  217. Conv2dIBNormRelu(
  218. 2 * hr_channels, hr_channels, 3, stride=1, padding=1))
  219. self.conv_hr2x = nn.Sequential(
  220. Conv2dIBNormRelu(
  221. 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
  222. Conv2dIBNormRelu(
  223. 2 * hr_channels, hr_channels, 3, stride=1, padding=1),
  224. Conv2dIBNormRelu(
  225. hr_channels, hr_channels, 3, stride=1, padding=1),
  226. Conv2dIBNormRelu(
  227. hr_channels, hr_channels, 3, stride=1, padding=1))
  228. self.conv_hr = nn.Sequential(
  229. Conv2dIBNormRelu(
  230. hr_channels + 3, hr_channels, 3, stride=1, padding=1),
  231. Conv2dIBNormRelu(
  232. hr_channels,
  233. 1,
  234. 1,
  235. stride=1,
  236. padding=0,
  237. with_ibn=False,
  238. with_relu=False))
  239. def forward(self, img, enc2x, enc4x, lr8x):
  240. img2x = F.interpolate(
  241. img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
  242. img4x = F.interpolate(
  243. img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
  244. enc2x = self.tohr_enc2x(enc2x)
  245. hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1))
  246. enc4x = self.tohr_enc4x(enc4x)
  247. hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1))
  248. lr4x = F.interpolate(
  249. lr8x, scale_factor=2, mode='bilinear', align_corners=False)
  250. hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1))
  251. hr2x = F.interpolate(
  252. hr4x, scale_factor=2, mode='bilinear', align_corners=False)
  253. hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1))
  254. pred_detail = None
  255. if self.training:
  256. hr = F.interpolate(
  257. hr2x, scale_factor=2, mode='bilinear', align_corners=False)
  258. hr = self.conv_hr(paddle.concat((hr, img), axis=1))
  259. pred_detail = F.sigmoid(hr)
  260. return pred_detail, hr2x
  261. class LRBranch(nn.Layer):
  262. def __init__(self, backbone_channels):
  263. super().__init__()
  264. self.se_block = SEBlock(backbone_channels[4], reduction=4)
  265. self.conv_lr16x = Conv2dIBNormRelu(
  266. backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2)
  267. self.conv_lr8x = Conv2dIBNormRelu(
  268. backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2)
  269. self.conv_lr = Conv2dIBNormRelu(
  270. backbone_channels[2],
  271. 1,
  272. 3,
  273. stride=2,
  274. padding=1,
  275. with_ibn=False,
  276. with_relu=False)
  277. def forward(self, feat_list):
  278. enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4]
  279. enc32x = self.se_block(enc32x)
  280. lr16x = F.interpolate(
  281. enc32x, scale_factor=2, mode='bilinear', align_corners=False)
  282. lr16x = self.conv_lr16x(lr16x)
  283. lr8x = F.interpolate(
  284. lr16x, scale_factor=2, mode='bilinear', align_corners=False)
  285. lr8x = self.conv_lr8x(lr8x)
  286. pred_semantic = None
  287. if self.training:
  288. lr = self.conv_lr(lr8x)
  289. pred_semantic = F.sigmoid(lr)
  290. return pred_semantic, lr8x, [enc2x, enc4x]
  291. class IBNorm(nn.Layer):
  292. """
  293. Combine Instance Norm and Batch Norm into One Layer
  294. """
  295. def __init__(self, in_channels):
  296. super().__init__()
  297. self.bnorm_channels = in_channels // 2
  298. self.inorm_channels = in_channels - self.bnorm_channels
  299. self.bnorm = nn.BatchNorm2D(self.bnorm_channels)
  300. self.inorm = nn.InstanceNorm2D(self.inorm_channels)
  301. def forward(self, x):
  302. bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :])
  303. in_x = self.inorm(x[:, self.bnorm_channels:, :, :])
  304. return paddle.concat((bn_x, in_x), 1)
  305. class Conv2dIBNormRelu(nn.Layer):
  306. """
  307. Convolution + IBNorm + Relu
  308. """
  309. def __init__(self,
  310. in_channels,
  311. out_channels,
  312. kernel_size,
  313. stride=1,
  314. padding=0,
  315. dilation=1,
  316. groups=1,
  317. bias_attr=None,
  318. with_ibn=True,
  319. with_relu=True):
  320. super().__init__()
  321. layers = [
  322. nn.Conv2D(
  323. in_channels,
  324. out_channels,
  325. kernel_size,
  326. stride=stride,
  327. padding=padding,
  328. dilation=dilation,
  329. groups=groups,
  330. bias_attr=bias_attr)
  331. ]
  332. if with_ibn:
  333. layers.append(IBNorm(out_channels))
  334. if with_relu:
  335. layers.append(nn.ReLU())
  336. self.layers = nn.Sequential(*layers)
  337. def forward(self, x):
  338. return self.layers(x)
  339. class SEBlock(nn.Layer):
  340. """
  341. SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
  342. """
  343. def __init__(self, num_channels, reduction=1):
  344. super().__init__()
  345. self.pool = nn.AdaptiveAvgPool2D(1)
  346. self.conv = nn.Sequential(
  347. nn.Conv2D(
  348. num_channels,
  349. int(num_channels // reduction),
  350. 1,
  351. bias_attr=False),
  352. nn.ReLU(),
  353. nn.Conv2D(
  354. int(num_channels // reduction),
  355. num_channels,
  356. 1,
  357. bias_attr=False),
  358. nn.Sigmoid())
  359. def forward(self, x):
  360. w = self.pool(x)
  361. w = self.conv(w)
  362. return w * x
  363. class GaussianBlurLayer(nn.Layer):
  364. """ Add Gaussian Blur to a 4D tensors
  365. This layer takes a 4D tensor of {N, C, H, W} as input.
  366. The Gaussian blur will be performed in given channel number (C) splitly.
  367. """
  368. def __init__(self, channels, kernel_size):
  369. """
  370. Args:
  371. channels (int): Channel for input tensor
  372. kernel_size (int): Size of the kernel used in blurring
  373. """
  374. super(GaussianBlurLayer, self).__init__()
  375. self.channels = channels
  376. self.kernel_size = kernel_size
  377. assert self.kernel_size % 2 != 0
  378. self.op = nn.Sequential(
  379. nn.Pad2D(
  380. int(self.kernel_size / 2), mode='reflect'),
  381. nn.Conv2D(
  382. channels,
  383. channels,
  384. self.kernel_size,
  385. stride=1,
  386. padding=0,
  387. bias_attr=False,
  388. groups=channels))
  389. self._init_kernel()
  390. self.op[1].weight.stop_gradient = True
  391. def forward(self, x):
  392. """
  393. Args:
  394. x (paddle.Tensor): input 4D tensor
  395. Returns:
  396. paddle.Tensor: Blurred version of the input
  397. """
  398. if not len(list(x.shape)) == 4:
  399. print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
  400. exit()
  401. elif not x.shape[1] == self.channels:
  402. print('In \'GaussianBlurLayer\', the required channel ({0}) is'
  403. 'not the same as input ({1})\n'.format(self.channels, x.shape[
  404. 1]))
  405. exit()
  406. return self.op(x)
  407. def _init_kernel(self):
  408. sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
  409. n = np.zeros((self.kernel_size, self.kernel_size))
  410. i = int(self.kernel_size / 2)
  411. n[i, i] = 1
  412. kernel = scipy.ndimage.gaussian_filter(n, sigma)
  413. kernel = kernel.astype('float32')
  414. kernel = kernel[np.newaxis, np.newaxis, :, :]
  415. paddle.assign(kernel, self.op[1].weight)