predict.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import math
  16. import time
  17. import cv2
  18. import numpy as np
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddleseg import utils
  22. from paddleseg.core import infer
  23. from paddleseg.utils import logger, progbar, TimeAverager
  24. from ppmatting.utils import mkdir, estimate_foreground_ml
  25. def partition_list(arr, m):
  26. """split the list 'arr' into m pieces"""
  27. n = int(math.ceil(len(arr) / float(m)))
  28. return [arr[i:i + n] for i in range(0, len(arr), n)]
  29. def save_result(alpha, path, im_path, trimap=None, fg_estimate=True, fg=None):
  30. """
  31. Save alpha and rgba.
  32. Args:
  33. alpha (numpy.ndarray): The value of alpha should in [0, 255], shape should be [h,w].
  34. path (str): The save path
  35. im_path (str): The original image path.
  36. trimap (str, optional): The trimap if provided. Default: None.
  37. fg_estimate (bool, optional): Whether to estimate the foreground, Default: True.
  38. fg (numpy.ndarray, optional): The foreground, if provided, fg_estimate is invalid. Default: None.
  39. """
  40. dirname = os.path.dirname(path)
  41. if not os.path.exists(dirname):
  42. os.makedirs(dirname)
  43. basename = os.path.basename(path)
  44. name = os.path.splitext(basename)[0]
  45. alpha_save_path = os.path.join(dirname, name + '_alpha.png')
  46. rgba_save_path = os.path.join(dirname, name + '_rgba.png')
  47. # save alpha matte
  48. if trimap is not None:
  49. trimap = cv2.imread(trimap, 0)
  50. alpha[trimap == 0] = 0
  51. alpha[trimap == 255] = 255
  52. alpha = (alpha).astype('uint8')
  53. cv2.imwrite(alpha_save_path, alpha)
  54. # save rgba
  55. im = cv2.imread(im_path)
  56. if fg is None:
  57. if fg_estimate:
  58. fg = estimate_foreground_ml(im / 255.0, alpha / 255.0) * 255
  59. else:
  60. fg = im
  61. fg = fg.astype('uint8')
  62. alpha = alpha[:, :, np.newaxis]
  63. rgba = np.concatenate((fg, alpha), axis=-1)
  64. cv2.imwrite(rgba_save_path, rgba)
  65. return fg, alpha_save_path, rgba_save_path
  66. def reverse_transform(img, trans_info):
  67. """recover pred to origin shape"""
  68. for item in trans_info[::-1]:
  69. if item[0] == 'resize':
  70. h, w = item[1][0], item[1][1]
  71. img = F.interpolate(img, [h, w], mode='bilinear')
  72. elif item[0] == 'padding':
  73. h, w = item[1][0], item[1][1]
  74. img = img[:, :, 0:h, 0:w]
  75. else:
  76. raise Exception("Unexpected info '{}' in im_info".format(item[0]))
  77. return img
  78. def preprocess(img, transforms, trimap=None):
  79. data = {}
  80. data['img'] = img
  81. if trimap is not None:
  82. data['trimap'] = trimap
  83. data['gt_fields'] = ['trimap']
  84. data['trans_info'] = []
  85. data = transforms(data)
  86. data['img'] = paddle.to_tensor(data['img'])
  87. data['img'] = data['img'].unsqueeze(0)
  88. if trimap is not None:
  89. data['trimap'] = paddle.to_tensor(data['trimap'])
  90. data['trimap'] = data['trimap'].unsqueeze((0, 1))
  91. return data
  92. def load(model, model_path):
  93. utils.utils.load_entire_model(model, model_path)
  94. model.eval()
  95. def predict(model,
  96. model_path,
  97. transforms,
  98. image_list,
  99. image_dir=None,
  100. trimap_list=None,
  101. save_dir='output',
  102. fg_estimate=True):
  103. """
  104. predict and visualize the image_list.
  105. Args:
  106. model (nn.Layer): Used to predict for input image.
  107. model_path (str): The path of pretrained model.
  108. transforms (transforms.Compose): Preprocess for input image.
  109. image_list (list): A list of image path to be predicted.
  110. image_dir (str, optional): The root directory of the images predicted. Default: None.
  111. trimap_list (list, optional): A list of trimap of image_list. Default: None.
  112. save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
  113. """
  114. # utils.utils.load_entire_model(model, model_path)
  115. # model.eval()
  116. nranks = paddle.distributed.get_world_size()
  117. local_rank = paddle.distributed.get_rank()
  118. if nranks > 1:
  119. img_lists = partition_list(image_list, nranks)
  120. trimap_lists = partition_list(
  121. trimap_list, nranks) if trimap_list is not None else None
  122. else:
  123. img_lists = [image_list]
  124. trimap_lists = [trimap_list] if trimap_list is not None else None
  125. logger.info("Start to predict...")
  126. progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1)
  127. preprocess_cost_averager = TimeAverager()
  128. infer_cost_averager = TimeAverager()
  129. postprocess_cost_averager = TimeAverager()
  130. batch_start = time.time()
  131. with paddle.no_grad():
  132. for i, im_path in enumerate(img_lists[local_rank]):
  133. preprocess_start = time.time()
  134. trimap = trimap_lists[local_rank][
  135. i] if trimap_list is not None else None
  136. data = preprocess(img=im_path, transforms=transforms, trimap=trimap)
  137. preprocess_cost_averager.record(time.time() - preprocess_start)
  138. infer_start = time.time()
  139. result = model(data)
  140. infer_cost_averager.record(time.time() - infer_start)
  141. postprocess_start = time.time()
  142. if isinstance(result, paddle.Tensor):
  143. alpha = result
  144. fg = None
  145. else:
  146. alpha = result['alpha']
  147. fg = result.get('fg', None)
  148. alpha = reverse_transform(alpha, data['trans_info'])
  149. alpha = (alpha.numpy()).squeeze()
  150. alpha = (alpha * 255).astype('uint8')
  151. if fg is not None:
  152. fg = reverse_transform(fg, data['trans_info'])
  153. fg = (fg.numpy()).squeeze().transpose((1, 2, 0))
  154. fg = (fg * 255).astype('uint8')
  155. # get the saved name
  156. if image_dir is not None:
  157. im_file = im_path.replace(image_dir, '')
  158. else:
  159. im_file = os.path.basename(im_path)
  160. if im_file[0] == '/' or im_file[0] == '\\':
  161. im_file = im_file[1:]
  162. save_path = os.path.join(save_dir, im_file)
  163. mkdir(save_path)
  164. fg,alpha_save_path, rgba_save_path = save_result(
  165. alpha,
  166. save_path,
  167. im_path=im_path,
  168. trimap=trimap,
  169. fg_estimate=fg_estimate,
  170. fg=fg)
  171. # rvm have member which need to reset.
  172. if hasattr(model, 'reset'):
  173. model.reset()
  174. postprocess_cost_averager.record(time.time() - postprocess_start)
  175. preprocess_cost = preprocess_cost_averager.get_average()
  176. infer_cost = infer_cost_averager.get_average()
  177. postprocess_cost = postprocess_cost_averager.get_average()
  178. if local_rank == 0:
  179. progbar_pred.update(i + 1,
  180. [('preprocess_cost', preprocess_cost),
  181. ('infer_cost cost', infer_cost),
  182. ('postprocess_cost', postprocess_cost)])
  183. preprocess_cost_averager.reset()
  184. infer_cost_averager.reset()
  185. postprocess_cost_averager.reset()
  186. return alpha, fg, alpha_save_path, rgba_save_path