123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- from paddleseg.utils import logger
- from tools.model import get_model
- import cv2
- import numpy as np
- from ppmatting.core import predict
- from ppmatting.utils import get_image_list
- current_path = os.path.abspath(os.path.dirname(__file__))
- def get_rel_path(path: str):
- return "{}/../{}".format(current_path, path)
- def seg(img_path: str, save_dir: str):
- image_list, image_dir = get_image_list(img_path)
- logger.info('Number of predict images = {}'.format(len(image_list)))
- model = get_model()
- return predict(
- model=model.model,
- model_path=model.path,
- transforms=model.transforms,
- image_list=image_list,
- image_dir=image_dir,
- trimap_list=None,
- save_dir=save_dir,
- fg_estimate=True)
- def replace(img_path: str, save_dir: str, bg_color = "r"):
- image_list, image_dir = get_image_list(img_path)
- model = get_model()
- alpha, fg = predict(
- model=model.model,
- model_path=model.path,
- transforms=model.transforms,
- image_list=image_list,
- trimap_list=None,
- save_dir=save_dir,
- fg_estimate=False)
- img_ori = cv2.imread(img_path)
- bg = get_bg(bg_color, img_ori.shape)
- alpha = alpha / 255.0
- alpha = alpha[:, :, np.newaxis]
- com = alpha * fg + (1 - alpha) * bg
- com = com.astype('uint8')
- com_save_path = os.path.join(save_dir, os.path.basename(img_path))
- cv2.imwrite(com_save_path, com)
- def get_bg(background, img_shape):
- bg = np.zeros(img_shape)
- if background == 'r':
- bg[:, :, 2] = 255
- elif background is None or background == 'g':
- bg[:, :, 1] = 255
- elif background == 'b':
- bg[:, :, 0] = 255
- elif background == 'w':
- bg[:, :, :] = 255
- elif not os.path.exists(background):
- raise Exception('The --background is not existed: {}'.format(
- background))
- else:
- bg = cv2.imread(background)
- bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
- return bg
|