123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494 |
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections import defaultdict
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- import numpy as np
- import scipy
- import paddleseg
- from paddleseg.models import layers, losses
- from paddleseg import utils
- from paddleseg.cvlibs import manager, param_init
- @manager.MODELS.add_component
- class MODNet(nn.Layer):
- """
- The MODNet implementation based on PaddlePaddle.
- The original article refers to
- Zhanghan Ke, et, al. "Is a Green Screen Really Necessary for Real-Time Portrait Matting?"
- (https://arxiv.org/pdf/2011.11961.pdf).
- Args:
- backbone: backbone model.
- hr(int, optional): The channels of high resolutions branch. Defautl: None.
- pretrained(str, optional): The path of pretrianed model. Defautl: None.
- """
- def __init__(self, backbone, hr_channels=32, pretrained=None):
- super().__init__()
- self.backbone = backbone
- self.pretrained = pretrained
- self.head = MODNetHead(
- hr_channels=hr_channels, backbone_channels=backbone.feat_channels)
- self.init_weight()
- self.blurer = GaussianBlurLayer(1, 3)
- self.loss_func_dict = None
- def forward(self, inputs):
- """
- If training, return a dict.
- If evaluation, return the final alpha prediction.
- """
- x = inputs['img']
- feat_list = self.backbone(x)
- y = self.head(inputs=inputs, feat_list=feat_list)
- if self.training:
- loss = self.loss(y, inputs)
- return y, loss
- else:
- return y
- def loss(self, logit_dict, label_dict, loss_func_dict=None):
- if loss_func_dict is None:
- if self.loss_func_dict is None:
- self.loss_func_dict = defaultdict(list)
- self.loss_func_dict['semantic'].append(paddleseg.models.MSELoss(
- ))
- self.loss_func_dict['detail'].append(paddleseg.models.L1Loss())
- self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
- self.loss_func_dict['fusion'].append(paddleseg.models.L1Loss())
- else:
- self.loss_func_dict = loss_func_dict
- loss = {}
- # semantic loss
- semantic_gt = F.interpolate(
- label_dict['alpha'],
- scale_factor=1 / 16,
- mode='bilinear',
- align_corners=False)
- semantic_gt = self.blurer(semantic_gt)
- # semantic_gt.stop_gradient=True
- loss['semantic'] = self.loss_func_dict['semantic'][0](
- logit_dict['semantic'], semantic_gt)
- # detail loss
- trimap = label_dict['trimap']
- mask = (trimap == 128).astype('float32')
- logit_detail = logit_dict['detail'] * mask
- label_detail = label_dict['alpha'] * mask
- loss_detail = self.loss_func_dict['detail'][0](logit_detail,
- label_detail)
- loss_detail = loss_detail / (mask.mean() + 1e-6)
- loss['detail'] = 10 * loss_detail
- # fusion loss
- matte = logit_dict['matte']
- alpha = label_dict['alpha']
- transition_mask = label_dict['trimap'] == 128
- matte_boundary = paddle.where(transition_mask, matte, alpha)
- # l1 loss
- loss_fusion_l1 = self.loss_func_dict['fusion'][0](
- matte, alpha) + 4 * self.loss_func_dict['fusion'][0](matte_boundary,
- alpha)
- # composition loss
- loss_fusion_comp = self.loss_func_dict['fusion'][1](
- matte * label_dict['img'], alpha *
- label_dict['img']) + 4 * self.loss_func_dict['fusion'][1](
- matte_boundary * label_dict['img'], alpha * label_dict['img'])
- # consisten loss with semantic
- transition_mask = F.interpolate(
- label_dict['trimap'],
- scale_factor=1 / 16,
- mode='nearest',
- align_corners=False)
- transition_mask = transition_mask == 128
- matte_con_sem = F.interpolate(
- matte, scale_factor=1 / 16, mode='bilinear', align_corners=False)
- matte_con_sem = self.blurer(matte_con_sem)
- logit_semantic = logit_dict['semantic'].clone()
- logit_semantic.stop_gradient = True
- matte_con_sem = paddle.where(transition_mask, logit_semantic,
- matte_con_sem)
- if False:
- import cv2
- matte_con_sem_num = matte_con_sem.numpy()
- matte_con_sem_num = matte_con_sem_num[0].squeeze()
- matte_con_sem_num = (matte_con_sem_num * 255).astype('uint8')
- semantic = logit_dict['semantic'].numpy()
- semantic = semantic[0].squeeze()
- semantic = (semantic * 255).astype('uint8')
- transition_mask = transition_mask.astype('uint8')
- transition_mask = transition_mask.numpy()
- transition_mask = (transition_mask[0].squeeze()) * 255
- cv2.imwrite('matte_con.png', matte_con_sem_num)
- cv2.imwrite('semantic.png', semantic)
- cv2.imwrite('transition.png', transition_mask)
- mse_loss = paddleseg.models.MSELoss()
- loss_fusion_con_sem = mse_loss(matte_con_sem, logit_dict['semantic'])
- loss_fusion = loss_fusion_l1 + loss_fusion_comp + loss_fusion_con_sem
- loss['fusion'] = loss_fusion
- loss['fusion_l1'] = loss_fusion_l1
- loss['fusion_comp'] = loss_fusion_comp
- loss['fusion_con_sem'] = loss_fusion_con_sem
- loss['all'] = loss['semantic'] + loss['detail'] + loss['fusion']
- return loss
- def init_weight(self):
- if self.pretrained is not None:
- utils.load_entire_model(self, self.pretrained)
- class MODNetHead(nn.Layer):
- def __init__(self, hr_channels, backbone_channels):
- super().__init__()
- self.lr_branch = LRBranch(backbone_channels)
- self.hr_branch = HRBranch(hr_channels, backbone_channels)
- self.f_branch = FusionBranch(hr_channels, backbone_channels)
- self.init_weight()
- def forward(self, inputs, feat_list):
- pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(feat_list)
- pred_detail, hr2x = self.hr_branch(inputs['img'], enc2x, enc4x, lr8x)
- pred_matte = self.f_branch(inputs['img'], lr8x, hr2x)
- if self.training:
- logit_dict = {
- 'semantic': pred_semantic,
- 'detail': pred_detail,
- 'matte': pred_matte
- }
- return logit_dict
- else:
- return pred_matte
- def init_weight(self):
- for layer in self.sublayers():
- if isinstance(layer, nn.Conv2D):
- param_init.kaiming_uniform(layer.weight)
- class FusionBranch(nn.Layer):
- def __init__(self, hr_channels, enc_channels):
- super().__init__()
- self.conv_lr4x = Conv2dIBNormRelu(
- enc_channels[2], hr_channels, 5, stride=1, padding=2)
- self.conv_f2x = Conv2dIBNormRelu(
- 2 * hr_channels, hr_channels, 3, stride=1, padding=1)
- self.conv_f = nn.Sequential(
- Conv2dIBNormRelu(
- hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
- Conv2dIBNormRelu(
- int(hr_channels / 2),
- 1,
- 1,
- stride=1,
- padding=0,
- with_ibn=False,
- with_relu=False))
- def forward(self, img, lr8x, hr2x):
- lr4x = F.interpolate(
- lr8x, scale_factor=2, mode='bilinear', align_corners=False)
- lr4x = self.conv_lr4x(lr4x)
- lr2x = F.interpolate(
- lr4x, scale_factor=2, mode='bilinear', align_corners=False)
- f2x = self.conv_f2x(paddle.concat((lr2x, hr2x), axis=1))
- f = F.interpolate(
- f2x, scale_factor=2, mode='bilinear', align_corners=False)
- f = self.conv_f(paddle.concat((f, img), axis=1))
- pred_matte = F.sigmoid(f)
- return pred_matte
- class HRBranch(nn.Layer):
- """
- High Resolution Branch of MODNet
- """
- def __init__(self, hr_channels, enc_channels):
- super().__init__()
- self.tohr_enc2x = Conv2dIBNormRelu(
- enc_channels[0], hr_channels, 1, stride=1, padding=0)
- self.conv_enc2x = Conv2dIBNormRelu(
- hr_channels + 3, hr_channels, 3, stride=2, padding=1)
- self.tohr_enc4x = Conv2dIBNormRelu(
- enc_channels[1], hr_channels, 1, stride=1, padding=0)
- self.conv_enc4x = Conv2dIBNormRelu(
- 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
- self.conv_hr4x = nn.Sequential(
- Conv2dIBNormRelu(
- 2 * hr_channels + enc_channels[2] + 3,
- 2 * hr_channels,
- 3,
- stride=1,
- padding=1),
- Conv2dIBNormRelu(
- 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
- Conv2dIBNormRelu(
- 2 * hr_channels, hr_channels, 3, stride=1, padding=1))
- self.conv_hr2x = nn.Sequential(
- Conv2dIBNormRelu(
- 2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
- Conv2dIBNormRelu(
- 2 * hr_channels, hr_channels, 3, stride=1, padding=1),
- Conv2dIBNormRelu(
- hr_channels, hr_channels, 3, stride=1, padding=1),
- Conv2dIBNormRelu(
- hr_channels, hr_channels, 3, stride=1, padding=1))
- self.conv_hr = nn.Sequential(
- Conv2dIBNormRelu(
- hr_channels + 3, hr_channels, 3, stride=1, padding=1),
- Conv2dIBNormRelu(
- hr_channels,
- 1,
- 1,
- stride=1,
- padding=0,
- with_ibn=False,
- with_relu=False))
- def forward(self, img, enc2x, enc4x, lr8x):
- img2x = F.interpolate(
- img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
- img4x = F.interpolate(
- img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
- enc2x = self.tohr_enc2x(enc2x)
- hr4x = self.conv_enc2x(paddle.concat((img2x, enc2x), axis=1))
- enc4x = self.tohr_enc4x(enc4x)
- hr4x = self.conv_enc4x(paddle.concat((hr4x, enc4x), axis=1))
- lr4x = F.interpolate(
- lr8x, scale_factor=2, mode='bilinear', align_corners=False)
- hr4x = self.conv_hr4x(paddle.concat((hr4x, lr4x, img4x), axis=1))
- hr2x = F.interpolate(
- hr4x, scale_factor=2, mode='bilinear', align_corners=False)
- hr2x = self.conv_hr2x(paddle.concat((hr2x, enc2x), axis=1))
- pred_detail = None
- if self.training:
- hr = F.interpolate(
- hr2x, scale_factor=2, mode='bilinear', align_corners=False)
- hr = self.conv_hr(paddle.concat((hr, img), axis=1))
- pred_detail = F.sigmoid(hr)
- return pred_detail, hr2x
- class LRBranch(nn.Layer):
- def __init__(self, backbone_channels):
- super().__init__()
- self.se_block = SEBlock(backbone_channels[4], reduction=4)
- self.conv_lr16x = Conv2dIBNormRelu(
- backbone_channels[4], backbone_channels[3], 5, stride=1, padding=2)
- self.conv_lr8x = Conv2dIBNormRelu(
- backbone_channels[3], backbone_channels[2], 5, stride=1, padding=2)
- self.conv_lr = Conv2dIBNormRelu(
- backbone_channels[2],
- 1,
- 3,
- stride=2,
- padding=1,
- with_ibn=False,
- with_relu=False)
- def forward(self, feat_list):
- enc2x, enc4x, enc32x = feat_list[0], feat_list[1], feat_list[4]
- enc32x = self.se_block(enc32x)
- lr16x = F.interpolate(
- enc32x, scale_factor=2, mode='bilinear', align_corners=False)
- lr16x = self.conv_lr16x(lr16x)
- lr8x = F.interpolate(
- lr16x, scale_factor=2, mode='bilinear', align_corners=False)
- lr8x = self.conv_lr8x(lr8x)
- pred_semantic = None
- if self.training:
- lr = self.conv_lr(lr8x)
- pred_semantic = F.sigmoid(lr)
- return pred_semantic, lr8x, [enc2x, enc4x]
- class IBNorm(nn.Layer):
- """
- Combine Instance Norm and Batch Norm into One Layer
- """
- def __init__(self, in_channels):
- super().__init__()
- self.bnorm_channels = in_channels // 2
- self.inorm_channels = in_channels - self.bnorm_channels
- self.bnorm = nn.BatchNorm2D(self.bnorm_channels)
- self.inorm = nn.InstanceNorm2D(self.inorm_channels)
- def forward(self, x):
- bn_x = self.bnorm(x[:, :self.bnorm_channels, :, :])
- in_x = self.inorm(x[:, self.bnorm_channels:, :, :])
- return paddle.concat((bn_x, in_x), 1)
- class Conv2dIBNormRelu(nn.Layer):
- """
- Convolution + IBNorm + Relu
- """
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias_attr=None,
- with_ibn=True,
- with_relu=True):
- super().__init__()
- layers = [
- nn.Conv2D(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias_attr=bias_attr)
- ]
- if with_ibn:
- layers.append(IBNorm(out_channels))
- if with_relu:
- layers.append(nn.ReLU())
- self.layers = nn.Sequential(*layers)
- def forward(self, x):
- return self.layers(x)
- class SEBlock(nn.Layer):
- """
- SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
- """
- def __init__(self, num_channels, reduction=1):
- super().__init__()
- self.pool = nn.AdaptiveAvgPool2D(1)
- self.conv = nn.Sequential(
- nn.Conv2D(
- num_channels,
- int(num_channels // reduction),
- 1,
- bias_attr=False),
- nn.ReLU(),
- nn.Conv2D(
- int(num_channels // reduction),
- num_channels,
- 1,
- bias_attr=False),
- nn.Sigmoid())
- def forward(self, x):
- w = self.pool(x)
- w = self.conv(w)
- return w * x
- class GaussianBlurLayer(nn.Layer):
- """ Add Gaussian Blur to a 4D tensors
- This layer takes a 4D tensor of {N, C, H, W} as input.
- The Gaussian blur will be performed in given channel number (C) splitly.
- """
- def __init__(self, channels, kernel_size):
- """
- Args:
- channels (int): Channel for input tensor
- kernel_size (int): Size of the kernel used in blurring
- """
- super(GaussianBlurLayer, self).__init__()
- self.channels = channels
- self.kernel_size = kernel_size
- assert self.kernel_size % 2 != 0
- self.op = nn.Sequential(
- nn.Pad2D(
- int(self.kernel_size / 2), mode='reflect'),
- nn.Conv2D(
- channels,
- channels,
- self.kernel_size,
- stride=1,
- padding=0,
- bias_attr=False,
- groups=channels))
- self._init_kernel()
- self.op[1].weight.stop_gradient = True
- def forward(self, x):
- """
- Args:
- x (paddle.Tensor): input 4D tensor
- Returns:
- paddle.Tensor: Blurred version of the input
- """
- if not len(list(x.shape)) == 4:
- print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
- exit()
- elif not x.shape[1] == self.channels:
- print('In \'GaussianBlurLayer\', the required channel ({0}) is'
- 'not the same as input ({1})\n'.format(self.channels, x.shape[
- 1]))
- exit()
- return self.op(x)
- def _init_kernel(self):
- sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
- n = np.zeros((self.kernel_size, self.kernel_size))
- i = int(self.kernel_size / 2)
- n[i, i] = 1
- kernel = scipy.ndimage.gaussian_filter(n, sigma)
- kernel = kernel.astype('float32')
- kernel = kernel[np.newaxis, np.newaxis, :, :]
- paddle.assign(kernel, self.op[1].weight)
|