|
@@ -0,0 +1,87 @@
|
|
|
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+import os
|
|
|
+from paddleseg.utils import logger
|
|
|
+
|
|
|
+from tools.model import get_model
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+from ppmatting.core import predict
|
|
|
+from ppmatting.utils import get_image_list
|
|
|
+
|
|
|
+current_path = os.path.abspath(os.path.dirname(__file__))
|
|
|
+
|
|
|
+
|
|
|
+def get_rel_path(path: str):
|
|
|
+ return "{}/../{}".format(current_path, path)
|
|
|
+
|
|
|
+
|
|
|
+def seg(img_path: str, save_dir: str):
|
|
|
+ image_list, image_dir = get_image_list(img_path)
|
|
|
+ logger.info('Number of predict images = {}'.format(len(image_list)))
|
|
|
+
|
|
|
+ model = get_model()
|
|
|
+
|
|
|
+ return predict(
|
|
|
+ model=model.model,
|
|
|
+ model_path=model.path,
|
|
|
+ transforms=model.transforms,
|
|
|
+ image_list=image_list,
|
|
|
+ image_dir=image_dir,
|
|
|
+ trimap_list=None,
|
|
|
+ save_dir=save_dir,
|
|
|
+ fg_estimate=True)
|
|
|
+
|
|
|
+
|
|
|
+def replace(img_path: str, save_dir: str, bg_color = "r"):
|
|
|
+ image_list, image_dir = get_image_list(img_path)
|
|
|
+ model = get_model()
|
|
|
+ alpha, fg = predict(
|
|
|
+ model=model.model,
|
|
|
+ model_path=model.path,
|
|
|
+ transforms=model.transforms,
|
|
|
+ image_list=image_list,
|
|
|
+ trimap_list=None,
|
|
|
+ save_dir=save_dir,
|
|
|
+ fg_estimate=False)
|
|
|
+
|
|
|
+ img_ori = cv2.imread(img_path)
|
|
|
+ bg = get_bg(bg_color, img_ori.shape)
|
|
|
+ alpha = alpha / 255.0
|
|
|
+ alpha = alpha[:, :, np.newaxis]
|
|
|
+ com = alpha * fg + (1 - alpha) * bg
|
|
|
+ com = com.astype('uint8')
|
|
|
+ com_save_path = os.path.join(save_dir, os.path.basename(img_path))
|
|
|
+ cv2.imwrite(com_save_path, com)
|
|
|
+
|
|
|
+
|
|
|
+def get_bg(background, img_shape):
|
|
|
+ bg = np.zeros(img_shape)
|
|
|
+ if background == 'r':
|
|
|
+ bg[:, :, 2] = 255
|
|
|
+ elif background is None or background == 'g':
|
|
|
+ bg[:, :, 1] = 255
|
|
|
+ elif background == 'b':
|
|
|
+ bg[:, :, 0] = 255
|
|
|
+ elif background == 'w':
|
|
|
+ bg[:, :, :] = 255
|
|
|
+
|
|
|
+ elif not os.path.exists(background):
|
|
|
+ raise Exception('The --background is not existed: {}'.format(
|
|
|
+ background))
|
|
|
+ else:
|
|
|
+ bg = cv2.imread(background)
|
|
|
+ bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
|
|
|
+ return bg
|