# This program is about RVM implementation based on PaddlePaddle according to # https://github.com/PeterL1n/RobustVideoMatting. # Copyright (C) 2022 PaddlePaddle Authors. # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see . import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import Tensor import paddleseg from paddleseg import utils from paddleseg.models import layers from paddleseg.cvlibs import manager from typing import Tuple, Optional from ppmatting.models import FastGuidedFilter @manager.MODELS.add_component class RVM(nn.Layer): """ The RVM implementation based on PaddlePaddle. The original article refers to Shanchuan Lin1, et, al. "Robust High-Resolution Video Matting with Temporal Guidance" (https://arxiv.org/pdf/2108.11515.pdf). Args: backbone: backbone model. lraspp_in_channels (int, optional): lraspp_out_channels (int, optional): decoder_channels (int, optional): refiner (str, optional): downsample_ratio (float, optional): pretrained(str, optional): The path of pretrianed model. Defautl: None. to_rgb(bool, optional): The fgr results change to rgb format. Default: True. """ def __init__(self, backbone, lraspp_in_channels=960, lraspp_out_channels=128, decoder_channels=(80, 40, 32, 16), refiner='deep_guided_filter', downsample_ratio=1., pretrained=None, to_rgb=True): super().__init__() self.backbone = backbone self.aspp = LRASPP(lraspp_in_channels, lraspp_out_channels) rd_fea_channels = self.backbone.feat_channels[:-1] + [ lraspp_out_channels ] self.decoder = RecurrentDecoder(rd_fea_channels, decoder_channels) self.project_mat = Projection(decoder_channels[-1], 4) self.project_seg = Projection(decoder_channels[-1], 1) if refiner == 'deep_guided_filter': self.refiner = DeepGuidedFilterRefiner() else: self.refiner = FastGuidedFilterRefiner() self.downsample_ratio = downsample_ratio self.pretrained = pretrained self.to_rgb = to_rgb self.r1 = None self.r2 = None self.r3 = None self.r4 = None def forward(self, data, r1=None, r2=None, r3=None, r4=None, downsample_ratio=None, segmentation_pass=False): src = data['img'] if downsample_ratio is None: downsample_ratio = self.downsample_ratio if r1 is not None and r2 is not None and r3 is not None and r4 is not None: self.r1, self.r2, self.r3, self.r4 = r1, r2, r3, r4 result = self.forward_( src, r1=self.r1, r2=self.r2, r3=self.r3, r4=self.r4, downsample_ratio=downsample_ratio, segmentation_pass=segmentation_pass) if self.training: raise RuntimeError('Sorry! RVM now do not support training') else: if segmentation_pass: seg, self.r1, self.r2, self.r3, self.r4 = result return {'alpha': seg} else: fgr, pha, self.r1, self.r2, self.r3, self.r4 = result if self.to_rgb: fgr = paddle.flip(fgr, axis=-3) return { 'alpha': pha, "fg": fgr, "r1": self.r1, "r2": self.r2, "r3": self.r3, "r4": self.r4 } def forward_(self, src, r1=None, r2=None, r3=None, r4=None, downsample_ratio=1., segmentation_pass=False): if isinstance(downsample_ratio, paddle.fluid.framework.Variable): # for export src_sm = self._interpolate(src, scale_factor=downsample_ratio) elif downsample_ratio != 1: src_sm = self._interpolate(src, scale_factor=downsample_ratio) else: src_sm = src f1, f2, f3, f4 = self.backbone_forward(src_sm) f4 = self.aspp(f4) hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) if not segmentation_pass: fgr_residual, pha = self.project_mat(hid).split([3, 1], axis=-3) if downsample_ratio != 1: fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) fgr = fgr_residual + src fgr = fgr.clip(0., 1.) pha = pha.clip(0., 1.) return [fgr, pha, *rec] else: seg = self.project_seg(hid) return [seg, *rec] def reset(self): """ When a video is predicted, the history memory shoulb be reset. """ self.r1 = None self.r2 = None self.r3 = None self.r4 = None def backbone_forward(self, x): if x.ndim == 5: B, T = paddle.shape(x)[:2] features = self.backbone(x.flatten(0, 1)) for i, f in enumerate(features): features[i] = f.reshape((B, T, *(paddle.shape(f)[1:]))) else: features = self.backbone(x) return features def _interpolate(self, x: Tensor, scale_factor: float): if x.ndim == 5: B, T = paddle.shape(x)[:2] x = F.interpolate( x.flatten(0, 1), scale_factor=scale_factor, mode='bilinear', align_corners=False) *_, C, H, W = paddle.shape(x)[-3:] x = x.reshape((B, T, C, H, W)) else: x = F.interpolate( x, scale_factor=scale_factor, mode='bilinear', align_corners=False) return x def init_weight(self): if self.pretrained is not None: utils.load_entire_model(self, self.pretrained) class LRASPP(nn.Layer): def __init__(self, in_channels, out_channels): super().__init__() self.aspp1 = nn.Sequential( nn.Conv2D( in_channels, out_channels, 1, bias_attr=False), nn.BatchNorm2D(out_channels), nn.ReLU()) self.aspp2 = nn.Sequential( nn.AdaptiveAvgPool2D(1), nn.Conv2D( in_channels, out_channels, 1, bias_attr=False), nn.Sigmoid()) def forward_single_frame(self, x): return self.aspp1(x) * self.aspp2(x) def forward_time_series(self, x): B, T = x.shape[:2] x = self.forward_single_frame(x.flatten(0, 1)) x = x.reshape((B, T, *(paddle.shape(x)[1:]))) return x def forward(self, x): if x.ndim == 5: return self.forward_time_series(x) else: return self.forward_single_frame(x) class RecurrentDecoder(nn.Layer): def __init__(self, feature_channels, decoder_channels): super().__init__() self.avgpool = AvgPool() self.decode4 = BottleneckBlock(feature_channels[3]) self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0]) self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1]) self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2]) self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor, r1: Optional[Tensor], r2: Optional[Tensor], r3: Optional[Tensor], r4: Optional[Tensor]): s1, s2, s3 = self.avgpool(s0) x4, r4 = self.decode4(f4, r4) x3, r3 = self.decode3(x4, f3, s3, r3) x2, r2 = self.decode2(x3, f2, s2, r2) x1, r1 = self.decode1(x2, f1, s1, r1) x0 = self.decode0(x1, s0) return x0, r1, r2, r3, r4 class AvgPool(nn.Layer): def __init__(self): super().__init__() self.avgpool = nn.AvgPool2D(2, 2, ceil_mode=True) def forward_single_frame(self, s0): s1 = self.avgpool(s0) s2 = self.avgpool(s1) s3 = self.avgpool(s2) return s1, s2, s3 def forward_time_series(self, s0): B, T = paddle.shape(s0)[:2] s0 = s0.flatten(0, 1) s1, s2, s3 = self.forward_single_frame(s0) s1 = s1.reshape((B, T, *(paddle.shape(s1)[1:]))) s2 = s2.reshape((B, T, *(paddle.shape(s2)[1:]))) s3 = s3.reshape((B, T, *(paddle.shape(s3)[1:]))) return s1, s2, s3 def forward(self, s0): if s0.ndim == 5: return self.forward_time_series(s0) else: return self.forward_single_frame(s0) class BottleneckBlock(nn.Layer): def __init__(self, channels): super().__init__() self.channels = channels self.gru = ConvGRU(channels // 2) def forward(self, x, r=None): a, b = x.split(2, axis=-3) b, r = self.gru(b, r) x = paddle.concat([a, b], axis=-3) return x, r class UpsamplingBlock(nn.Layer): def __init__(self, in_channels, skip_channels, src_channels, out_channels): super().__init__() self.out_channels = out_channels self.upsample = nn.Upsample( scale_factor=2, mode='bilinear', align_corners=False) self.conv = nn.Sequential( nn.Conv2D( in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias_attr=False), nn.BatchNorm2D(out_channels), nn.ReLU(), ) self.gru = ConvGRU(out_channels // 2) def forward_single_frame(self, x, f, s, r: Optional[Tensor]): x = self.upsample(x) x = x[:, :, :paddle.shape(s)[2], :paddle.shape(s)[3]] x = paddle.concat([x, f, s], axis=1) x = self.conv(x) a, b = x.split(2, axis=1) b, r = self.gru(b, r) x = paddle.concat([a, b], axis=1) return x, r def forward_time_series(self, x, f, s, r: Optional[Tensor]): B, T, _, H, W = s.shape x = x.flatten(0, 1) f = f.flatten(0, 1) s = s.flatten(0, 1) x = self.upsample(x) x = x[:, :, :H, :W] x = paddle.concat([x, f, s], axis=1) x = self.conv(x) _, c, h, w = paddle.shape(x) x = x.reshape((B, T, c, h, w)) a, b = x.split(2, axis=2) b, r = self.gru(b, r) x = paddle.concat([a, b], axis=2) return x, r def forward(self, x, f, s, r: Optional[Tensor]): if x.ndim == 5: return self.forward_time_series(x, f, s, r) else: return self.forward_single_frame(x, f, s, r) class OutputBlock(nn.Layer): def __init__(self, in_channels, src_channels, out_channels): super().__init__() self.upsample = nn.Upsample( scale_factor=2, mode='bilinear', align_corners=False) self.conv = nn.Sequential( nn.Conv2D( in_channels + src_channels, out_channels, 3, 1, 1, bias_attr=False), nn.BatchNorm2D(out_channels), nn.ReLU(), nn.Conv2D( out_channels, out_channels, 3, 1, 1, bias_attr=False), nn.BatchNorm2D(out_channels), nn.ReLU(), ) def forward_single_frame(self, x, s): _, _, H, W = paddle.shape(s) x = self.upsample(x) x = x[:, :, :H, :W] x = paddle.concat([x, s], axis=1) x = self.conv(x) return x def forward_time_series(self, x, s): B, T, C, H, W = paddle.shape(s) x = x.flatten(0, 1) s = s.flatten(0, 1) x = self.upsample(x) x = x[:, :, :H, :W] x = paddle.concat([x, s], axis=1) x = self.conv(x) x = paddle.reshape(x, (B, T, paddle.shape(x)[1], H, W)) return x def forward(self, x, s): if x.ndim == 5: return self.forward_time_series(x, s) else: return self.forward_single_frame(x, s) class ConvGRU(nn.Layer): def __init__(self, channels, kernel_size=3, padding=1): super().__init__() self.channels = channels self.ih = nn.Sequential( nn.Conv2D( channels * 2, channels * 2, kernel_size, padding=padding), nn.Sigmoid()) self.hh = nn.Sequential( nn.Conv2D( channels * 2, channels, kernel_size, padding=padding), nn.Tanh()) def forward_single_frame(self, x, h): r, z = self.ih(paddle.concat([x, h], axis=1)).split(2, axis=1) c = self.hh(paddle.concat([x, r * h], axis=1)) h = (1 - z) * h + z * c return h, h def forward_time_series(self, x, h): o = [] for xt in x.unbind(axis=1): ot, h = self.forward_single_frame(xt, h) o.append(ot) o = paddle.stack(o, axis=1) return o, h def forward(self, x, h=None): if h is None: h = paddle.zeros( (paddle.shape(x)[0], paddle.shape(x)[-3], paddle.shape(x)[-2], paddle.shape(x)[-1]), dtype=x.dtype) if x.ndim == 5: return self.forward_time_series(x, h) else: return self.forward_single_frame(x, h) class Projection(nn.Layer): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2D(in_channels, out_channels, 1) def forward_single_frame(self, x): return self.conv(x) def forward_time_series(self, x): B, T = paddle.shape(x)[:2] x = self.conv(x.flatten(0, 1)) _, C, H, W = paddle.shape(x) x = x.reshape((B, T, C, H, W)) return x def forward(self, x): if x.ndim == 5: return self.forward_time_series(x) else: return self.forward_single_frame(x) class FastGuidedFilterRefiner(nn.Layer): def __init__(self, *args, **kwargs): super().__init__() self.guilded_filter = FastGuidedFilter(1) def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha): fine_src_gray = fine_src.mean(1, keepdim=True) base_src_gray = base_src.mean(1, keepdim=True) fgr, pha = self.guilded_filter( paddle.concat( [base_src, base_src_gray], axis=1), paddle.concat( [base_fgr, base_pha], axis=1), paddle.concat( [fine_src, fine_src_gray], axis=1)).split( [3, 1], axis=1) return fgr, pha def forward_time_series(self, fine_src, base_src, base_fgr, base_pha): B, T = fine_src.shape[:2] fgr, pha = self.forward_single_frame( fine_src.flatten(0, 1), base_src.flatten(0, 1), base_fgr.flatten(0, 1), base_pha.flatten(0, 1)) *_, C, H, W = paddle.shape(fgr) fgr = fgr.reshape((B, T, C, H, W)) pha = pha.reshape((B, T, 1, H, W)) return fgr, pha def forward(self, fine_src, base_src, base_fgr, base_pha, *args, **kwargs): if fine_src.ndim == 5: return self.forward_time_series(fine_src, base_src, base_fgr, base_pha) else: return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha) class DeepGuidedFilterRefiner(nn.Layer): def __init__(self, hid_channels=16): super().__init__() self.box_filter = nn.Conv2D( 4, 4, kernel_size=3, padding=1, bias_attr=False, groups=4) self.box_filter.weight.set_value( paddle.zeros_like(self.box_filter.weight) + 1 / 9) self.conv = nn.Sequential( nn.Conv2D( 4 * 2 + hid_channels, hid_channels, kernel_size=1, bias_attr=False), nn.BatchNorm2D(hid_channels), nn.ReLU(), nn.Conv2D( hid_channels, hid_channels, kernel_size=1, bias_attr=False), nn.BatchNorm2D(hid_channels), nn.ReLU(), nn.Conv2D( hid_channels, 4, kernel_size=1, bias_attr=True)) def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid): fine_x = paddle.concat( [fine_src, fine_src.mean( 1, keepdim=True)], axis=1) base_x = paddle.concat( [base_src, base_src.mean( 1, keepdim=True)], axis=1) base_y = paddle.concat([base_fgr, base_pha], axis=1) mean_x = self.box_filter(base_x) mean_y = self.box_filter(base_y) cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y var_x = self.box_filter(base_x * base_x) - mean_x * mean_x A = self.conv(paddle.concat([cov_xy, var_x, base_hid], axis=1)) b = mean_y - A * mean_x H, W = paddle.shape(fine_src)[2:] A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False) b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False) out = A * fine_x + b fgr, pha = out.split([3, 1], axis=1) return fgr, pha def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid): B, T = fine_src.shape[:2] fgr, pha = self.forward_single_frame( fine_src.flatten(0, 1), base_src.flatten(0, 1), base_fgr.flatten(0, 1), base_pha.flatten(0, 1), base_hid.flatten(0, 1)) *_, C, H, W = paddle.shape(fgr) fgr = fgr.reshape((B, T, C, H, W)) pha = pha.reshape((B, T, 1, H, W)) return fgr, pha def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): if fine_src.ndim == 5: return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid) else: return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid)