ppmattingv2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import partial
  15. from collections import defaultdict
  16. import paddle
  17. import paddle.nn as nn
  18. import paddle.nn.functional as F
  19. import paddleseg
  20. from paddleseg import utils
  21. from paddleseg.models import layers
  22. from paddleseg.cvlibs import manager
  23. from paddleseg.models.backbones.transformer_utils import Identity, DropPath
  24. from ppmatting.models.layers import MLFF
  25. from ppmatting.models.losses import MRSD, GradientLoss
  26. @manager.MODELS.add_component
  27. class PPMattingV2(nn.Layer):
  28. """
  29. The PPMattingV2 implementation based on PaddlePaddle.
  30. The original article refers to
  31. TODO Guowei Chen, et, al. "" ().
  32. Args:
  33. backbone: backobne model.
  34. pretrained(str, optional): The path of pretrianed model. Defautl: None.
  35. dpp_len_trans(int, optional): The depth of transformer block in dpp(DoublePyramidPoolModule). Default: 1.
  36. dpp_index(list, optional): The index of backone output which as the input in dpp. Default: [1, 2, 3, 4].
  37. dpp_mid_channel(int, optional): The output channels of the first pyramid pool in dpp. Default: 256.
  38. dpp_out_channel(int, optional): The output channels of dpp. Default: 512.
  39. dpp_bin_sizes(list, optional): The output size of the second pyramid pool in dpp. Default: (2, 4, 6).
  40. dpp_mlp_ratios(int, optional): The expandsion ratio of mlp in dpp. Default: 2.
  41. dpp_attn_ratio(int, optional): The expandsion ratio of attention. Default: 2.
  42. dpp_merge_type(str, optional): The merge type of the output of the second pyramid pool in dpp,
  43. which should be one of (`concat`, `add`). Default: 'concat'.
  44. mlff_merge_type(str, optional): The merge type of the multi features before output.
  45. It should be one of ('add', 'concat'). Default: 'concat'.
  46. """
  47. def __init__(self,
  48. backbone,
  49. pretrained=None,
  50. dpp_len_trans=1,
  51. dpp_index=[1, 2, 3, 4],
  52. dpp_mid_channel=256,
  53. dpp_output_channel=512,
  54. dpp_bin_sizes=(2, 4, 6),
  55. dpp_mlp_ratios=2,
  56. dpp_attn_ratio=2,
  57. dpp_merge_type='concat',
  58. mlff_merge_type='concat',
  59. decoder_channels=[128, 96, 64, 32, 32],
  60. head_channel=32):
  61. super().__init__()
  62. self.backbone = backbone
  63. self.backbone_channels = backbone.feat_channels
  64. # check
  65. assert len(backbone.feat_channels) == 5, \
  66. "Backbone should return 5 features with different scales"
  67. assert max(dpp_index) < len(backbone.feat_channels), \
  68. "The element of `dpp_index` should be less than the number of return features of backbone."
  69. # dpp module
  70. self.dpp_index = dpp_index
  71. self.dpp = DoublePyramidPoolModule(
  72. stride=2,
  73. input_channel=sum(self.backbone_channels[i]
  74. for i in self.dpp_index),
  75. mid_channel=dpp_mid_channel,
  76. output_channel=dpp_output_channel,
  77. len_trans=dpp_len_trans,
  78. bin_sizes=dpp_bin_sizes,
  79. mlp_ratios=dpp_mlp_ratios,
  80. attn_ratio=dpp_attn_ratio,
  81. merge_type=dpp_merge_type)
  82. # decoder
  83. self.mlff32x = MLFF(
  84. in_channels=[self.backbone_channels[-1], dpp_output_channel],
  85. mid_channels=[dpp_output_channel, dpp_output_channel],
  86. out_channel=decoder_channels[0],
  87. merge_type=mlff_merge_type)
  88. self.mlff16x = MLFF(
  89. in_channels=[
  90. self.backbone_channels[-2], decoder_channels[0],
  91. dpp_output_channel
  92. ],
  93. mid_channels=[
  94. decoder_channels[0], decoder_channels[0], decoder_channels[0]
  95. ],
  96. out_channel=decoder_channels[1],
  97. merge_type=mlff_merge_type)
  98. self.mlff8x = MLFF(
  99. in_channels=[
  100. self.backbone_channels[-3], decoder_channels[1],
  101. dpp_output_channel
  102. ],
  103. mid_channels=[
  104. decoder_channels[1], decoder_channels[1], decoder_channels[1]
  105. ],
  106. out_channel=decoder_channels[2],
  107. merge_type=mlff_merge_type)
  108. self.mlff4x = MLFF(
  109. in_channels=[self.backbone_channels[-4], decoder_channels[2], 3],
  110. mid_channels=[decoder_channels[2], decoder_channels[2], 3],
  111. out_channel=decoder_channels[3])
  112. self.mlff2x = MLFF(
  113. in_channels=[self.backbone_channels[-5], decoder_channels[3], 3],
  114. mid_channels=[decoder_channels[3], decoder_channels[3], 3],
  115. out_channel=decoder_channels[4])
  116. self.matting_head_mlff8x = MattingHead(
  117. in_chan=decoder_channels[2], mid_chan=32)
  118. self.matting_head_mlff2x = MattingHead(
  119. in_chan=decoder_channels[4] + 3, mid_chan=head_channel, mid_num=2)
  120. # loss
  121. self.loss_func_dict = None
  122. # pretrained
  123. self.pretrained = pretrained
  124. self.init_weight()
  125. def forward(self, inputs):
  126. img = inputs['img']
  127. input_shape = paddle.shape(img)
  128. feats_backbone = self.backbone(
  129. img) # stdc1 [2x, 4x, 8x, 16x, 32x] [32, 64, 256, 512, 1024]
  130. x = self.dpp([feats_backbone[i] for i in self.dpp_index])
  131. dpp_out = x
  132. input_32x = [feats_backbone[-1], x]
  133. x = self.mlff32x(input_32x,
  134. paddle.shape(feats_backbone[-1])[-2:]) # 32x
  135. input_16x = [feats_backbone[-2], x, dpp_out]
  136. x = self.mlff16x(input_16x,
  137. paddle.shape(feats_backbone[-2])[-2:]) # 16x
  138. input_8x = [feats_backbone[-3], x, dpp_out]
  139. x = self.mlff8x(input_8x, paddle.shape(feats_backbone[-3])[-2:]) # 8x
  140. mlff8x_output = x
  141. input_4x = [feats_backbone[-4], x]
  142. input_4x.append(
  143. F.interpolate(
  144. img, feats_backbone[-4].shape[2:], mode='area'))
  145. x = self.mlff4x(input_4x, paddle.shape(feats_backbone[-4])[-2:]) # 4x
  146. input_2x = [feats_backbone[-5], x]
  147. input_2x.append(
  148. F.interpolate(
  149. img, feats_backbone[-5].shape[2:], mode='area'))
  150. x = self.mlff2x(input_2x, paddle.shape(feats_backbone[-5])[-2:]) # 2x
  151. x = F.interpolate(
  152. x, input_shape[-2:], mode='bilinear', align_corners=False)
  153. x = paddle.concat([x, img], axis=1)
  154. alpha = self.matting_head_mlff2x(x)
  155. if self.training:
  156. logit_dict = {}
  157. logit_dict['alpha'] = alpha
  158. logit_dict['alpha_8x'] = self.matting_head_mlff8x(mlff8x_output)
  159. loss_dict = self.loss(logit_dict, inputs)
  160. return logit_dict, loss_dict
  161. else:
  162. return alpha
  163. def loss(self, logit_dict, label_dict, loss_func_dict=None):
  164. if loss_func_dict is None:
  165. if self.loss_func_dict is None:
  166. self.loss_func_dict = defaultdict(list)
  167. self.loss_func_dict['alpha'].append(MRSD())
  168. self.loss_func_dict['alpha'].append(GradientLoss())
  169. self.loss_func_dict['alpha_8x'].append(MRSD())
  170. self.loss_func_dict['alpha_8x'].append(GradientLoss())
  171. else:
  172. self.loss_func_dict = loss_func_dict
  173. loss = {}
  174. alpha_8x_label = F.interpolate(
  175. label_dict['alpha'],
  176. size=logit_dict['alpha_8x'].shape[-2:],
  177. mode='area',
  178. align_corners=False)
  179. loss['alpha_8x_mrsd'] = self.loss_func_dict['alpha_8x'][0](
  180. logit_dict['alpha_8x'], alpha_8x_label)
  181. loss['alpha_8x_grad'] = self.loss_func_dict['alpha_8x'][1](
  182. logit_dict['alpha_8x'], alpha_8x_label)
  183. loss['alpha_8x'] = loss['alpha_8x_mrsd'] + loss['alpha_8x_grad']
  184. transition_mask = label_dict['trimap'] == 128
  185. loss['alpha_mrsd'] = self.loss_func_dict['alpha'][0](
  186. logit_dict['alpha'],
  187. label_dict['alpha']) + 2 * self.loss_func_dict['alpha'][0](
  188. logit_dict['alpha'], label_dict['alpha'], transition_mask)
  189. loss['alpha_grad'] = self.loss_func_dict['alpha'][1](
  190. logit_dict['alpha'],
  191. label_dict['alpha']) + 2 * self.loss_func_dict['alpha'][1](
  192. logit_dict['alpha'], label_dict['alpha'], transition_mask)
  193. loss['alpha'] = loss['alpha_mrsd'] + loss['alpha_grad']
  194. loss['all'] = loss['alpha'] + loss['alpha_8x']
  195. return loss
  196. def init_weight(self):
  197. if self.pretrained is not None:
  198. utils.load_entire_model(self, self.pretrained)
  199. class MattingHead(nn.Layer):
  200. def __init__(self, in_chan, mid_chan, mid_num=1, out_channels=1):
  201. super().__init__()
  202. self.conv = layers.ConvBNReLU(
  203. in_chan,
  204. mid_chan,
  205. kernel_size=3,
  206. stride=1,
  207. padding=1,
  208. bias_attr=False)
  209. self.mid_conv = nn.LayerList([
  210. layers.ConvBNReLU(
  211. mid_chan,
  212. mid_chan,
  213. kernel_size=3,
  214. stride=1,
  215. padding=1,
  216. bias_attr=False) for i in range(mid_num - 1)
  217. ])
  218. self.conv_out = nn.Conv2D(
  219. mid_chan, out_channels, kernel_size=1, bias_attr=False)
  220. def forward(self, x):
  221. x = self.conv(x)
  222. for mid_conv in self.mid_conv:
  223. x = mid_conv(x)
  224. x = self.conv_out(x)
  225. x = F.sigmoid(x)
  226. return x
  227. class DoublePyramidPoolModule(nn.Layer):
  228. """
  229. Extract global information through double pyramid pool structure and attention calculation by transformer block.
  230. Args:
  231. stride(int): The stride for the inputs.
  232. input_channel(int): The total channels of input features.
  233. mid_channel(int, optional): The output channels of the first pyramid pool. Default: 256.
  234. out_channel(int, optional): The output channels. Default: 512.
  235. len_trans(int, optional): The depth of transformer block. Default: 1.
  236. bin_sizes(list, optional): The output size of the second pyramid pool. Default: (2, 4, 6).
  237. mlp_ratios(int, optional): The expandsion ratio of the mlp. Default: 2.
  238. attn_ratio(int, optional): The expandsion ratio of the attention. Default: 2.
  239. merge_type(str, optional): The merge type of the output of the second pyramid pool, which should be one of (`concat`, `add`). Default: 'concat'.
  240. align_corners(bool, optional): Whether to use `align_corners` when interpolating. Default: False.
  241. """
  242. def __init__(self,
  243. stride,
  244. input_channel,
  245. mid_channel=256,
  246. output_channel=512,
  247. len_trans=1,
  248. bin_sizes=(2, 4, 6),
  249. mlp_ratios=2,
  250. attn_ratio=2,
  251. merge_type='concat',
  252. align_corners=False):
  253. super().__init__()
  254. self.mid_channel = mid_channel
  255. self.align_corners = align_corners
  256. self.mlp_rations = mlp_ratios
  257. self.attn_ratio = attn_ratio
  258. if isinstance(len_trans, int):
  259. self.len_trans = [len_trans] * len(bin_sizes)
  260. elif isinstance(len_trans, (list, tuple)):
  261. self.len_trans = len_trans
  262. if len(len_trans) != len(bin_sizes):
  263. raise ValueError(
  264. 'If len_trans is list or tuple, the length should be same as bin_sizes'
  265. )
  266. else:
  267. raise ValueError(
  268. '`len_trans` only support int, list and tuple type')
  269. if merge_type not in ['add', 'concat']:
  270. raise ('`merge_type only support `add` or `concat`.')
  271. self.merge_type = merge_type
  272. self.pp1 = PyramidPoolAgg(stride=stride)
  273. self.conv_mid = layers.ConvBN(input_channel, mid_channel, 1)
  274. self.pp2 = nn.LayerList([
  275. self._make_stage(
  276. embdeding_channels=mid_channel, size=size, block_num=block_num)
  277. for size, block_num in zip(bin_sizes, self.len_trans)
  278. ])
  279. if self.merge_type == 'concat':
  280. in_chan = mid_channel + mid_channel * len(bin_sizes)
  281. else:
  282. in_chan = mid_channel
  283. self.conv_out = layers.ConvBNReLU(
  284. in_chan, output_channel, kernel_size=1)
  285. def _make_stage(self, embdeding_channels, size, block_num):
  286. prior = nn.AdaptiveAvgPool2D(output_size=size)
  287. if size == 1:
  288. trans = layers.ConvBNReLU(
  289. in_channels=embdeding_channels,
  290. out_channels=embdeding_channels,
  291. kernel_size=1)
  292. else:
  293. trans = BasicLayer(
  294. block_num=block_num,
  295. embedding_dim=embdeding_channels,
  296. key_dim=16,
  297. num_heads=8,
  298. mlp_ratios=self.mlp_rations,
  299. attn_ratio=self.attn_ratio,
  300. drop=0,
  301. attn_drop=0,
  302. drop_path=0,
  303. act_layer=nn.ReLU6,
  304. lr_mult=1.0)
  305. return nn.Sequential(prior, trans)
  306. def forward(self, inputs):
  307. x = self.pp1(inputs)
  308. pp2_input = self.conv_mid(x)
  309. cat_layers = []
  310. for stage in self.pp2:
  311. x = stage(pp2_input)
  312. x = F.interpolate(
  313. x,
  314. paddle.shape(pp2_input)[2:],
  315. mode='bilinear',
  316. align_corners=self.align_corners)
  317. cat_layers.append(x)
  318. cat_layers = [pp2_input] + cat_layers[::-1]
  319. if self.merge_type == 'concat':
  320. cat = paddle.concat(cat_layers, axis=1)
  321. else:
  322. cat = sum(cat_layers)
  323. out = self.conv_out(cat)
  324. return out
  325. class Conv2DBN(nn.Layer):
  326. def __init__(self,
  327. in_channels,
  328. out_channels,
  329. ks=1,
  330. stride=1,
  331. pad=0,
  332. dilation=1,
  333. groups=1,
  334. bn_weight_init=1,
  335. lr_mult=1.0):
  336. super().__init__()
  337. conv_weight_attr = paddle.ParamAttr(learning_rate=lr_mult)
  338. self.c = nn.Conv2D(
  339. in_channels=in_channels,
  340. out_channels=out_channels,
  341. kernel_size=ks,
  342. stride=stride,
  343. padding=pad,
  344. dilation=dilation,
  345. groups=groups,
  346. weight_attr=conv_weight_attr,
  347. bias_attr=False)
  348. bn_weight_attr = paddle.ParamAttr(
  349. initializer=nn.initializer.Constant(bn_weight_init),
  350. learning_rate=lr_mult)
  351. bn_bias_attr = paddle.ParamAttr(
  352. initializer=nn.initializer.Constant(0), learning_rate=lr_mult)
  353. self.bn = nn.BatchNorm2D(
  354. out_channels, weight_attr=bn_weight_attr, bias_attr=bn_bias_attr)
  355. def forward(self, inputs):
  356. out = self.c(inputs)
  357. out = self.bn(out)
  358. return out
  359. class MLP(nn.Layer):
  360. def __init__(self,
  361. in_features,
  362. hidden_features=None,
  363. out_features=None,
  364. act_layer=nn.ReLU,
  365. drop=0.,
  366. lr_mult=1.0):
  367. super().__init__()
  368. out_features = out_features or in_features
  369. hidden_features = hidden_features or in_features
  370. self.fc1 = Conv2DBN(in_features, hidden_features, lr_mult=lr_mult)
  371. param_attr = paddle.ParamAttr(learning_rate=lr_mult)
  372. self.dwconv = nn.Conv2D(
  373. hidden_features,
  374. hidden_features,
  375. 3,
  376. 1,
  377. 1,
  378. groups=hidden_features,
  379. weight_attr=param_attr,
  380. bias_attr=param_attr)
  381. self.act = act_layer()
  382. self.fc2 = Conv2DBN(hidden_features, out_features, lr_mult=lr_mult)
  383. self.drop = nn.Dropout(drop)
  384. def forward(self, x):
  385. x = self.fc1(x)
  386. x = self.dwconv(x)
  387. x = self.act(x)
  388. x = self.drop(x)
  389. x = self.fc2(x)
  390. x = self.drop(x)
  391. return x
  392. class Attention(nn.Layer):
  393. def __init__(self,
  394. dim,
  395. key_dim,
  396. num_heads,
  397. attn_ratio=4,
  398. activation=None,
  399. lr_mult=1.0):
  400. super().__init__()
  401. self.num_heads = num_heads
  402. self.scale = key_dim**-0.5
  403. self.key_dim = key_dim
  404. self.nh_kd = nh_kd = key_dim * num_heads
  405. self.d = int(attn_ratio * key_dim)
  406. self.dh = int(attn_ratio * key_dim) * num_heads
  407. self.attn_ratio = attn_ratio
  408. self.to_q = Conv2DBN(dim, nh_kd, 1, lr_mult=lr_mult)
  409. self.to_k = Conv2DBN(dim, nh_kd, 1, lr_mult=lr_mult)
  410. self.to_v = Conv2DBN(dim, self.dh, 1, lr_mult=lr_mult)
  411. self.proj = nn.Sequential(
  412. activation(),
  413. Conv2DBN(
  414. self.dh, dim, bn_weight_init=0, lr_mult=lr_mult))
  415. def forward(self, x):
  416. x_shape = paddle.shape(x)
  417. H, W = x_shape[2], x_shape[3]
  418. qq = self.to_q(x).reshape(
  419. [0, self.num_heads, self.key_dim, -1]).transpose([0, 1, 3, 2])
  420. kk = self.to_k(x).reshape([0, self.num_heads, self.key_dim, -1])
  421. vv = self.to_v(x).reshape([0, self.num_heads, self.d, -1]).transpose(
  422. [0, 1, 3, 2])
  423. attn = paddle.matmul(qq, kk)
  424. attn = F.softmax(attn, axis=-1)
  425. xx = paddle.matmul(attn, vv)
  426. xx = xx.transpose([0, 1, 3, 2]).reshape([0, self.dh, H, W])
  427. xx = self.proj(xx)
  428. return xx
  429. class Block(nn.Layer):
  430. def __init__(self,
  431. dim,
  432. key_dim,
  433. num_heads,
  434. mlp_ratios=4.,
  435. attn_ratio=2.,
  436. drop=0.,
  437. drop_path=0.,
  438. act_layer=nn.ReLU,
  439. lr_mult=1.0):
  440. super().__init__()
  441. self.dim = dim
  442. self.num_heads = num_heads
  443. self.mlp_ratios = mlp_ratios
  444. self.attn = Attention(
  445. dim,
  446. key_dim=key_dim,
  447. num_heads=num_heads,
  448. attn_ratio=attn_ratio,
  449. activation=act_layer,
  450. lr_mult=lr_mult)
  451. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  452. self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
  453. mlp_hidden_dim = int(dim * mlp_ratios)
  454. self.mlp = MLP(in_features=dim,
  455. hidden_features=mlp_hidden_dim,
  456. act_layer=act_layer,
  457. drop=drop,
  458. lr_mult=lr_mult)
  459. def forward(self, x):
  460. h = x
  461. x = self.attn(x)
  462. x = self.drop_path(x)
  463. x = h + x
  464. h = x
  465. x = self.mlp(x)
  466. x = self.drop_path(x)
  467. x = x + h
  468. return x
  469. class BasicLayer(nn.Layer):
  470. def __init__(self,
  471. block_num,
  472. embedding_dim,
  473. key_dim,
  474. num_heads,
  475. mlp_ratios=4.,
  476. attn_ratio=2.,
  477. drop=0.,
  478. attn_drop=0.,
  479. drop_path=0.,
  480. act_layer=None,
  481. lr_mult=1.0):
  482. super().__init__()
  483. self.block_num = block_num
  484. self.transformer_blocks = nn.LayerList()
  485. for i in range(self.block_num):
  486. self.transformer_blocks.append(
  487. Block(
  488. embedding_dim,
  489. key_dim=key_dim,
  490. num_heads=num_heads,
  491. mlp_ratios=mlp_ratios,
  492. attn_ratio=attn_ratio,
  493. drop=drop,
  494. drop_path=drop_path[i]
  495. if isinstance(drop_path, list) else drop_path,
  496. act_layer=act_layer,
  497. lr_mult=lr_mult))
  498. def forward(self, x):
  499. # token * N
  500. for i in range(self.block_num):
  501. x = self.transformer_blocks[i](x)
  502. return x
  503. class PyramidPoolAgg(nn.Layer):
  504. def __init__(self, stride):
  505. super().__init__()
  506. self.stride = stride
  507. self.tmp = Identity() # avoid the error of paddle.flops
  508. def forward(self, inputs):
  509. '''
  510. # The F.adaptive_avg_pool2d does not support the (H, W) be Tensor,
  511. # so exporting the inference model will raise error.
  512. _, _, H, W = inputs[-1].shape
  513. H = (H - 1) // self.stride + 1
  514. W = (W - 1) // self.stride + 1
  515. return paddle.concat(
  516. [F.adaptive_avg_pool2d(inp, (H, W)) for inp in inputs], axis=1)
  517. '''
  518. out = []
  519. ks = 2**len(inputs)
  520. stride = self.stride**len(inputs)
  521. for x in inputs:
  522. x = F.avg_pool2d(x, int(ks), int(stride))
  523. ks /= 2
  524. stride /= 2
  525. out.append(x)
  526. out = paddle.concat(out, axis=1)
  527. return out