predict_video.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import math
  16. import time
  17. import cv2
  18. import numpy as np
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddleseg import utils
  22. from paddleseg.core import infer
  23. from paddleseg.utils import logger, progbar, TimeAverager
  24. from ppmatting.utils import mkdir, estimate_foreground_ml, VideoReader, VideoWriter
  25. def build_loader_writter(video_path, transforms, save_dir):
  26. reader = VideoReader(video_path, transforms)
  27. loader = paddle.io.DataLoader(reader)
  28. base_name = os.path.basename(video_path)
  29. name = os.path.splitext(base_name)[0]
  30. alpha_save_path = os.path.join(save_dir, name + '_alpha.avi')
  31. fg_save_path = os.path.join(save_dir, name + '_fg.avi')
  32. writer_alpha = VideoWriter(
  33. alpha_save_path,
  34. reader.fps,
  35. frame_size=(reader.width, reader.height),
  36. is_color=False)
  37. writer_fg = VideoWriter(
  38. fg_save_path,
  39. reader.fps,
  40. frame_size=(reader.width, reader.height),
  41. is_color=True)
  42. writers = {'alpha': writer_alpha, 'fg': writer_fg}
  43. return loader, writers
  44. def reverse_transform(img, trans_info):
  45. """recover pred to origin shape"""
  46. for item in trans_info[::-1]:
  47. if item[0][0] == 'resize':
  48. h, w = item[1][0], item[1][1]
  49. img = F.interpolate(img, [h, w], mode='bilinear')
  50. elif item[0][0] == 'padding':
  51. h, w = item[1][0], item[1][1]
  52. img = img[:, :, 0:h, 0:w]
  53. else:
  54. raise Exception("Unexpected info '{}' in im_info".format(item[0]))
  55. return img
  56. def postprocess(fg, alpha, img, trans_info, writers, fg_estimate):
  57. """
  58. Postprocess for prediction results.
  59. Args:
  60. fg (Tensor): The foreground, value should be in [0, 1].
  61. alpha (Tensor): The alpha, value should be in [0, 1].
  62. img (Tensor): The original image, value should be in [0, 1].
  63. trans_info (list): A list of the shape transformations.
  64. writers (dict): A dict of VideoWriter instance.
  65. fg_estimate (bool): Whether to estimate foreground. It is invalid when fg is not None.
  66. """
  67. alpha = reverse_transform(alpha, trans_info)
  68. if fg is None:
  69. if fg_estimate:
  70. img = img.transpose((0, 2, 3, 1)).squeeze().numpy()
  71. alpha = alpha.squeeze().numpy()
  72. fg = estimate_foreground_ml(img, alpha)
  73. else:
  74. fg = img
  75. else:
  76. fg = reverse_transform(fg, trans_info)
  77. if len(alpha.shape) == 2:
  78. fg = alpha[:, :, None] * fg
  79. else:
  80. fg = alpha * fg
  81. writers['alpha'].write(alpha)
  82. writers['fg'].write(fg)
  83. def predict_video(model,
  84. model_path,
  85. transforms,
  86. video_path,
  87. save_dir='output',
  88. fg_estimate=True):
  89. """
  90. predict and visualize the video.
  91. Args:
  92. model (nn.Layer): Used to predict for input video.
  93. model_path (str): The path of pretrained model.
  94. transforms (transforms.Compose): Preprocess for frames of video.
  95. video_path (str): the video path to be predicted.
  96. save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
  97. fg_estimate (bool, optional): Whether to estimate foreground when predicting. It is invalid if the foreground is predicted by model. Default: True
  98. """
  99. utils.utils.load_entire_model(model, model_path)
  100. model.eval()
  101. # Build loader and writer for video
  102. loader, writers = build_loader_writter(
  103. video_path, transforms, save_dir=save_dir)
  104. logger.info("Start to predict...")
  105. progbar_pred = progbar.Progbar(target=len(loader), verbose=1)
  106. preprocess_cost_averager = TimeAverager()
  107. infer_cost_averager = TimeAverager()
  108. postprocess_cost_averager = TimeAverager()
  109. batch_start = time.time()
  110. with paddle.no_grad():
  111. for i, data in enumerate(loader):
  112. preprocess_cost_averager.record(time.time() - batch_start)
  113. infer_start = time.time()
  114. result = model(data) # result maybe a Tensor or a dict
  115. if isinstance(result, paddle.Tensor):
  116. alpha = result
  117. fg = None
  118. else:
  119. alpha = result['alpha']
  120. fg = result.get('fg', None)
  121. infer_cost_averager.record(time.time() - infer_start)
  122. postprocess_start = time.time()
  123. postprocess(
  124. fg,
  125. alpha,
  126. data['ori_img'],
  127. trans_info=data['trans_info'],
  128. writers=writers,
  129. fg_estimate=fg_estimate)
  130. postprocess_cost_averager.record(time.time() - postprocess_start)
  131. preprocess_cost = preprocess_cost_averager.get_average()
  132. infer_cost = infer_cost_averager.get_average()
  133. postprocess_cost = postprocess_cost_averager.get_average()
  134. progbar_pred.update(i + 1, [('preprocess_cost', preprocess_cost),
  135. ('infer_cost cost', infer_cost),
  136. ('postprocess_cost', postprocess_cost)])
  137. preprocess_cost_averager.reset()
  138. infer_cost_averager.reset()
  139. postprocess_cost_averager.reset()
  140. batch_start = time.time()
  141. if hasattr(model, 'reset'):
  142. model.reset()
  143. loader.dataset.release()
  144. for k, v in writers.items():
  145. v.release()