# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from paddleseg.utils import logger from tools.model import get_model import cv2 import numpy as np from ppmatting.core import predict from ppmatting.utils import get_image_list current_path = os.path.abspath(os.path.dirname(__file__)) def get_rel_path(path: str): return "{}/../{}".format(current_path, path) def seg(img_path: str, save_dir: str): image_list, image_dir = get_image_list(img_path) logger.info('Number of predict images = {}'.format(len(image_list))) model = get_model() return predict( model=model.model, model_path=model.path, transforms=model.transforms, image_list=image_list, image_dir=image_dir, trimap_list=None, save_dir=save_dir, fg_estimate=True) def replace(img_path: str, save_dir: str, background: str = None, width: int = None, height: int = None): logger.info("replace: {}, {},{},{},{}".format(img_path, save_dir, background, width, height)) image_list, image_dir = get_image_list(img_path) model = get_model() alpha, fg, _, p = predict( model=model.model, model_path=model.path, transforms=model.transforms, image_list=image_list, trimap_list=None, save_dir=save_dir, fg_estimate=False) if background is None: return p img_ori = cv2.imread(img_path) bg = get_bg(background, img_ori.shape) if bg is None: return p alpha = alpha / 255.0 alpha = alpha[:, :, np.newaxis] com = alpha * fg + (1 - alpha) * bg com = com.astype('uint8') com_save_path = os.path.join(save_dir, os.path.basename(img_path)) cv2.imwrite(com_save_path, com) return com_save_path def get_bg(background, img_shape): # 1、纯色 # 2、通道颜色 # 3、图片 bg = np.zeros(img_shape) if background is None: return None if os.path.exists(background): bg = cv2.imread(background) bg = cv2.resize(bg, (img_shape[1], img_shape[0])) elif background == 'r': bg[:, :, 2] = 255 elif background == 'g': bg[:, :, 1] = 255 elif background == 'b': bg[:, :, 0] = 255 elif background == 'w': bg[:, :, :] = 255 elif is_color_hex(background): r, g, b, _ = hex_to_rgb(background) bg[:, :, 2] = r bg[:, :, 1] = g bg[:, :, 0] = b else: return None return bg def is_color_hex(color: str): size = len(color) if color.startswith("#"): return size == 7 or size == 9 return False def hex_to_rgb(color: str): if color.startswith("#"): color = color[1:len(color)] r = int(color[0:2], 16) g = int(color[2:4], 16) b = int(color[4:6], 16) a = 100 if len(color) == 8: a = int(color[6:8], 16) return r, g, b, a