rvm.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # This program is about RVM implementation based on PaddlePaddle according to
  2. # https://github.com/PeterL1n/RobustVideoMatting.
  3. # Copyright (C) 2022 PaddlePaddle Authors.
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. # This program is distributed in the hope that it will be useful,
  9. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. # GNU General Public License for more details.
  12. # You should have received a copy of the GNU General Public License
  13. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle import Tensor
  18. import paddleseg
  19. from paddleseg import utils
  20. from paddleseg.models import layers
  21. from paddleseg.cvlibs import manager
  22. from typing import Tuple, Optional
  23. from ppmatting.models import FastGuidedFilter
  24. @manager.MODELS.add_component
  25. class RVM(nn.Layer):
  26. """
  27. The RVM implementation based on PaddlePaddle.
  28. The original article refers to
  29. Shanchuan Lin1, et, al. "Robust High-Resolution Video Matting with Temporal Guidance"
  30. (https://arxiv.org/pdf/2108.11515.pdf).
  31. Args:
  32. backbone: backbone model.
  33. lraspp_in_channels (int, optional):
  34. lraspp_out_channels (int, optional):
  35. decoder_channels (int, optional):
  36. refiner (str, optional):
  37. downsample_ratio (float, optional):
  38. pretrained(str, optional): The path of pretrianed model. Defautl: None.
  39. to_rgb(bool, optional): The fgr results change to rgb format. Default: True.
  40. """
  41. def __init__(self,
  42. backbone,
  43. lraspp_in_channels=960,
  44. lraspp_out_channels=128,
  45. decoder_channels=(80, 40, 32, 16),
  46. refiner='deep_guided_filter',
  47. downsample_ratio=1.,
  48. pretrained=None,
  49. to_rgb=True):
  50. super().__init__()
  51. self.backbone = backbone
  52. self.aspp = LRASPP(lraspp_in_channels, lraspp_out_channels)
  53. rd_fea_channels = self.backbone.feat_channels[:-1] + [
  54. lraspp_out_channels
  55. ]
  56. self.decoder = RecurrentDecoder(rd_fea_channels, decoder_channels)
  57. self.project_mat = Projection(decoder_channels[-1], 4)
  58. self.project_seg = Projection(decoder_channels[-1], 1)
  59. if refiner == 'deep_guided_filter':
  60. self.refiner = DeepGuidedFilterRefiner()
  61. else:
  62. self.refiner = FastGuidedFilterRefiner()
  63. self.downsample_ratio = downsample_ratio
  64. self.pretrained = pretrained
  65. self.to_rgb = to_rgb
  66. self.r1 = None
  67. self.r2 = None
  68. self.r3 = None
  69. self.r4 = None
  70. def forward(self,
  71. data,
  72. r1=None,
  73. r2=None,
  74. r3=None,
  75. r4=None,
  76. downsample_ratio=None,
  77. segmentation_pass=False):
  78. src = data['img']
  79. if downsample_ratio is None:
  80. downsample_ratio = self.downsample_ratio
  81. if r1 is not None and r2 is not None and r3 is not None and r4 is not None:
  82. self.r1, self.r2, self.r3, self.r4 = r1, r2, r3, r4
  83. result = self.forward_(
  84. src,
  85. r1=self.r1,
  86. r2=self.r2,
  87. r3=self.r3,
  88. r4=self.r4,
  89. downsample_ratio=downsample_ratio,
  90. segmentation_pass=segmentation_pass)
  91. if self.training:
  92. raise RuntimeError('Sorry! RVM now do not support training')
  93. else:
  94. if segmentation_pass:
  95. seg, self.r1, self.r2, self.r3, self.r4 = result
  96. return {'alpha': seg}
  97. else:
  98. fgr, pha, self.r1, self.r2, self.r3, self.r4 = result
  99. if self.to_rgb:
  100. fgr = paddle.flip(fgr, axis=-3)
  101. return {
  102. 'alpha': pha,
  103. "fg": fgr,
  104. "r1": self.r1,
  105. "r2": self.r2,
  106. "r3": self.r3,
  107. "r4": self.r4
  108. }
  109. def forward_(self,
  110. src,
  111. r1=None,
  112. r2=None,
  113. r3=None,
  114. r4=None,
  115. downsample_ratio=1.,
  116. segmentation_pass=False):
  117. if isinstance(downsample_ratio, paddle.fluid.framework.Variable):
  118. # for export
  119. src_sm = self._interpolate(src, scale_factor=downsample_ratio)
  120. elif downsample_ratio != 1:
  121. src_sm = self._interpolate(src, scale_factor=downsample_ratio)
  122. else:
  123. src_sm = src
  124. f1, f2, f3, f4 = self.backbone_forward(src_sm)
  125. f4 = self.aspp(f4)
  126. hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
  127. if not segmentation_pass:
  128. fgr_residual, pha = self.project_mat(hid).split([3, 1], axis=-3)
  129. if downsample_ratio != 1:
  130. fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha,
  131. hid)
  132. fgr = fgr_residual + src
  133. fgr = fgr.clip(0., 1.)
  134. pha = pha.clip(0., 1.)
  135. return [fgr, pha, *rec]
  136. else:
  137. seg = self.project_seg(hid)
  138. return [seg, *rec]
  139. def reset(self):
  140. """
  141. When a video is predicted, the history memory shoulb be reset.
  142. """
  143. self.r1 = None
  144. self.r2 = None
  145. self.r3 = None
  146. self.r4 = None
  147. def backbone_forward(self, x):
  148. if x.ndim == 5:
  149. B, T = paddle.shape(x)[:2]
  150. features = self.backbone(x.flatten(0, 1))
  151. for i, f in enumerate(features):
  152. features[i] = f.reshape((B, T, *(paddle.shape(f)[1:])))
  153. else:
  154. features = self.backbone(x)
  155. return features
  156. def _interpolate(self, x: Tensor, scale_factor: float):
  157. if x.ndim == 5:
  158. B, T = paddle.shape(x)[:2]
  159. x = F.interpolate(
  160. x.flatten(0, 1),
  161. scale_factor=scale_factor,
  162. mode='bilinear',
  163. align_corners=False)
  164. *_, C, H, W = paddle.shape(x)[-3:]
  165. x = x.reshape((B, T, C, H, W))
  166. else:
  167. x = F.interpolate(
  168. x,
  169. scale_factor=scale_factor,
  170. mode='bilinear',
  171. align_corners=False)
  172. return x
  173. def init_weight(self):
  174. if self.pretrained is not None:
  175. utils.load_entire_model(self, self.pretrained)
  176. class LRASPP(nn.Layer):
  177. def __init__(self, in_channels, out_channels):
  178. super().__init__()
  179. self.aspp1 = nn.Sequential(
  180. nn.Conv2D(
  181. in_channels, out_channels, 1, bias_attr=False),
  182. nn.BatchNorm2D(out_channels),
  183. nn.ReLU())
  184. self.aspp2 = nn.Sequential(
  185. nn.AdaptiveAvgPool2D(1),
  186. nn.Conv2D(
  187. in_channels, out_channels, 1, bias_attr=False),
  188. nn.Sigmoid())
  189. def forward_single_frame(self, x):
  190. return self.aspp1(x) * self.aspp2(x)
  191. def forward_time_series(self, x):
  192. B, T = x.shape[:2]
  193. x = self.forward_single_frame(x.flatten(0, 1))
  194. x = x.reshape((B, T, *(paddle.shape(x)[1:])))
  195. return x
  196. def forward(self, x):
  197. if x.ndim == 5:
  198. return self.forward_time_series(x)
  199. else:
  200. return self.forward_single_frame(x)
  201. class RecurrentDecoder(nn.Layer):
  202. def __init__(self, feature_channels, decoder_channels):
  203. super().__init__()
  204. self.avgpool = AvgPool()
  205. self.decode4 = BottleneckBlock(feature_channels[3])
  206. self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2],
  207. 3, decoder_channels[0])
  208. self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1],
  209. 3, decoder_channels[1])
  210. self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0],
  211. 3, decoder_channels[2])
  212. self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
  213. def forward(self,
  214. s0: Tensor,
  215. f1: Tensor,
  216. f2: Tensor,
  217. f3: Tensor,
  218. f4: Tensor,
  219. r1: Optional[Tensor],
  220. r2: Optional[Tensor],
  221. r3: Optional[Tensor],
  222. r4: Optional[Tensor]):
  223. s1, s2, s3 = self.avgpool(s0)
  224. x4, r4 = self.decode4(f4, r4)
  225. x3, r3 = self.decode3(x4, f3, s3, r3)
  226. x2, r2 = self.decode2(x3, f2, s2, r2)
  227. x1, r1 = self.decode1(x2, f1, s1, r1)
  228. x0 = self.decode0(x1, s0)
  229. return x0, r1, r2, r3, r4
  230. class AvgPool(nn.Layer):
  231. def __init__(self):
  232. super().__init__()
  233. self.avgpool = nn.AvgPool2D(2, 2, ceil_mode=True)
  234. def forward_single_frame(self, s0):
  235. s1 = self.avgpool(s0)
  236. s2 = self.avgpool(s1)
  237. s3 = self.avgpool(s2)
  238. return s1, s2, s3
  239. def forward_time_series(self, s0):
  240. B, T = paddle.shape(s0)[:2]
  241. s0 = s0.flatten(0, 1)
  242. s1, s2, s3 = self.forward_single_frame(s0)
  243. s1 = s1.reshape((B, T, *(paddle.shape(s1)[1:])))
  244. s2 = s2.reshape((B, T, *(paddle.shape(s2)[1:])))
  245. s3 = s3.reshape((B, T, *(paddle.shape(s3)[1:])))
  246. return s1, s2, s3
  247. def forward(self, s0):
  248. if s0.ndim == 5:
  249. return self.forward_time_series(s0)
  250. else:
  251. return self.forward_single_frame(s0)
  252. class BottleneckBlock(nn.Layer):
  253. def __init__(self, channels):
  254. super().__init__()
  255. self.channels = channels
  256. self.gru = ConvGRU(channels // 2)
  257. def forward(self, x, r=None):
  258. a, b = x.split(2, axis=-3)
  259. b, r = self.gru(b, r)
  260. x = paddle.concat([a, b], axis=-3)
  261. return x, r
  262. class UpsamplingBlock(nn.Layer):
  263. def __init__(self, in_channels, skip_channels, src_channels, out_channels):
  264. super().__init__()
  265. self.out_channels = out_channels
  266. self.upsample = nn.Upsample(
  267. scale_factor=2, mode='bilinear', align_corners=False)
  268. self.conv = nn.Sequential(
  269. nn.Conv2D(
  270. in_channels + skip_channels + src_channels,
  271. out_channels,
  272. 3,
  273. 1,
  274. 1,
  275. bias_attr=False),
  276. nn.BatchNorm2D(out_channels),
  277. nn.ReLU(), )
  278. self.gru = ConvGRU(out_channels // 2)
  279. def forward_single_frame(self, x, f, s, r: Optional[Tensor]):
  280. x = self.upsample(x)
  281. x = x[:, :, :paddle.shape(s)[2], :paddle.shape(s)[3]]
  282. x = paddle.concat([x, f, s], axis=1)
  283. x = self.conv(x)
  284. a, b = x.split(2, axis=1)
  285. b, r = self.gru(b, r)
  286. x = paddle.concat([a, b], axis=1)
  287. return x, r
  288. def forward_time_series(self, x, f, s, r: Optional[Tensor]):
  289. B, T, _, H, W = s.shape
  290. x = x.flatten(0, 1)
  291. f = f.flatten(0, 1)
  292. s = s.flatten(0, 1)
  293. x = self.upsample(x)
  294. x = x[:, :, :H, :W]
  295. x = paddle.concat([x, f, s], axis=1)
  296. x = self.conv(x)
  297. _, c, h, w = paddle.shape(x)
  298. x = x.reshape((B, T, c, h, w))
  299. a, b = x.split(2, axis=2)
  300. b, r = self.gru(b, r)
  301. x = paddle.concat([a, b], axis=2)
  302. return x, r
  303. def forward(self, x, f, s, r: Optional[Tensor]):
  304. if x.ndim == 5:
  305. return self.forward_time_series(x, f, s, r)
  306. else:
  307. return self.forward_single_frame(x, f, s, r)
  308. class OutputBlock(nn.Layer):
  309. def __init__(self, in_channels, src_channels, out_channels):
  310. super().__init__()
  311. self.upsample = nn.Upsample(
  312. scale_factor=2, mode='bilinear', align_corners=False)
  313. self.conv = nn.Sequential(
  314. nn.Conv2D(
  315. in_channels + src_channels,
  316. out_channels,
  317. 3,
  318. 1,
  319. 1,
  320. bias_attr=False),
  321. nn.BatchNorm2D(out_channels),
  322. nn.ReLU(),
  323. nn.Conv2D(
  324. out_channels, out_channels, 3, 1, 1, bias_attr=False),
  325. nn.BatchNorm2D(out_channels),
  326. nn.ReLU(), )
  327. def forward_single_frame(self, x, s):
  328. _, _, H, W = paddle.shape(s)
  329. x = self.upsample(x)
  330. x = x[:, :, :H, :W]
  331. x = paddle.concat([x, s], axis=1)
  332. x = self.conv(x)
  333. return x
  334. def forward_time_series(self, x, s):
  335. B, T, C, H, W = paddle.shape(s)
  336. x = x.flatten(0, 1)
  337. s = s.flatten(0, 1)
  338. x = self.upsample(x)
  339. x = x[:, :, :H, :W]
  340. x = paddle.concat([x, s], axis=1)
  341. x = self.conv(x)
  342. x = paddle.reshape(x, (B, T, paddle.shape(x)[1], H, W))
  343. return x
  344. def forward(self, x, s):
  345. if x.ndim == 5:
  346. return self.forward_time_series(x, s)
  347. else:
  348. return self.forward_single_frame(x, s)
  349. class ConvGRU(nn.Layer):
  350. def __init__(self, channels, kernel_size=3, padding=1):
  351. super().__init__()
  352. self.channels = channels
  353. self.ih = nn.Sequential(
  354. nn.Conv2D(
  355. channels * 2, channels * 2, kernel_size, padding=padding),
  356. nn.Sigmoid())
  357. self.hh = nn.Sequential(
  358. nn.Conv2D(
  359. channels * 2, channels, kernel_size, padding=padding),
  360. nn.Tanh())
  361. def forward_single_frame(self, x, h):
  362. r, z = self.ih(paddle.concat([x, h], axis=1)).split(2, axis=1)
  363. c = self.hh(paddle.concat([x, r * h], axis=1))
  364. h = (1 - z) * h + z * c
  365. return h, h
  366. def forward_time_series(self, x, h):
  367. o = []
  368. for xt in x.unbind(axis=1):
  369. ot, h = self.forward_single_frame(xt, h)
  370. o.append(ot)
  371. o = paddle.stack(o, axis=1)
  372. return o, h
  373. def forward(self, x, h=None):
  374. if h is None:
  375. h = paddle.zeros(
  376. (paddle.shape(x)[0], paddle.shape(x)[-3], paddle.shape(x)[-2],
  377. paddle.shape(x)[-1]),
  378. dtype=x.dtype)
  379. if x.ndim == 5:
  380. return self.forward_time_series(x, h)
  381. else:
  382. return self.forward_single_frame(x, h)
  383. class Projection(nn.Layer):
  384. def __init__(self, in_channels, out_channels):
  385. super().__init__()
  386. self.conv = nn.Conv2D(in_channels, out_channels, 1)
  387. def forward_single_frame(self, x):
  388. return self.conv(x)
  389. def forward_time_series(self, x):
  390. B, T = paddle.shape(x)[:2]
  391. x = self.conv(x.flatten(0, 1))
  392. _, C, H, W = paddle.shape(x)
  393. x = x.reshape((B, T, C, H, W))
  394. return x
  395. def forward(self, x):
  396. if x.ndim == 5:
  397. return self.forward_time_series(x)
  398. else:
  399. return self.forward_single_frame(x)
  400. class FastGuidedFilterRefiner(nn.Layer):
  401. def __init__(self, *args, **kwargs):
  402. super().__init__()
  403. self.guilded_filter = FastGuidedFilter(1)
  404. def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha):
  405. fine_src_gray = fine_src.mean(1, keepdim=True)
  406. base_src_gray = base_src.mean(1, keepdim=True)
  407. fgr, pha = self.guilded_filter(
  408. paddle.concat(
  409. [base_src, base_src_gray], axis=1),
  410. paddle.concat(
  411. [base_fgr, base_pha], axis=1),
  412. paddle.concat(
  413. [fine_src, fine_src_gray], axis=1)).split(
  414. [3, 1], axis=1)
  415. return fgr, pha
  416. def forward_time_series(self, fine_src, base_src, base_fgr, base_pha):
  417. B, T = fine_src.shape[:2]
  418. fgr, pha = self.forward_single_frame(
  419. fine_src.flatten(0, 1),
  420. base_src.flatten(0, 1),
  421. base_fgr.flatten(0, 1), base_pha.flatten(0, 1))
  422. *_, C, H, W = paddle.shape(fgr)
  423. fgr = fgr.reshape((B, T, C, H, W))
  424. pha = pha.reshape((B, T, 1, H, W))
  425. return fgr, pha
  426. def forward(self, fine_src, base_src, base_fgr, base_pha, *args, **kwargs):
  427. if fine_src.ndim == 5:
  428. return self.forward_time_series(fine_src, base_src, base_fgr,
  429. base_pha)
  430. else:
  431. return self.forward_single_frame(fine_src, base_src, base_fgr,
  432. base_pha)
  433. class DeepGuidedFilterRefiner(nn.Layer):
  434. def __init__(self, hid_channels=16):
  435. super().__init__()
  436. self.box_filter = nn.Conv2D(
  437. 4, 4, kernel_size=3, padding=1, bias_attr=False, groups=4)
  438. self.box_filter.weight.set_value(
  439. paddle.zeros_like(self.box_filter.weight) + 1 / 9)
  440. self.conv = nn.Sequential(
  441. nn.Conv2D(
  442. 4 * 2 + hid_channels,
  443. hid_channels,
  444. kernel_size=1,
  445. bias_attr=False),
  446. nn.BatchNorm2D(hid_channels),
  447. nn.ReLU(),
  448. nn.Conv2D(
  449. hid_channels, hid_channels, kernel_size=1, bias_attr=False),
  450. nn.BatchNorm2D(hid_channels),
  451. nn.ReLU(),
  452. nn.Conv2D(
  453. hid_channels, 4, kernel_size=1, bias_attr=True))
  454. def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha,
  455. base_hid):
  456. fine_x = paddle.concat(
  457. [fine_src, fine_src.mean(
  458. 1, keepdim=True)], axis=1)
  459. base_x = paddle.concat(
  460. [base_src, base_src.mean(
  461. 1, keepdim=True)], axis=1)
  462. base_y = paddle.concat([base_fgr, base_pha], axis=1)
  463. mean_x = self.box_filter(base_x)
  464. mean_y = self.box_filter(base_y)
  465. cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
  466. var_x = self.box_filter(base_x * base_x) - mean_x * mean_x
  467. A = self.conv(paddle.concat([cov_xy, var_x, base_hid], axis=1))
  468. b = mean_y - A * mean_x
  469. H, W = paddle.shape(fine_src)[2:]
  470. A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
  471. b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
  472. out = A * fine_x + b
  473. fgr, pha = out.split([3, 1], axis=1)
  474. return fgr, pha
  475. def forward_time_series(self, fine_src, base_src, base_fgr, base_pha,
  476. base_hid):
  477. B, T = fine_src.shape[:2]
  478. fgr, pha = self.forward_single_frame(
  479. fine_src.flatten(0, 1),
  480. base_src.flatten(0, 1),
  481. base_fgr.flatten(0, 1),
  482. base_pha.flatten(0, 1), base_hid.flatten(0, 1))
  483. *_, C, H, W = paddle.shape(fgr)
  484. fgr = fgr.reshape((B, T, C, H, W))
  485. pha = pha.reshape((B, T, 1, H, W))
  486. return fgr, pha
  487. def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
  488. if fine_src.ndim == 5:
  489. return self.forward_time_series(fine_src, base_src, base_fgr,
  490. base_pha, base_hid)
  491. else:
  492. return self.forward_single_frame(fine_src, base_src, base_fgr,
  493. base_pha, base_hid)