metric.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Grad and Conn is refer to https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/evaluate.py
  15. # Output of `Grad` is sightly different from the MATLAB version provided by Adobe (less than 0.1%)
  16. # Output of `Conn` is smaller than the MATLAB version (~5%, maybe MATLAB has a different algorithm)
  17. # So do not report results calculated by these functions in your paper.
  18. # Evaluate your inference with the MATLAB file `DIM_evaluation_code/evaluate.m`.
  19. import cv2
  20. import numpy as np
  21. from scipy.ndimage.filters import convolve
  22. from scipy.special import gamma
  23. from skimage.measure import label
  24. class MSE:
  25. """
  26. Only calculate the unknown region if trimap provided.
  27. """
  28. def __init__(self):
  29. self.mse_diffs = 0
  30. self.count = 0
  31. def update(self, pred, gt, trimap=None):
  32. """
  33. update metric.
  34. Args:
  35. pred (np.ndarray): The value range is [0., 255.].
  36. gt (np.ndarray): The value range is [0, 255].
  37. trimap (np.ndarray, optional) The value is in {0, 128, 255}. Default: None.
  38. """
  39. if trimap is None:
  40. trimap = np.ones_like(gt) * 128
  41. if not (pred.shape == gt.shape == trimap.shape):
  42. raise ValueError(
  43. 'The shape of `pred`, `gt` and `trimap` should be equal. '
  44. 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
  45. trimap.shape))
  46. pred[trimap == 0] = 0
  47. pred[trimap == 255] = 255
  48. mask = trimap == 128
  49. pixels = float(mask.sum())
  50. pred = pred / 255.
  51. gt = gt / 255.
  52. diff = (pred - gt) * mask
  53. mse_diff = (diff**2).sum() / pixels if pixels > 0 else 0
  54. self.mse_diffs += mse_diff
  55. self.count += 1
  56. return mse_diff
  57. def evaluate(self):
  58. mse = self.mse_diffs / self.count if self.count > 0 else 0
  59. return mse
  60. class SAD:
  61. """
  62. Only calculate the unknown region if trimap provided.
  63. """
  64. def __init__(self):
  65. self.sad_diffs = 0
  66. self.count = 0
  67. def update(self, pred, gt, trimap=None):
  68. """
  69. update metric.
  70. Args:
  71. pred (np.ndarray): The value range is [0., 255.].
  72. gt (np.ndarray): The value range is [0., 255.].
  73. trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None.
  74. """
  75. if trimap is None:
  76. trimap = np.ones_like(gt) * 128
  77. if not (pred.shape == gt.shape == trimap.shape):
  78. raise ValueError(
  79. 'The shape of `pred`, `gt` and `trimap` should be equal. '
  80. 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
  81. trimap.shape))
  82. pred[trimap == 0] = 0
  83. pred[trimap == 255] = 255
  84. mask = trimap == 128
  85. pred = pred / 255.
  86. gt = gt / 255.
  87. diff = (pred - gt) * mask
  88. sad_diff = (np.abs(diff)).sum()
  89. sad_diff /= 1000
  90. self.sad_diffs += sad_diff
  91. self.count += 1
  92. return sad_diff
  93. def evaluate(self):
  94. sad = self.sad_diffs / self.count if self.count > 0 else 0
  95. return sad
  96. class Grad:
  97. """
  98. Only calculate the unknown region if trimap provided.
  99. Refer to: https://github.com/open-mlab/mmediting/blob/master/mmedit/core/evaluation/metrics.py
  100. """
  101. def __init__(self):
  102. self.grad_diffs = 0
  103. self.count = 0
  104. def gaussian(self, x, sigma):
  105. return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi))
  106. def dgaussian(self, x, sigma):
  107. return -x * self.gaussian(x, sigma) / sigma**2
  108. def gauss_filter(self, sigma, epsilon=1e-2):
  109. half_size = np.ceil(
  110. sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon)))
  111. size = int(2 * half_size + 1)
  112. # create filter in x axis
  113. filter_x = np.zeros((size, size))
  114. for i in range(size):
  115. for j in range(size):
  116. filter_x[i, j] = self.gaussian(
  117. i - half_size, sigma) * self.dgaussian(j - half_size, sigma)
  118. # normalize filter
  119. norm = np.sqrt((filter_x**2).sum())
  120. filter_x = filter_x / norm
  121. filter_y = np.transpose(filter_x)
  122. return filter_x, filter_y
  123. def gauss_gradient(self, img, sigma):
  124. filter_x, filter_y = self.gauss_filter(sigma)
  125. img_filtered_x = cv2.filter2D(
  126. img, -1, filter_x, borderType=cv2.BORDER_REPLICATE)
  127. img_filtered_y = cv2.filter2D(
  128. img, -1, filter_y, borderType=cv2.BORDER_REPLICATE)
  129. return np.sqrt(img_filtered_x**2 + img_filtered_y**2)
  130. def update(self, pred, gt, trimap=None, sigma=1.4):
  131. """
  132. update metric.
  133. Args:
  134. pred (np.ndarray): The value range is [0., 1.].
  135. gt (np.ndarray): The value range is [0, 255].
  136. trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None.
  137. sigma (float, optional): Standard deviation of the gaussian kernel. Default: 1.4.
  138. """
  139. if trimap is None:
  140. trimap = np.ones_like(gt) * 128
  141. if not (pred.shape == gt.shape == trimap.shape):
  142. raise ValueError(
  143. 'The shape of `pred`, `gt` and `trimap` should be equal. '
  144. 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
  145. trimap.shape))
  146. pred[trimap == 0] = 0
  147. pred[trimap == 255] = 255
  148. gt = gt.squeeze()
  149. pred = pred.squeeze()
  150. gt = gt.astype(np.float64)
  151. pred = pred.astype(np.float64)
  152. gt_normed = np.zeros_like(gt)
  153. pred_normed = np.zeros_like(pred)
  154. cv2.normalize(gt, gt_normed, 1., 0., cv2.NORM_MINMAX)
  155. cv2.normalize(pred, pred_normed, 1., 0., cv2.NORM_MINMAX)
  156. gt_grad = self.gauss_gradient(gt_normed, sigma).astype(np.float32)
  157. pred_grad = self.gauss_gradient(pred_normed, sigma).astype(np.float32)
  158. grad_diff = ((gt_grad - pred_grad)**2 * (trimap == 128)).sum()
  159. grad_diff /= 1000
  160. self.grad_diffs += grad_diff
  161. self.count += 1
  162. return grad_diff
  163. def evaluate(self):
  164. grad = self.grad_diffs / self.count if self.count > 0 else 0
  165. return grad
  166. class Conn:
  167. """
  168. Only calculate the unknown region if trimap provided.
  169. Refer to: Refer to: https://github.com/open-mlab/mmediting/blob/master/mmedit/core/evaluation/metrics.py
  170. """
  171. def __init__(self):
  172. self.conn_diffs = 0
  173. self.count = 0
  174. def update(self, pred, gt, trimap=None, step=0.1):
  175. """
  176. update metric.
  177. Args:
  178. pred (np.ndarray): The value range is [0., 1.].
  179. gt (np.ndarray): The value range is [0, 255].
  180. trimap (np.ndarray, optional)L The value is in {0, 128, 255}. Default: None.
  181. step (float, optional): Step of threshold when computing intersection between
  182. `gt` and `pred`. Default: 0.1.
  183. """
  184. if trimap is None:
  185. trimap = np.ones_like(gt) * 128
  186. if not (pred.shape == gt.shape == trimap.shape):
  187. raise ValueError(
  188. 'The shape of `pred`, `gt` and `trimap` should be equal. '
  189. 'but they are {}, {} and {}'.format(pred.shape, gt.shape,
  190. trimap.shape))
  191. pred[trimap == 0] = 0
  192. pred[trimap == 255] = 255
  193. gt = gt.squeeze()
  194. pred = pred.squeeze()
  195. gt = gt.astype(np.float32) / 255
  196. pred = pred.astype(np.float32) / 255
  197. thresh_steps = np.arange(0, 1 + step, step)
  198. round_down_map = -np.ones_like(gt)
  199. for i in range(1, len(thresh_steps)):
  200. gt_thresh = gt >= thresh_steps[i]
  201. pred_thresh = pred >= thresh_steps[i]
  202. intersection = (gt_thresh & pred_thresh).astype(np.uint8)
  203. # connected components
  204. _, output, stats, _ = cv2.connectedComponentsWithStats(
  205. intersection, connectivity=4)
  206. # start from 1 in dim 0 to exclude background
  207. size = stats[1:, -1]
  208. # largest connected component of the intersection
  209. omega = np.zeros_like(gt)
  210. if len(size) != 0:
  211. max_id = np.argmax(size)
  212. # plus one to include background
  213. omega[output == max_id + 1] = 1
  214. mask = (round_down_map == -1) & (omega == 0)
  215. round_down_map[mask] = thresh_steps[i - 1]
  216. round_down_map[round_down_map == -1] = 1
  217. gt_diff = gt - round_down_map
  218. pred_diff = pred - round_down_map
  219. # only calculate difference larger than or equal to 0.15
  220. gt_phi = 1 - gt_diff * (gt_diff >= 0.15)
  221. pred_phi = 1 - pred_diff * (pred_diff >= 0.15)
  222. conn_diff = np.sum(np.abs(gt_phi - pred_phi) * (trimap == 128))
  223. conn_diff /= 1000
  224. self.conn_diffs += conn_diff
  225. self.count += 1
  226. return conn_diff
  227. def evaluate(self):
  228. conn = self.conn_diffs / self.count if self.count > 0 else 0
  229. return conn