123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # The gca code was heavily based on https://github.com/Yaoyi-Li/GCA-Matting
- # and https://github.com/open-mmlab/mmediting
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from paddleseg.cvlibs import param_init
- class GuidedCxtAtten(nn.Layer):
- def __init__(self,
- out_channels,
- guidance_channels,
- kernel_size=3,
- stride=1,
- rate=2):
- super().__init__()
- self.kernel_size = kernel_size
- self.rate = rate
- self.stride = stride
- self.guidance_conv = nn.Conv2D(
- in_channels=guidance_channels,
- out_channels=guidance_channels // 2,
- kernel_size=1)
- self.out_conv = nn.Sequential(
- nn.Conv2D(
- in_channels=out_channels,
- out_channels=out_channels,
- kernel_size=1,
- bias_attr=False),
- nn.BatchNorm(out_channels))
- self.init_weight()
- def init_weight(self):
- param_init.xavier_uniform(self.guidance_conv.weight)
- param_init.constant_init(self.guidance_conv.bias, value=0.0)
- param_init.xavier_uniform(self.out_conv[0].weight)
- param_init.constant_init(self.out_conv[1].weight, value=1e-3)
- param_init.constant_init(self.out_conv[1].bias, value=0.0)
- def forward(self, img_feat, alpha_feat, unknown=None, softmax_scale=1.):
- img_feat = self.guidance_conv(img_feat)
- img_feat = F.interpolate(
- img_feat, scale_factor=1 / self.rate, mode='nearest')
- # process unknown mask
- unknown, softmax_scale = self.process_unknown_mask(unknown, img_feat,
- softmax_scale)
- img_ps, alpha_ps, unknown_ps = self.extract_feature_maps_patches(
- img_feat, alpha_feat, unknown)
- self_mask = self.get_self_correlation_mask(img_feat)
- # split tensors by batch dimension; tuple is returned
- img_groups = paddle.split(img_feat, 1, axis=0)
- img_ps_groups = paddle.split(img_ps, 1, axis=0)
- alpha_ps_groups = paddle.split(alpha_ps, 1, axis=0)
- unknown_ps_groups = paddle.split(unknown_ps, 1, axis=0)
- scale_groups = paddle.split(softmax_scale, 1, axis=0)
- groups = (img_groups, img_ps_groups, alpha_ps_groups, unknown_ps_groups,
- scale_groups)
- y = []
- for img_i, img_ps_i, alpha_ps_i, unknown_ps_i, scale_i in zip(*groups):
- # conv for compare
- similarity_map = self.compute_similarity_map(img_i, img_ps_i)
- gca_score = self.compute_guided_attention_score(
- similarity_map, unknown_ps_i, scale_i, self_mask)
- yi = self.propagate_alpha_feature(gca_score, alpha_ps_i)
- y.append(yi)
- y = paddle.concat(y, axis=0) # back to the mini-batch
- y = paddle.reshape(y, alpha_feat.shape)
- y = self.out_conv(y) + alpha_feat
- return y
- def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown):
- # extract image feature patches with shape:
- # (N, img_h*img_w, img_c, img_ks, img_ks)
- img_ks = self.kernel_size
- img_ps = self.extract_patches(img_feat, img_ks, self.stride)
- # extract alpha feature patches with shape:
- # (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks)
- alpha_ps = self.extract_patches(alpha_feat, self.rate * 2, self.rate)
- # extract unknown mask patches with shape: (N, img_h*img_w, 1, 1)
- unknown_ps = self.extract_patches(unknown, img_ks, self.stride)
- unknown_ps = unknown_ps.squeeze(axis=2) # squeeze channel dimension
- unknown_ps = unknown_ps.mean(axis=[2, 3], keepdim=True)
- return img_ps, alpha_ps, unknown_ps
- def extract_patches(self, x, kernel_size, stride):
- n, c, _, _ = x.shape
- x = self.pad(x, kernel_size, stride)
- x = F.unfold(x, [kernel_size, kernel_size], strides=[stride, stride])
- x = paddle.transpose(x, (0, 2, 1))
- x = paddle.reshape(x, (n, -1, c, kernel_size, kernel_size))
- return x
- def pad(self, x, kernel_size, stride):
- left = (kernel_size - stride + 1) // 2
- right = (kernel_size - stride) // 2
- pad = (left, right, left, right)
- return F.pad(x, pad, mode='reflect')
- def compute_guided_attention_score(self, similarity_map, unknown_ps, scale,
- self_mask):
- # scale the correlation with predicted scale factor for known and
- # unknown area
- unknown_scale, known_scale = scale[0]
- out = similarity_map * (
- unknown_scale * paddle.greater_than(unknown_ps,
- paddle.to_tensor([0.])) +
- known_scale * paddle.less_equal(unknown_ps, paddle.to_tensor([0.])))
- # mask itself, self-mask only applied to unknown area
- out = out + self_mask * unknown_ps
- gca_score = F.softmax(out, axis=1)
- return gca_score
- def propagate_alpha_feature(self, gca_score, alpha_ps):
- alpha_ps = alpha_ps[0] # squeeze dim 0
- if self.rate == 1:
- gca_score = self.pad(gca_score, kernel_size=2, stride=1)
- alpha_ps = paddle.transpose(alpha_ps, (1, 0, 2, 3))
- out = F.conv2d(gca_score, alpha_ps) / 4.
- else:
- out = F.conv2d_transpose(
- gca_score, alpha_ps, stride=self.rate, padding=1) / 4.
- return out
- def compute_similarity_map(self, img_feat, img_ps):
- img_ps = img_ps[0] # squeeze dim 0
- # convolve the feature to get correlation (similarity) map
- img_ps_normed = img_ps / paddle.clip(self.l2_norm(img_ps), 1e-4)
- img_feat = F.pad(img_feat, (1, 1, 1, 1), mode='reflect')
- similarity_map = F.conv2d(img_feat, img_ps_normed)
- return similarity_map
- def get_self_correlation_mask(self, img_feat):
- _, _, h, w = img_feat.shape
- self_mask = F.one_hot(
- paddle.reshape(paddle.arange(h * w), (h, w)),
- num_classes=int(h * w))
- self_mask = paddle.transpose(self_mask, (2, 0, 1))
- self_mask = paddle.reshape(self_mask, (1, h * w, h, w))
- return self_mask * (-1e4)
- def process_unknown_mask(self, unknown, img_feat, softmax_scale):
- n, _, h, w = img_feat.shape
- if unknown is not None:
- unknown = unknown.clone()
- unknown = F.interpolate(
- unknown, scale_factor=1 / self.rate, mode='nearest')
- unknown_mean = unknown.mean(axis=[2, 3])
- known_mean = 1 - unknown_mean
- unknown_scale = paddle.clip(
- paddle.sqrt(unknown_mean / known_mean), 0.1, 10)
- known_scale = paddle.clip(
- paddle.sqrt(known_mean / unknown_mean), 0.1, 10)
- softmax_scale = paddle.concat([unknown_scale, known_scale], axis=1)
- else:
- unknown = paddle.ones([n, 1, h, w])
- softmax_scale = paddle.reshape(
- paddle.to_tensor([softmax_scale, softmax_scale]), (1, 2))
- softmax_scale = paddle.expand(softmax_scale, (n, 2))
- return unknown, softmax_scale
- @staticmethod
- def l2_norm(x):
- x = x**2
- x = x.sum(axis=[1, 2, 3], keepdim=True)
- return paddle.sqrt(x)
|