predict.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import math
  16. import time
  17. import cv2
  18. import numpy as np
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddleseg import utils
  22. from paddleseg.core import infer
  23. from paddleseg.utils import logger, progbar, TimeAverager
  24. from ppmatting.utils import mkdir, estimate_foreground_ml
  25. def partition_list(arr, m):
  26. """split the list 'arr' into m pieces"""
  27. n = int(math.ceil(len(arr) / float(m)))
  28. return [arr[i:i + n] for i in range(0, len(arr), n)]
  29. def save_result(alpha, path, im_path, trimap=None, fg_estimate=True, fg=None):
  30. """
  31. Save alpha and rgba.
  32. Args:
  33. alpha (numpy.ndarray): The value of alpha should in [0, 255], shape should be [h,w].
  34. path (str): The save path
  35. im_path (str): The original image path.
  36. trimap (str, optional): The trimap if provided. Default: None.
  37. fg_estimate (bool, optional): Whether to estimate the foreground, Default: True.
  38. fg (numpy.ndarray, optional): The foreground, if provided, fg_estimate is invalid. Default: None.
  39. """
  40. dirname = os.path.dirname(path)
  41. if not os.path.exists(dirname):
  42. os.makedirs(dirname)
  43. basename = os.path.basename(path)
  44. name = os.path.splitext(basename)[0]
  45. alpha_save_path = os.path.join(dirname, name + '_alpha.png')
  46. rgba_save_path = os.path.join(dirname, name + '_rgba.png')
  47. # save alpha matte
  48. if trimap is not None:
  49. trimap = cv2.imread(trimap, 0)
  50. alpha[trimap == 0] = 0
  51. alpha[trimap == 255] = 255
  52. alpha = (alpha).astype('uint8')
  53. cv2.imwrite(alpha_save_path, alpha)
  54. # save rgba
  55. im = cv2.imread(im_path)
  56. if fg is None:
  57. if fg_estimate:
  58. fg = estimate_foreground_ml(im / 255.0, alpha / 255.0) * 255
  59. else:
  60. fg = im
  61. fg = fg.astype('uint8')
  62. alpha = alpha[:, :, np.newaxis]
  63. rgba = np.concatenate((fg, alpha), axis=-1)
  64. cv2.imwrite(rgba_save_path, rgba)
  65. return fg
  66. def reverse_transform(img, trans_info):
  67. """recover pred to origin shape"""
  68. for item in trans_info[::-1]:
  69. if item[0] == 'resize':
  70. h, w = item[1][0], item[1][1]
  71. img = F.interpolate(img, [h, w], mode='bilinear')
  72. elif item[0] == 'padding':
  73. h, w = item[1][0], item[1][1]
  74. img = img[:, :, 0:h, 0:w]
  75. else:
  76. raise Exception("Unexpected info '{}' in im_info".format(item[0]))
  77. return img
  78. def preprocess(img, transforms, trimap=None):
  79. data = {}
  80. data['img'] = img
  81. if trimap is not None:
  82. data['trimap'] = trimap
  83. data['gt_fields'] = ['trimap']
  84. data['trans_info'] = []
  85. data = transforms(data)
  86. data['img'] = paddle.to_tensor(data['img'])
  87. data['img'] = data['img'].unsqueeze(0)
  88. if trimap is not None:
  89. data['trimap'] = paddle.to_tensor(data['trimap'])
  90. data['trimap'] = data['trimap'].unsqueeze((0, 1))
  91. return data
  92. def predict(model,
  93. model_path,
  94. transforms,
  95. image_list,
  96. image_dir=None,
  97. trimap_list=None,
  98. save_dir='output',
  99. fg_estimate=True):
  100. """
  101. predict and visualize the image_list.
  102. Args:
  103. model (nn.Layer): Used to predict for input image.
  104. model_path (str): The path of pretrained model.
  105. transforms (transforms.Compose): Preprocess for input image.
  106. image_list (list): A list of image path to be predicted.
  107. image_dir (str, optional): The root directory of the images predicted. Default: None.
  108. trimap_list (list, optional): A list of trimap of image_list. Default: None.
  109. save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
  110. """
  111. utils.utils.load_entire_model(model, model_path)
  112. model.eval()
  113. nranks = paddle.distributed.get_world_size()
  114. local_rank = paddle.distributed.get_rank()
  115. if nranks > 1:
  116. img_lists = partition_list(image_list, nranks)
  117. trimap_lists = partition_list(
  118. trimap_list, nranks) if trimap_list is not None else None
  119. else:
  120. img_lists = [image_list]
  121. trimap_lists = [trimap_list] if trimap_list is not None else None
  122. logger.info("Start to predict...")
  123. progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1)
  124. preprocess_cost_averager = TimeAverager()
  125. infer_cost_averager = TimeAverager()
  126. postprocess_cost_averager = TimeAverager()
  127. batch_start = time.time()
  128. with paddle.no_grad():
  129. for i, im_path in enumerate(img_lists[local_rank]):
  130. preprocess_start = time.time()
  131. trimap = trimap_lists[local_rank][
  132. i] if trimap_list is not None else None
  133. data = preprocess(img=im_path, transforms=transforms, trimap=trimap)
  134. preprocess_cost_averager.record(time.time() - preprocess_start)
  135. infer_start = time.time()
  136. result = model(data)
  137. infer_cost_averager.record(time.time() - infer_start)
  138. postprocess_start = time.time()
  139. if isinstance(result, paddle.Tensor):
  140. alpha = result
  141. fg = None
  142. else:
  143. alpha = result['alpha']
  144. fg = result.get('fg', None)
  145. alpha = reverse_transform(alpha, data['trans_info'])
  146. alpha = (alpha.numpy()).squeeze()
  147. alpha = (alpha * 255).astype('uint8')
  148. if fg is not None:
  149. fg = reverse_transform(fg, data['trans_info'])
  150. fg = (fg.numpy()).squeeze().transpose((1, 2, 0))
  151. fg = (fg * 255).astype('uint8')
  152. # get the saved name
  153. if image_dir is not None:
  154. im_file = im_path.replace(image_dir, '')
  155. else:
  156. im_file = os.path.basename(im_path)
  157. if im_file[0] == '/' or im_file[0] == '\\':
  158. im_file = im_file[1:]
  159. save_path = os.path.join(save_dir, im_file)
  160. mkdir(save_path)
  161. fg = save_result(
  162. alpha,
  163. save_path,
  164. im_path=im_path,
  165. trimap=trimap,
  166. fg_estimate=fg_estimate,
  167. fg=fg)
  168. # rvm have member which need to reset.
  169. if hasattr(model, 'reset'):
  170. model.reset()
  171. postprocess_cost_averager.record(time.time() - postprocess_start)
  172. preprocess_cost = preprocess_cost_averager.get_average()
  173. infer_cost = infer_cost_averager.get_average()
  174. postprocess_cost = postprocess_cost_averager.get_average()
  175. if local_rank == 0:
  176. progbar_pred.update(i + 1,
  177. [('preprocess_cost', preprocess_cost),
  178. ('infer_cost cost', infer_cost),
  179. ('postprocess_cost', postprocess_cost)])
  180. preprocess_cost_averager.reset()
  181. infer_cost_averager.reset()
  182. postprocess_cost_averager.reset()
  183. return alpha, fg