123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from functools import partial
- from collections import defaultdict
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- import paddleseg
- from paddleseg import utils
- from paddleseg.models import layers
- from paddleseg.cvlibs import manager
- from paddleseg.models.backbones.transformer_utils import Identity, DropPath
- from ppmatting.models.layers import MLFF
- from ppmatting.models.losses import MRSD, GradientLoss
- @manager.MODELS.add_component
- class PPMattingV2(nn.Layer):
- """
- The PPMattingV2 implementation based on PaddlePaddle.
- The original article refers to
- TODO Guowei Chen, et, al. "" ().
- Args:
- backbone: backobne model.
- pretrained(str, optional): The path of pretrianed model. Defautl: None.
- dpp_len_trans(int, optional): The depth of transformer block in dpp(DoublePyramidPoolModule). Default: 1.
- dpp_index(list, optional): The index of backone output which as the input in dpp. Default: [1, 2, 3, 4].
- dpp_mid_channel(int, optional): The output channels of the first pyramid pool in dpp. Default: 256.
- dpp_out_channel(int, optional): The output channels of dpp. Default: 512.
- dpp_bin_sizes(list, optional): The output size of the second pyramid pool in dpp. Default: (2, 4, 6).
- dpp_mlp_ratios(int, optional): The expandsion ratio of mlp in dpp. Default: 2.
- dpp_attn_ratio(int, optional): The expandsion ratio of attention. Default: 2.
- dpp_merge_type(str, optional): The merge type of the output of the second pyramid pool in dpp,
- which should be one of (`concat`, `add`). Default: 'concat'.
- mlff_merge_type(str, optional): The merge type of the multi features before output.
- It should be one of ('add', 'concat'). Default: 'concat'.
- """
- def __init__(self,
- backbone,
- pretrained=None,
- dpp_len_trans=1,
- dpp_index=[1, 2, 3, 4],
- dpp_mid_channel=256,
- dpp_output_channel=512,
- dpp_bin_sizes=(2, 4, 6),
- dpp_mlp_ratios=2,
- dpp_attn_ratio=2,
- dpp_merge_type='concat',
- mlff_merge_type='concat',
- decoder_channels=[128, 96, 64, 32, 32],
- head_channel=32):
- super().__init__()
- self.backbone = backbone
- self.backbone_channels = backbone.feat_channels
- # check
- assert len(backbone.feat_channels) == 5, \
- "Backbone should return 5 features with different scales"
- assert max(dpp_index) < len(backbone.feat_channels), \
- "The element of `dpp_index` should be less than the number of return features of backbone."
- # dpp module
- self.dpp_index = dpp_index
- self.dpp = DoublePyramidPoolModule(
- stride=2,
- input_channel=sum(self.backbone_channels[i]
- for i in self.dpp_index),
- mid_channel=dpp_mid_channel,
- output_channel=dpp_output_channel,
- len_trans=dpp_len_trans,
- bin_sizes=dpp_bin_sizes,
- mlp_ratios=dpp_mlp_ratios,
- attn_ratio=dpp_attn_ratio,
- merge_type=dpp_merge_type)
- # decoder
- self.mlff32x = MLFF(
- in_channels=[self.backbone_channels[-1], dpp_output_channel],
- mid_channels=[dpp_output_channel, dpp_output_channel],
- out_channel=decoder_channels[0],
- merge_type=mlff_merge_type)
- self.mlff16x = MLFF(
- in_channels=[
- self.backbone_channels[-2], decoder_channels[0],
- dpp_output_channel
- ],
- mid_channels=[
- decoder_channels[0], decoder_channels[0], decoder_channels[0]
- ],
- out_channel=decoder_channels[1],
- merge_type=mlff_merge_type)
- self.mlff8x = MLFF(
- in_channels=[
- self.backbone_channels[-3], decoder_channels[1],
- dpp_output_channel
- ],
- mid_channels=[
- decoder_channels[1], decoder_channels[1], decoder_channels[1]
- ],
- out_channel=decoder_channels[2],
- merge_type=mlff_merge_type)
- self.mlff4x = MLFF(
- in_channels=[self.backbone_channels[-4], decoder_channels[2], 3],
- mid_channels=[decoder_channels[2], decoder_channels[2], 3],
- out_channel=decoder_channels[3])
- self.mlff2x = MLFF(
- in_channels=[self.backbone_channels[-5], decoder_channels[3], 3],
- mid_channels=[decoder_channels[3], decoder_channels[3], 3],
- out_channel=decoder_channels[4])
- self.matting_head_mlff8x = MattingHead(
- in_chan=decoder_channels[2], mid_chan=32)
- self.matting_head_mlff2x = MattingHead(
- in_chan=decoder_channels[4] + 3, mid_chan=head_channel, mid_num=2)
- # loss
- self.loss_func_dict = None
- # pretrained
- self.pretrained = pretrained
- self.init_weight()
- def forward(self, inputs):
- img = inputs['img']
- input_shape = paddle.shape(img)
- feats_backbone = self.backbone(
- img) # stdc1 [2x, 4x, 8x, 16x, 32x] [32, 64, 256, 512, 1024]
- x = self.dpp([feats_backbone[i] for i in self.dpp_index])
- dpp_out = x
- input_32x = [feats_backbone[-1], x]
- x = self.mlff32x(input_32x,
- paddle.shape(feats_backbone[-1])[-2:]) # 32x
- input_16x = [feats_backbone[-2], x, dpp_out]
- x = self.mlff16x(input_16x,
- paddle.shape(feats_backbone[-2])[-2:]) # 16x
- input_8x = [feats_backbone[-3], x, dpp_out]
- x = self.mlff8x(input_8x, paddle.shape(feats_backbone[-3])[-2:]) # 8x
- mlff8x_output = x
- input_4x = [feats_backbone[-4], x]
- input_4x.append(
- F.interpolate(
- img, feats_backbone[-4].shape[2:], mode='area'))
- x = self.mlff4x(input_4x, paddle.shape(feats_backbone[-4])[-2:]) # 4x
- input_2x = [feats_backbone[-5], x]
- input_2x.append(
- F.interpolate(
- img, feats_backbone[-5].shape[2:], mode='area'))
- x = self.mlff2x(input_2x, paddle.shape(feats_backbone[-5])[-2:]) # 2x
- x = F.interpolate(
- x, input_shape[-2:], mode='bilinear', align_corners=False)
- x = paddle.concat([x, img], axis=1)
- alpha = self.matting_head_mlff2x(x)
- if self.training:
- logit_dict = {}
- logit_dict['alpha'] = alpha
- logit_dict['alpha_8x'] = self.matting_head_mlff8x(mlff8x_output)
- loss_dict = self.loss(logit_dict, inputs)
- return logit_dict, loss_dict
- else:
- return alpha
- def loss(self, logit_dict, label_dict, loss_func_dict=None):
- if loss_func_dict is None:
- if self.loss_func_dict is None:
- self.loss_func_dict = defaultdict(list)
- self.loss_func_dict['alpha'].append(MRSD())
- self.loss_func_dict['alpha'].append(GradientLoss())
- self.loss_func_dict['alpha_8x'].append(MRSD())
- self.loss_func_dict['alpha_8x'].append(GradientLoss())
- else:
- self.loss_func_dict = loss_func_dict
- loss = {}
- alpha_8x_label = F.interpolate(
- label_dict['alpha'],
- size=logit_dict['alpha_8x'].shape[-2:],
- mode='area',
- align_corners=False)
- loss['alpha_8x_mrsd'] = self.loss_func_dict['alpha_8x'][0](
- logit_dict['alpha_8x'], alpha_8x_label)
- loss['alpha_8x_grad'] = self.loss_func_dict['alpha_8x'][1](
- logit_dict['alpha_8x'], alpha_8x_label)
- loss['alpha_8x'] = loss['alpha_8x_mrsd'] + loss['alpha_8x_grad']
- transition_mask = label_dict['trimap'] == 128
- loss['alpha_mrsd'] = self.loss_func_dict['alpha'][0](
- logit_dict['alpha'],
- label_dict['alpha']) + 2 * self.loss_func_dict['alpha'][0](
- logit_dict['alpha'], label_dict['alpha'], transition_mask)
- loss['alpha_grad'] = self.loss_func_dict['alpha'][1](
- logit_dict['alpha'],
- label_dict['alpha']) + 2 * self.loss_func_dict['alpha'][1](
- logit_dict['alpha'], label_dict['alpha'], transition_mask)
- loss['alpha'] = loss['alpha_mrsd'] + loss['alpha_grad']
- loss['all'] = loss['alpha'] + loss['alpha_8x']
- return loss
- def init_weight(self):
- if self.pretrained is not None:
- utils.load_entire_model(self, self.pretrained)
- class MattingHead(nn.Layer):
- def __init__(self, in_chan, mid_chan, mid_num=1, out_channels=1):
- super().__init__()
- self.conv = layers.ConvBNReLU(
- in_chan,
- mid_chan,
- kernel_size=3,
- stride=1,
- padding=1,
- bias_attr=False)
- self.mid_conv = nn.LayerList([
- layers.ConvBNReLU(
- mid_chan,
- mid_chan,
- kernel_size=3,
- stride=1,
- padding=1,
- bias_attr=False) for i in range(mid_num - 1)
- ])
- self.conv_out = nn.Conv2D(
- mid_chan, out_channels, kernel_size=1, bias_attr=False)
- def forward(self, x):
- x = self.conv(x)
- for mid_conv in self.mid_conv:
- x = mid_conv(x)
- x = self.conv_out(x)
- x = F.sigmoid(x)
- return x
- class DoublePyramidPoolModule(nn.Layer):
- """
- Extract global information through double pyramid pool structure and attention calculation by transformer block.
- Args:
- stride(int): The stride for the inputs.
- input_channel(int): The total channels of input features.
- mid_channel(int, optional): The output channels of the first pyramid pool. Default: 256.
- out_channel(int, optional): The output channels. Default: 512.
- len_trans(int, optional): The depth of transformer block. Default: 1.
- bin_sizes(list, optional): The output size of the second pyramid pool. Default: (2, 4, 6).
- mlp_ratios(int, optional): The expandsion ratio of the mlp. Default: 2.
- attn_ratio(int, optional): The expandsion ratio of the attention. Default: 2.
- merge_type(str, optional): The merge type of the output of the second pyramid pool, which should be one of (`concat`, `add`). Default: 'concat'.
- align_corners(bool, optional): Whether to use `align_corners` when interpolating. Default: False.
- """
- def __init__(self,
- stride,
- input_channel,
- mid_channel=256,
- output_channel=512,
- len_trans=1,
- bin_sizes=(2, 4, 6),
- mlp_ratios=2,
- attn_ratio=2,
- merge_type='concat',
- align_corners=False):
- super().__init__()
- self.mid_channel = mid_channel
- self.align_corners = align_corners
- self.mlp_rations = mlp_ratios
- self.attn_ratio = attn_ratio
- if isinstance(len_trans, int):
- self.len_trans = [len_trans] * len(bin_sizes)
- elif isinstance(len_trans, (list, tuple)):
- self.len_trans = len_trans
- if len(len_trans) != len(bin_sizes):
- raise ValueError(
- 'If len_trans is list or tuple, the length should be same as bin_sizes'
- )
- else:
- raise ValueError(
- '`len_trans` only support int, list and tuple type')
- if merge_type not in ['add', 'concat']:
- raise ('`merge_type only support `add` or `concat`.')
- self.merge_type = merge_type
- self.pp1 = PyramidPoolAgg(stride=stride)
- self.conv_mid = layers.ConvBN(input_channel, mid_channel, 1)
- self.pp2 = nn.LayerList([
- self._make_stage(
- embdeding_channels=mid_channel, size=size, block_num=block_num)
- for size, block_num in zip(bin_sizes, self.len_trans)
- ])
- if self.merge_type == 'concat':
- in_chan = mid_channel + mid_channel * len(bin_sizes)
- else:
- in_chan = mid_channel
- self.conv_out = layers.ConvBNReLU(
- in_chan, output_channel, kernel_size=1)
- def _make_stage(self, embdeding_channels, size, block_num):
- prior = nn.AdaptiveAvgPool2D(output_size=size)
- if size == 1:
- trans = layers.ConvBNReLU(
- in_channels=embdeding_channels,
- out_channels=embdeding_channels,
- kernel_size=1)
- else:
- trans = BasicLayer(
- block_num=block_num,
- embedding_dim=embdeding_channels,
- key_dim=16,
- num_heads=8,
- mlp_ratios=self.mlp_rations,
- attn_ratio=self.attn_ratio,
- drop=0,
- attn_drop=0,
- drop_path=0,
- act_layer=nn.ReLU6,
- lr_mult=1.0)
- return nn.Sequential(prior, trans)
- def forward(self, inputs):
- x = self.pp1(inputs)
- pp2_input = self.conv_mid(x)
- cat_layers = []
- for stage in self.pp2:
- x = stage(pp2_input)
- x = F.interpolate(
- x,
- paddle.shape(pp2_input)[2:],
- mode='bilinear',
- align_corners=self.align_corners)
- cat_layers.append(x)
- cat_layers = [pp2_input] + cat_layers[::-1]
- if self.merge_type == 'concat':
- cat = paddle.concat(cat_layers, axis=1)
- else:
- cat = sum(cat_layers)
- out = self.conv_out(cat)
- return out
- class Conv2DBN(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- ks=1,
- stride=1,
- pad=0,
- dilation=1,
- groups=1,
- bn_weight_init=1,
- lr_mult=1.0):
- super().__init__()
- conv_weight_attr = paddle.ParamAttr(learning_rate=lr_mult)
- self.c = nn.Conv2D(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=ks,
- stride=stride,
- padding=pad,
- dilation=dilation,
- groups=groups,
- weight_attr=conv_weight_attr,
- bias_attr=False)
- bn_weight_attr = paddle.ParamAttr(
- initializer=nn.initializer.Constant(bn_weight_init),
- learning_rate=lr_mult)
- bn_bias_attr = paddle.ParamAttr(
- initializer=nn.initializer.Constant(0), learning_rate=lr_mult)
- self.bn = nn.BatchNorm2D(
- out_channels, weight_attr=bn_weight_attr, bias_attr=bn_bias_attr)
- def forward(self, inputs):
- out = self.c(inputs)
- out = self.bn(out)
- return out
- class MLP(nn.Layer):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.ReLU,
- drop=0.,
- lr_mult=1.0):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = Conv2DBN(in_features, hidden_features, lr_mult=lr_mult)
- param_attr = paddle.ParamAttr(learning_rate=lr_mult)
- self.dwconv = nn.Conv2D(
- hidden_features,
- hidden_features,
- 3,
- 1,
- 1,
- groups=hidden_features,
- weight_attr=param_attr,
- bias_attr=param_attr)
- self.act = act_layer()
- self.fc2 = Conv2DBN(hidden_features, out_features, lr_mult=lr_mult)
- self.drop = nn.Dropout(drop)
- def forward(self, x):
- x = self.fc1(x)
- x = self.dwconv(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class Attention(nn.Layer):
- def __init__(self,
- dim,
- key_dim,
- num_heads,
- attn_ratio=4,
- activation=None,
- lr_mult=1.0):
- super().__init__()
- self.num_heads = num_heads
- self.scale = key_dim**-0.5
- self.key_dim = key_dim
- self.nh_kd = nh_kd = key_dim * num_heads
- self.d = int(attn_ratio * key_dim)
- self.dh = int(attn_ratio * key_dim) * num_heads
- self.attn_ratio = attn_ratio
- self.to_q = Conv2DBN(dim, nh_kd, 1, lr_mult=lr_mult)
- self.to_k = Conv2DBN(dim, nh_kd, 1, lr_mult=lr_mult)
- self.to_v = Conv2DBN(dim, self.dh, 1, lr_mult=lr_mult)
- self.proj = nn.Sequential(
- activation(),
- Conv2DBN(
- self.dh, dim, bn_weight_init=0, lr_mult=lr_mult))
- def forward(self, x):
- x_shape = paddle.shape(x)
- H, W = x_shape[2], x_shape[3]
- qq = self.to_q(x).reshape(
- [0, self.num_heads, self.key_dim, -1]).transpose([0, 1, 3, 2])
- kk = self.to_k(x).reshape([0, self.num_heads, self.key_dim, -1])
- vv = self.to_v(x).reshape([0, self.num_heads, self.d, -1]).transpose(
- [0, 1, 3, 2])
- attn = paddle.matmul(qq, kk)
- attn = F.softmax(attn, axis=-1)
- xx = paddle.matmul(attn, vv)
- xx = xx.transpose([0, 1, 3, 2]).reshape([0, self.dh, H, W])
- xx = self.proj(xx)
- return xx
- class Block(nn.Layer):
- def __init__(self,
- dim,
- key_dim,
- num_heads,
- mlp_ratios=4.,
- attn_ratio=2.,
- drop=0.,
- drop_path=0.,
- act_layer=nn.ReLU,
- lr_mult=1.0):
- super().__init__()
- self.dim = dim
- self.num_heads = num_heads
- self.mlp_ratios = mlp_ratios
- self.attn = Attention(
- dim,
- key_dim=key_dim,
- num_heads=num_heads,
- attn_ratio=attn_ratio,
- activation=act_layer,
- lr_mult=lr_mult)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
- mlp_hidden_dim = int(dim * mlp_ratios)
- self.mlp = MLP(in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- lr_mult=lr_mult)
- def forward(self, x):
- h = x
- x = self.attn(x)
- x = self.drop_path(x)
- x = h + x
- h = x
- x = self.mlp(x)
- x = self.drop_path(x)
- x = x + h
- return x
- class BasicLayer(nn.Layer):
- def __init__(self,
- block_num,
- embedding_dim,
- key_dim,
- num_heads,
- mlp_ratios=4.,
- attn_ratio=2.,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=None,
- lr_mult=1.0):
- super().__init__()
- self.block_num = block_num
- self.transformer_blocks = nn.LayerList()
- for i in range(self.block_num):
- self.transformer_blocks.append(
- Block(
- embedding_dim,
- key_dim=key_dim,
- num_heads=num_heads,
- mlp_ratios=mlp_ratios,
- attn_ratio=attn_ratio,
- drop=drop,
- drop_path=drop_path[i]
- if isinstance(drop_path, list) else drop_path,
- act_layer=act_layer,
- lr_mult=lr_mult))
- def forward(self, x):
- # token * N
- for i in range(self.block_num):
- x = self.transformer_blocks[i](x)
- return x
- class PyramidPoolAgg(nn.Layer):
- def __init__(self, stride):
- super().__init__()
- self.stride = stride
- self.tmp = Identity() # avoid the error of paddle.flops
- def forward(self, inputs):
- '''
- # The F.adaptive_avg_pool2d does not support the (H, W) be Tensor,
- # so exporting the inference model will raise error.
- _, _, H, W = inputs[-1].shape
- H = (H - 1) // self.stride + 1
- W = (W - 1) // self.stride + 1
- return paddle.concat(
- [F.adaptive_avg_pool2d(inp, (H, W)) for inp in inputs], axis=1)
- '''
- out = []
- ks = 2**len(inputs)
- stride = self.stride**len(inputs)
- for x in inputs:
- x = F.avg_pool2d(x, int(ks), int(stride))
- ks /= 2
- stride /= 2
- out.append(x)
- out = paddle.concat(out, axis=1)
- return out
|