# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F from paddleseg.models import layers from ppmatting.models.layers import tensor_fusion_helper as helper class MLFF(nn.Layer): """ Multi-level features are fused adaptively by obtaining spatial attention. Args: in_channels(list): The channels of input tensors. mid_channles(list): The middle channels while fusing the features. out_channel(int): The output channel after fusing. merge_type(str): Which type to merge the multi features before output. It should be one of ('add', 'concat'). Default: 'concat'. """ def __init__(self, in_channels, mid_channels, out_channel, merge_type='concat'): super().__init__() self.merge_type = merge_type # Check arguments if len(in_channels) != len(mid_channels): raise ValueError( "`mid_channels` should have the same length as `in_channels`, but they are {} and {}". format(mid_channels, in_channels)) if self.merge_type == 'add' and len(np.unique(np.array( mid_channels))) != 1: raise ValueError( "if `merge_type='add', `mid_channels` should be same of all input features, but it is {}.". format(mid_channels)) self.pwconvs = nn.LayerList() self.dwconvs = nn.LayerList() for in_channel, mid_channel in zip(in_channels, mid_channels): self.pwconvs.append( layers.ConvBN( in_channel, mid_channel, 1, bias_attr=False)) self.dwconvs.append( layers.ConvBNReLU( mid_channel, mid_channel, 3, padding=1, groups=mid_channel, bias_attr=False)) num_feas = len(in_channels) self.conv_atten = nn.Sequential( layers.ConvBNReLU( 2 * num_feas, num_feas, kernel_size=3, padding=1, bias_attr=False), layers.ConvBN( num_feas, num_feas, kernel_size=3, padding=1, bias_attr=False)) if self.merge_type == 'add': in_chan = mid_channels[0] else: in_chan = sum(mid_channels) self.conv_out = layers.ConvBNReLU( in_chan, out_channel, kernel_size=3, padding=1, bias_attr=False) def forward(self, inputs, shape): """ args: inputs(list): List of tensor to be fused. shape(Tensor): A tensor with two elements like (H, W). """ feas = [] for i, input in enumerate(inputs): x = self.pwconvs[i](input) x = F.interpolate( x, size=shape, mode='bilinear', align_corners=False) x = self.dwconvs[i](x) feas.append(x) atten = helper.avg_max_reduce_channel(feas) atten = F.sigmoid(self.conv_atten(atten)) feas_att = [] for i, fea in enumerate(feas): fea = fea * (atten[:, i, :, :].unsqueeze(1)) feas_att.append(fea) if self.merge_type == 'concat': out = paddle.concat(feas_att, axis=1) else: out = sum(feas_att) out = self.conv_out(out) return out