123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddleseg.models import layers
- from ppmatting.models.layers import tensor_fusion_helper as helper
- class MLFF(nn.Layer):
- """
- Multi-level features are fused adaptively by obtaining spatial attention.
- Args:
- in_channels(list): The channels of input tensors.
- mid_channles(list): The middle channels while fusing the features.
- out_channel(int): The output channel after fusing.
- merge_type(str): Which type to merge the multi features before output.
- It should be one of ('add', 'concat'). Default: 'concat'.
- """
- def __init__(self,
- in_channels,
- mid_channels,
- out_channel,
- merge_type='concat'):
- super().__init__()
- self.merge_type = merge_type
- # Check arguments
- if len(in_channels) != len(mid_channels):
- raise ValueError(
- "`mid_channels` should have the same length as `in_channels`, but they are {} and {}".
- format(mid_channels, in_channels))
- if self.merge_type == 'add' and len(np.unique(np.array(
- mid_channels))) != 1:
- raise ValueError(
- "if `merge_type='add', `mid_channels` should be same of all input features, but it is {}.".
- format(mid_channels))
- self.pwconvs = nn.LayerList()
- self.dwconvs = nn.LayerList()
- for in_channel, mid_channel in zip(in_channels, mid_channels):
- self.pwconvs.append(
- layers.ConvBN(
- in_channel, mid_channel, 1, bias_attr=False))
- self.dwconvs.append(
- layers.ConvBNReLU(
- mid_channel,
- mid_channel,
- 3,
- padding=1,
- groups=mid_channel,
- bias_attr=False))
- num_feas = len(in_channels)
- self.conv_atten = nn.Sequential(
- layers.ConvBNReLU(
- 2 * num_feas,
- num_feas,
- kernel_size=3,
- padding=1,
- bias_attr=False),
- layers.ConvBN(
- num_feas, num_feas, kernel_size=3, padding=1, bias_attr=False))
- if self.merge_type == 'add':
- in_chan = mid_channels[0]
- else:
- in_chan = sum(mid_channels)
- self.conv_out = layers.ConvBNReLU(
- in_chan, out_channel, kernel_size=3, padding=1, bias_attr=False)
- def forward(self, inputs, shape):
- """
- args:
- inputs(list): List of tensor to be fused.
- shape(Tensor): A tensor with two elements like (H, W).
- """
- feas = []
- for i, input in enumerate(inputs):
- x = self.pwconvs[i](input)
- x = F.interpolate(
- x, size=shape, mode='bilinear', align_corners=False)
- x = self.dwconvs[i](x)
- feas.append(x)
- atten = helper.avg_max_reduce_channel(feas)
- atten = F.sigmoid(self.conv_atten(atten))
- feas_att = []
- for i, fea in enumerate(feas):
- fea = fea * (atten[:, i, :, :].unsqueeze(1))
- feas_att.append(fea)
- if self.merge_type == 'concat':
- out = paddle.concat(feas_att, axis=1)
- else:
- out = sum(feas_att)
- out = self.conv_out(out)
- return out
|