val.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import time
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddleseg.utils import TimeAverager, calculate_eta, logger, progbar
  21. from ppmatting.metrics import metrics_class_dict
  22. np.set_printoptions(suppress=True)
  23. def save_alpha_pred(alpha, path):
  24. """
  25. The value of alpha is range [0, 1], shape should be [h,w]
  26. """
  27. dirname = os.path.dirname(path)
  28. if not os.path.exists(dirname):
  29. os.makedirs(dirname)
  30. alpha = (alpha).astype('uint8')
  31. cv2.imwrite(path, alpha)
  32. def reverse_transform(alpha, trans_info):
  33. """recover pred to origin shape"""
  34. for item in trans_info[::-1]:
  35. if item[0][0] == 'resize':
  36. h, w = item[1][0], item[1][1]
  37. alpha = F.interpolate(alpha, [h, w], mode='bilinear')
  38. elif item[0][0] == 'padding':
  39. h, w = item[1][0], item[1][1]
  40. alpha = alpha[:, :, 0:h, 0:w]
  41. else:
  42. raise Exception("Unexpected info '{}' in im_info".format(item[0]))
  43. return alpha
  44. def evaluate(model,
  45. eval_dataset,
  46. num_workers=0,
  47. print_detail=True,
  48. save_dir='output/results',
  49. save_results=True,
  50. metrics='sad',
  51. precision='fp32',
  52. amp_level='O1'):
  53. model.eval()
  54. nranks = paddle.distributed.ParallelEnv().nranks
  55. local_rank = paddle.distributed.ParallelEnv().local_rank
  56. if nranks > 1:
  57. # Initialize parallel environment if not done.
  58. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  59. ):
  60. paddle.distributed.init_parallel_env()
  61. loader = paddle.io.DataLoader(
  62. eval_dataset,
  63. batch_size=1,
  64. drop_last=False,
  65. num_workers=num_workers,
  66. return_list=True, )
  67. total_iters = len(loader)
  68. # Get metric instances and data saving
  69. metrics_ins = {}
  70. metrics_data = {}
  71. if isinstance(metrics, str):
  72. metrics = [metrics]
  73. elif not isinstance(metrics, list):
  74. metrics = ['sad']
  75. for key in metrics:
  76. key = key.lower()
  77. metrics_ins[key] = metrics_class_dict[key]()
  78. metrics_data[key] = None
  79. if print_detail:
  80. logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
  81. format(len(eval_dataset), total_iters))
  82. progbar_val = progbar.Progbar(
  83. target=total_iters, verbose=1 if nranks < 2 else 2)
  84. reader_cost_averager = TimeAverager()
  85. batch_cost_averager = TimeAverager()
  86. batch_start = time.time()
  87. img_name = ''
  88. i = 0
  89. with paddle.no_grad():
  90. for iter, data in enumerate(loader):
  91. reader_cost_averager.record(time.time() - batch_start)
  92. if precision == 'fp16':
  93. with paddle.amp.auto_cast(
  94. level=amp_level,
  95. enable=True,
  96. custom_white_list={
  97. "elementwise_add", "batch_norm", "sync_batch_norm"
  98. },
  99. custom_black_list={'bilinear_interp_v2', 'pad3d'}):
  100. alpha_pred = model(data)
  101. alpha_pred = reverse_transform(alpha_pred,
  102. data['trans_info'])
  103. else:
  104. alpha_pred = model(data)
  105. alpha_pred = reverse_transform(alpha_pred, data['trans_info'])
  106. alpha_pred = alpha_pred.numpy()
  107. alpha_gt = data['alpha'].numpy() * 255
  108. trimap = data.get('ori_trimap')
  109. if trimap is not None:
  110. trimap = trimap.numpy().astype('uint8')
  111. alpha_pred = np.round(alpha_pred * 255)
  112. for key in metrics_ins.keys():
  113. metrics_data[key] = metrics_ins[key].update(alpha_pred,
  114. alpha_gt, trimap)
  115. if save_results:
  116. alpha_pred_one = alpha_pred[0].squeeze()
  117. if trimap is not None:
  118. trimap = trimap.squeeze().astype('uint8')
  119. alpha_pred_one[trimap == 255] = 255
  120. alpha_pred_one[trimap == 0] = 0
  121. save_name = data['img_name'][0]
  122. name, ext = os.path.splitext(save_name)
  123. if save_name == img_name:
  124. save_name = name + '_' + str(i) + ext
  125. i += 1
  126. else:
  127. img_name = save_name
  128. save_name = name + '_' + str(i) + ext
  129. i = 1
  130. save_alpha_pred(alpha_pred_one,
  131. os.path.join(save_dir, save_name))
  132. batch_cost_averager.record(
  133. time.time() - batch_start, num_samples=len(alpha_gt))
  134. batch_cost = batch_cost_averager.get_average()
  135. reader_cost = reader_cost_averager.get_average()
  136. if local_rank == 0 and print_detail:
  137. show_list = [(k, v) for k, v in metrics_data.items()]
  138. show_list = show_list + [('batch_cost', batch_cost),
  139. ('reader cost', reader_cost)]
  140. progbar_val.update(iter + 1, show_list)
  141. reader_cost_averager.reset()
  142. batch_cost_averager.reset()
  143. batch_start = time.time()
  144. for key in metrics_ins.keys():
  145. metrics_data[key] = metrics_ins[key].evaluate()
  146. log_str = '[EVAL] '
  147. for key, value in metrics_data.items():
  148. log_str = log_str + key + ': {:.4f}, '.format(value)
  149. log_str = log_str[:-2]
  150. logger.info(log_str)
  151. return metrics_data