methods.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pymatting
  15. from paddleseg.cvlibs import manager
  16. class BaseMLMatting(object):
  17. def __init__(self, alpha_estimator, **kargs):
  18. self.alpha_estimator = alpha_estimator
  19. self.kargs = kargs
  20. def __call__(self, image, trimap):
  21. image = self.__to_float64(image)
  22. trimap = self.__to_float64(trimap)
  23. alpha_matte = self.alpha_estimator(image, trimap, **self.kargs)
  24. return alpha_matte
  25. def __to_float64(self, x):
  26. x_dtype = x.dtype
  27. assert x_dtype in ["float32", "float64"]
  28. x = x.astype("float64")
  29. return x
  30. @manager.MODELS.add_component
  31. class CloseFormMatting(BaseMLMatting):
  32. def __init__(self, **kargs):
  33. cf_alpha_estimator = pymatting.estimate_alpha_cf
  34. super().__init__(cf_alpha_estimator, **kargs)
  35. @manager.MODELS.add_component
  36. class KNNMatting(BaseMLMatting):
  37. def __init__(self, **kargs):
  38. knn_alpha_estimator = pymatting.estimate_alpha_knn
  39. super().__init__(knn_alpha_estimator, **kargs)
  40. @manager.MODELS.add_component
  41. class LearningBasedMatting(BaseMLMatting):
  42. def __init__(self, **kargs):
  43. lbdm_alpha_estimator = pymatting.estimate_alpha_lbdm
  44. super().__init__(lbdm_alpha_estimator, **kargs)
  45. @manager.MODELS.add_component
  46. class FastMatting(BaseMLMatting):
  47. def __init__(self, **kargs):
  48. lkm_alpha_estimator = pymatting.estimate_alpha_lkm
  49. super().__init__(lkm_alpha_estimator, **kargs)
  50. @manager.MODELS.add_component
  51. class RandomWalksMatting(BaseMLMatting):
  52. def __init__(self, **kargs):
  53. rw_alpha_estimator = pymatting.estimate_alpha_rw
  54. super().__init__(rw_alpha_estimator, **kargs)
  55. if __name__ == "__main__":
  56. from pymatting.util.util import load_image, save_image, stack_images
  57. from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
  58. import cv2
  59. root = "/mnt/liuyi22/PaddlePaddle/PaddleSeg/Matting/data/examples/"
  60. image_path = root + "lemur.png"
  61. trimap_path = root + "lemur_trimap.png"
  62. cutout_path = root + "lemur_cutout.png"
  63. image = cv2.cvtColor(
  64. cv2.imread(image_path).astype("float64"), cv2.COLOR_BGR2RGB) / 255.0
  65. cv2.imwrite("image.png", (image * 255).astype('uint8'))
  66. trimap = load_image(trimap_path, "GRAY")
  67. print(image.shape, trimap.shape)
  68. print(image.dtype, trimap.dtype)
  69. cf = CloseFormMatting()
  70. alpha = cf(image, trimap)
  71. # alpha = pymatting.estimate_alpha_lkm(image, trimap)
  72. foreground = estimate_foreground_ml(image, alpha)
  73. cutout = stack_images(foreground, alpha)
  74. save_image(cutout_path, cutout)