123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections import defaultdict
- import time
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- import paddleseg
- from paddleseg.models import layers
- from paddleseg import utils
- from paddleseg.cvlibs import manager
- from ppmatting.models.losses import MRSD, GradientLoss
- from ppmatting.models.backbone import resnet_vd
- @manager.MODELS.add_component
- class PPMatting(nn.Layer):
- """
- The PPMattinh implementation based on PaddlePaddle.
- The original article refers to
- Guowei Chen, et, al. "PP-Matting: High-Accuracy Natural Image Matting"
- (https://arxiv.org/pdf/2204.09433.pdf).
- Args:
- backbone: backbone model.
- pretrained(str, optional): The path of pretrianed model. Defautl: None.
- """
- def __init__(self, backbone, pretrained=None):
- super().__init__()
- self.backbone = backbone
- self.pretrained = pretrained
- self.loss_func_dict = self.get_loss_func_dict()
- self.backbone_channels = backbone.feat_channels
- self.scb = SCB(self.backbone_channels[-1])
- self.hrdb = HRDB(
- self.backbone_channels[0] + self.backbone_channels[1],
- scb_channels=self.scb.out_channels,
- gf_index=[0, 2, 4])
- self.init_weight()
- def forward(self, inputs):
- x = inputs['img']
- input_shape = paddle.shape(x)
- fea_list = self.backbone(x)
- scb_logits = self.scb(fea_list[-1])
- semantic_map = F.softmax(scb_logits[-1], axis=1)
- fea0 = F.interpolate(
- fea_list[0], input_shape[2:], mode='bilinear', align_corners=False)
- fea1 = F.interpolate(
- fea_list[1], input_shape[2:], mode='bilinear', align_corners=False)
- hrdb_input = paddle.concat([fea0, fea1], 1)
- hrdb_logit = self.hrdb(hrdb_input, scb_logits)
- detail_map = F.sigmoid(hrdb_logit)
- fusion = self.fusion(semantic_map, detail_map)
- if self.training:
- logit_dict = {
- 'semantic': semantic_map,
- 'detail': detail_map,
- 'fusion': fusion
- }
- loss_dict = self.loss(logit_dict, inputs)
- return logit_dict, loss_dict
- else:
- return fusion
- def get_loss_func_dict(self):
- loss_func_dict = defaultdict(list)
- loss_func_dict['semantic'].append(nn.NLLLoss())
- loss_func_dict['detail'].append(MRSD())
- loss_func_dict['detail'].append(GradientLoss())
- loss_func_dict['fusion'].append(MRSD())
- loss_func_dict['fusion'].append(MRSD())
- loss_func_dict['fusion'].append(GradientLoss())
- return loss_func_dict
- def loss(self, logit_dict, label_dict):
- loss = {}
- # semantic loss computation
- # get semantic label
- semantic_label = label_dict['trimap']
- semantic_label_trans = (semantic_label == 128).astype('int64')
- semantic_label_bg = (semantic_label == 0).astype('int64')
- semantic_label = semantic_label_trans + semantic_label_bg * 2
- loss_semantic = self.loss_func_dict['semantic'][0](
- paddle.log(logit_dict['semantic'] + 1e-6),
- semantic_label.squeeze(1))
- loss['semantic'] = loss_semantic
- # detail loss computation
- transparent = label_dict['trimap'] == 128
- detail_alpha_loss = self.loss_func_dict['detail'][0](
- logit_dict['detail'], label_dict['alpha'], transparent)
- # gradient loss
- detail_gradient_loss = self.loss_func_dict['detail'][1](
- logit_dict['detail'], label_dict['alpha'], transparent)
- loss_detail = detail_alpha_loss + detail_gradient_loss
- loss['detail'] = loss_detail
- loss['detail_alpha'] = detail_alpha_loss
- loss['detail_gradient'] = detail_gradient_loss
- # fusion loss
- loss_fusion_func = self.loss_func_dict['fusion']
- # fusion_sigmoid loss
- fusion_alpha_loss = loss_fusion_func[0](logit_dict['fusion'],
- label_dict['alpha'])
- # composion loss
- comp_pred = logit_dict['fusion'] * label_dict['fg'] + (
- 1 - logit_dict['fusion']) * label_dict['bg']
- comp_gt = label_dict['alpha'] * label_dict['fg'] + (
- 1 - label_dict['alpha']) * label_dict['bg']
- fusion_composition_loss = loss_fusion_func[1](comp_pred, comp_gt)
- # grandient loss
- fusion_grad_loss = loss_fusion_func[2](logit_dict['fusion'],
- label_dict['alpha'])
- # fusion loss
- loss_fusion = fusion_alpha_loss + fusion_composition_loss + fusion_grad_loss
- loss['fusion'] = loss_fusion
- loss['fusion_alpha'] = fusion_alpha_loss
- loss['fusion_composition'] = fusion_composition_loss
- loss['fusion_gradient'] = fusion_grad_loss
- loss[
- 'all'] = 0.25 * loss_semantic + 0.25 * loss_detail + 0.25 * loss_fusion
- return loss
- def fusion(self, semantic_map, detail_map):
- # semantic_map [N, 3, H, W]
- # In index, 0 is foreground, 1 is transition, 2 is backbone
- # After fusion, the foreground is 1, the background is 0, and the transion is between [0, 1]
- index = paddle.argmax(semantic_map, axis=1, keepdim=True)
- transition_mask = (index == 1).astype('float32')
- fg = (index == 0).astype('float32')
- alpha = detail_map * transition_mask + fg
- return alpha
- def init_weight(self):
- if self.pretrained is not None:
- utils.load_entire_model(self, self.pretrained)
- class SCB(nn.Layer):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = [512 + in_channels, 512, 256, 128, 128, 64]
- self.mid_channels = [512, 256, 128, 128, 64, 64]
- self.out_channels = [256, 128, 64, 64, 64, 3]
- self.psp_module = layers.PPModule(
- in_channels,
- 512,
- bin_sizes=(1, 3, 5),
- dim_reduction=False,
- align_corners=False)
- psp_upsamples = [2, 4, 8, 16]
- self.psps = nn.LayerList([
- self.conv_up_psp(512, self.out_channels[i], psp_upsamples[i])
- for i in range(4)
- ])
- scb_list = [
- self._make_stage(
- self.in_channels[i],
- self.mid_channels[i],
- self.out_channels[i],
- padding=int(i == 0) + 1,
- dilation=int(i == 0) + 1)
- for i in range(len(self.in_channels) - 1)
- ]
- scb_list += [
- nn.Sequential(
- layers.ConvBNReLU(
- self.in_channels[-1], self.mid_channels[-1], 3, padding=1),
- layers.ConvBNReLU(
- self.mid_channels[-1], self.mid_channels[-1], 3, padding=1),
- nn.Conv2D(
- self.mid_channels[-1], self.out_channels[-1], 3, padding=1))
- ]
- self.scb_stages = nn.LayerList(scb_list)
- def forward(self, x):
- psp_x = self.psp_module(x)
- psps = [psp(psp_x) for psp in self.psps]
- scb_logits = []
- for i, scb_stage in enumerate(self.scb_stages):
- if i == 0:
- x = scb_stage(paddle.concat((psp_x, x), 1))
- elif i <= len(psps):
- x = scb_stage(paddle.concat((psps[i - 1], x), 1))
- else:
- x = scb_stage(x)
- scb_logits.append(x)
- return scb_logits
- def conv_up_psp(self, in_channels, out_channels, up_sample):
- return nn.Sequential(
- layers.ConvBNReLU(
- in_channels, out_channels, 3, padding=1),
- nn.Upsample(
- scale_factor=up_sample, mode='bilinear', align_corners=False))
- def _make_stage(self,
- in_channels,
- mid_channels,
- out_channels,
- padding=1,
- dilation=1):
- layer_list = [
- layers.ConvBNReLU(
- in_channels, mid_channels, 3, padding=1), layers.ConvBNReLU(
- mid_channels,
- mid_channels,
- 3,
- padding=padding,
- dilation=dilation), layers.ConvBNReLU(
- mid_channels,
- out_channels,
- 3,
- padding=padding,
- dilation=dilation), nn.Upsample(
- scale_factor=2,
- mode='bilinear',
- align_corners=False)
- ]
- return nn.Sequential(*layer_list)
- class HRDB(nn.Layer):
- """
- The High-Resolution Detail Branch
- Args:
- in_channels(int): The number of input channels.
- scb_channels(list|tuple): The channels of scb logits
- gf_index(list|tuple, optional): Which logit is selected as guidance flow from scb logits. Default: (0, 2, 4)
- """
- def __init__(self, in_channels, scb_channels, gf_index=(0, 2, 4)):
- super().__init__()
- self.gf_index = gf_index
- self.gf_list = nn.LayerList(
- [nn.Conv2D(scb_channels[i], 1, 1) for i in gf_index])
- channels = [64, 32, 16, 8]
- self.res_list = [
- resnet_vd.BasicBlock(
- in_channels, channels[0], stride=1, shortcut=False)
- ]
- self.res_list += [
- resnet_vd.BasicBlock(
- i, i, stride=1) for i in channels[1:-1]
- ]
- self.res_list = nn.LayerList(self.res_list)
- self.convs = nn.LayerList([
- nn.Conv2D(
- channels[i], channels[i + 1], kernel_size=1)
- for i in range(len(channels) - 1)
- ])
- self.gates = nn.LayerList(
- [GatedSpatailConv2d(i, i) for i in channels[1:]])
- self.detail_conv = nn.Conv2D(channels[-1], 1, 1, bias_attr=False)
- def forward(self, x, scb_logits):
- for i in range(len(self.res_list)):
- x = self.res_list[i](x)
- x = self.convs[i](x)
- gf = self.gf_list[i](scb_logits[self.gf_index[i]])
- gf = F.interpolate(
- gf, paddle.shape(x)[-2:], mode='bilinear', align_corners=False)
- x = self.gates[i](x, gf)
- return self.detail_conv(x)
- class GatedSpatailConv2d(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias_attr=False):
- super().__init__()
- self._gate_conv = nn.Sequential(
- layers.SyncBatchNorm(in_channels + 1),
- nn.Conv2D(
- in_channels + 1, in_channels + 1, kernel_size=1),
- nn.ReLU(),
- nn.Conv2D(
- in_channels + 1, 1, kernel_size=1),
- layers.SyncBatchNorm(1),
- nn.Sigmoid())
- self.conv = nn.Conv2D(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias_attr=bias_attr)
- def forward(self, input_features, gating_features):
- cat = paddle.concat([input_features, gating_features], axis=1)
- alphas = self._gate_conv(cat)
- x = input_features * (alphas + 1)
- x = self.conv(x)
- return x