# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from paddleseg.utils import logger from tools.model import get_model import cv2 import numpy as np from ppmatting.core import predict from ppmatting.utils import get_image_list current_path = os.path.abspath(os.path.dirname(__file__)) def get_rel_path(path: str): return "{}/../{}".format(current_path, path) def seg(img_path: str, save_dir: str): image_list, image_dir = get_image_list(img_path) logger.info('Number of predict images = {}'.format(len(image_list))) model = get_model() return predict( model=model.model, model_path=model.path, transforms=model.transforms, image_list=image_list, image_dir=image_dir, trimap_list=None, save_dir=save_dir, fg_estimate=True) def replace(img_path: str, save_dir: str, bg_color = "r"): image_list, image_dir = get_image_list(img_path) model = get_model() alpha, fg = predict( model=model.model, model_path=model.path, transforms=model.transforms, image_list=image_list, trimap_list=None, save_dir=save_dir, fg_estimate=False) img_ori = cv2.imread(img_path) bg = get_bg(bg_color, img_ori.shape) alpha = alpha / 255.0 alpha = alpha[:, :, np.newaxis] com = alpha * fg + (1 - alpha) * bg com = com.astype('uint8') com_save_path = os.path.join(save_dir, os.path.basename(img_path)) cv2.imwrite(com_save_path, com) def get_bg(background, img_shape): bg = np.zeros(img_shape) if background == 'r': bg[:, :, 2] = 255 elif background is None or background == 'g': bg[:, :, 1] = 255 elif background == 'b': bg[:, :, 0] = 255 elif background == 'w': bg[:, :, :] = 255 elif not os.path.exists(background): raise Exception('The --background is not existed: {}'.format( background)) else: bg = cv2.imread(background) bg = cv2.resize(bg, (img_shape[1], img_shape[0])) return bg