img.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from paddleseg.utils import logger
  16. from tools.model import get_model
  17. import cv2
  18. import numpy as np
  19. from ppmatting.core import predict
  20. from ppmatting.utils import get_image_list
  21. current_path = os.path.abspath(os.path.dirname(__file__))
  22. def get_rel_path(path: str):
  23. return "{}/../{}".format(current_path, path)
  24. def seg(img_path: str, save_dir: str):
  25. image_list, image_dir = get_image_list(img_path)
  26. logger.info('Number of predict images = {}'.format(len(image_list)))
  27. model = get_model()
  28. return predict(
  29. model=model.model,
  30. model_path=model.path,
  31. transforms=model.transforms,
  32. image_list=image_list,
  33. image_dir=image_dir,
  34. trimap_list=None,
  35. save_dir=save_dir,
  36. fg_estimate=True)
  37. def replace(img_path: str, save_dir: str, bg_color = "r"):
  38. image_list, image_dir = get_image_list(img_path)
  39. model = get_model()
  40. alpha, fg = predict(
  41. model=model.model,
  42. model_path=model.path,
  43. transforms=model.transforms,
  44. image_list=image_list,
  45. trimap_list=None,
  46. save_dir=save_dir,
  47. fg_estimate=False)
  48. img_ori = cv2.imread(img_path)
  49. bg = get_bg(bg_color, img_ori.shape)
  50. alpha = alpha / 255.0
  51. alpha = alpha[:, :, np.newaxis]
  52. com = alpha * fg + (1 - alpha) * bg
  53. com = com.astype('uint8')
  54. com_save_path = os.path.join(save_dir, os.path.basename(img_path))
  55. cv2.imwrite(com_save_path, com)
  56. def get_bg(background, img_shape):
  57. bg = np.zeros(img_shape)
  58. if background == 'r':
  59. bg[:, :, 2] = 255
  60. elif background is None or background == 'g':
  61. bg[:, :, 1] = 255
  62. elif background == 'b':
  63. bg[:, :, 0] = 255
  64. elif background == 'w':
  65. bg[:, :, :] = 255
  66. elif not os.path.exists(background):
  67. raise Exception('The --background is not existed: {}'.format(
  68. background))
  69. else:
  70. bg = cv2.imread(background)
  71. bg = cv2.resize(bg, (img_shape[1], img_shape[0]))
  72. return bg