estimate_foreground_ml.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import numpy as np
  2. from numba import njit, prange
  3. # The foreground estimation refer to pymatting [https://github.com/pymatting/pymatting/blob/master/pymatting/foreground/estimate_foreground_ml.py]
  4. @njit("void(f4[:, :, :], f4[:, :, :])", cache=True, nogil=True, parallel=True)
  5. def _resize_nearest_multichannel(dst, src):
  6. """
  7. Internal method.
  8. Resize image src to dst using nearest neighbors filtering.
  9. Images must have multiple color channels, i.e. :code:`len(shape) == 3`.
  10. Parameters
  11. ----------
  12. dst: numpy.ndarray of type np.float32
  13. output image
  14. src: numpy.ndarray of type np.float32
  15. input image
  16. """
  17. h_src, w_src, depth = src.shape
  18. h_dst, w_dst, depth = dst.shape
  19. for y_dst in prange(h_dst):
  20. for x_dst in range(w_dst):
  21. x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst))
  22. y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst))
  23. for c in range(depth):
  24. dst[y_dst, x_dst, c] = src[y_src, x_src, c]
  25. @njit("void(f4[:, :], f4[:, :])", cache=True, nogil=True, parallel=True)
  26. def _resize_nearest(dst, src):
  27. """
  28. Internal method.
  29. Resize image src to dst using nearest neighbors filtering.
  30. Images must be grayscale, i.e. :code:`len(shape) == 3`.
  31. Parameters
  32. ----------
  33. dst: numpy.ndarray of type np.float32
  34. output image
  35. src: numpy.ndarray of type np.float32
  36. input image
  37. """
  38. h_src, w_src = src.shape
  39. h_dst, w_dst = dst.shape
  40. for y_dst in prange(h_dst):
  41. for x_dst in range(w_dst):
  42. x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst))
  43. y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst))
  44. dst[y_dst, x_dst] = src[y_src, x_src]
  45. # TODO
  46. # There should be an option to switch @njit(parallel=True) on or off.
  47. # parallel=True would be faster, but might cause race conditions.
  48. # User should have the option to turn it on or off.
  49. @njit(
  50. "Tuple((f4[:, :, :], f4[:, :, :]))(f4[:, :, :], f4[:, :], f4, i4, i4, i4, f4)",
  51. cache=True,
  52. nogil=True)
  53. def _estimate_fb_ml(
  54. input_image,
  55. input_alpha,
  56. regularization,
  57. n_small_iterations,
  58. n_big_iterations,
  59. small_size,
  60. gradient_weight, ):
  61. h0, w0, depth = input_image.shape
  62. dtype = np.float32
  63. w_prev = 1
  64. h_prev = 1
  65. F_prev = np.empty((h_prev, w_prev, depth), dtype=dtype)
  66. B_prev = np.empty((h_prev, w_prev, depth), dtype=dtype)
  67. n_levels = int(np.ceil(np.log2(max(w0, h0))))
  68. for i_level in range(n_levels + 1):
  69. w = round(w0**(i_level / n_levels))
  70. h = round(h0**(i_level / n_levels))
  71. image = np.empty((h, w, depth), dtype=dtype)
  72. alpha = np.empty((h, w), dtype=dtype)
  73. _resize_nearest_multichannel(image, input_image)
  74. _resize_nearest(alpha, input_alpha)
  75. F = np.empty((h, w, depth), dtype=dtype)
  76. B = np.empty((h, w, depth), dtype=dtype)
  77. _resize_nearest_multichannel(F, F_prev)
  78. _resize_nearest_multichannel(B, B_prev)
  79. if w <= small_size and h <= small_size:
  80. n_iter = n_small_iterations
  81. else:
  82. n_iter = n_big_iterations
  83. b = np.zeros((2, depth), dtype=dtype)
  84. dx = [-1, 1, 0, 0]
  85. dy = [0, 0, -1, 1]
  86. for i_iter in range(n_iter):
  87. for y in prange(h):
  88. for x in range(w):
  89. a0 = alpha[y, x]
  90. a1 = 1.0 - a0
  91. a00 = a0 * a0
  92. a01 = a0 * a1
  93. # a10 = a01 can be omitted due to symmetry of matrix
  94. a11 = a1 * a1
  95. for c in range(depth):
  96. b[0, c] = a0 * image[y, x, c]
  97. b[1, c] = a1 * image[y, x, c]
  98. for d in range(4):
  99. x2 = max(0, min(w - 1, x + dx[d]))
  100. y2 = max(0, min(h - 1, y + dy[d]))
  101. gradient = abs(a0 - alpha[y2, x2])
  102. da = regularization + gradient_weight * gradient
  103. a00 += da
  104. a11 += da
  105. for c in range(depth):
  106. b[0, c] += da * F[y2, x2, c]
  107. b[1, c] += da * B[y2, x2, c]
  108. determinant = a00 * a11 - a01 * a01
  109. inv_det = 1.0 / determinant
  110. b00 = inv_det * a11
  111. b01 = inv_det * -a01
  112. b11 = inv_det * a00
  113. for c in range(depth):
  114. F_c = b00 * b[0, c] + b01 * b[1, c]
  115. B_c = b01 * b[0, c] + b11 * b[1, c]
  116. F_c = max(0.0, min(1.0, F_c))
  117. B_c = max(0.0, min(1.0, B_c))
  118. F[y, x, c] = F_c
  119. B[y, x, c] = B_c
  120. F_prev = F
  121. B_prev = B
  122. w_prev = w
  123. h_prev = h
  124. return F, B
  125. def estimate_foreground_ml(
  126. image,
  127. alpha,
  128. regularization=1e-5,
  129. n_small_iterations=10,
  130. n_big_iterations=2,
  131. small_size=32,
  132. return_background=False,
  133. gradient_weight=1.0, ):
  134. """Estimates the foreground of an image given its alpha matte.
  135. See :cite:`germer2020multilevel` for reference.
  136. Parameters
  137. ----------
  138. image: numpy.ndarray
  139. Input image with shape :math:`h \\times w \\times d`
  140. alpha: numpy.ndarray
  141. Input alpha matte shape :math:`h \\times w`
  142. regularization: float
  143. Regularization strength :math:`\\epsilon`, defaults to :math:`10^{-5}`.
  144. Higher regularization results in smoother colors.
  145. n_small_iterations: int
  146. Number of iterations performed on small scale, defaults to :math:`10`
  147. n_big_iterations: int
  148. Number of iterations performed on large scale, defaults to :math:`2`
  149. small_size: int
  150. Threshold that determines at which size `n_small_iterations` should be used
  151. return_background: bool
  152. Whether to return the estimated background in addition to the foreground
  153. gradient_weight: float
  154. Larger values enforce smoother foregrounds, defaults to :math:`1`
  155. Returns
  156. -------
  157. F: numpy.ndarray
  158. Extracted foreground
  159. B: numpy.ndarray
  160. Extracted background
  161. Example
  162. -------
  163. >>> from pymatting import *
  164. >>> image = load_image("data/lemur/lemur.png", "RGB")
  165. >>> alpha = load_image("data/lemur/lemur_alpha.png", "GRAY")
  166. >>> F = estimate_foreground_ml(image, alpha, return_background=False)
  167. >>> F, B = estimate_foreground_ml(image, alpha, return_background=True)
  168. See Also
  169. ----
  170. stack_images: This function can be used to place the foreground on a new background.
  171. """
  172. foreground, background = _estimate_fb_ml(
  173. image.astype(np.float32),
  174. alpha.astype(np.float32),
  175. regularization,
  176. n_small_iterations,
  177. n_big_iterations,
  178. small_size,
  179. gradient_weight, )
  180. if return_background:
  181. return foreground, background
  182. return foreground