predict_video.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import paddle
  18. import paddleseg
  19. from paddleseg.cvlibs import manager
  20. from paddleseg.utils import get_sys_env, logger
  21. LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
  22. sys.path.append(os.path.join(LOCAL_PATH, '..'))
  23. manager.BACKBONES._components_dict.clear()
  24. manager.TRANSFORMS._components_dict.clear()
  25. import ppmatting
  26. from ppmatting.core import predict_video
  27. from ppmatting.utils import Config, MatBuilder
  28. def parse_args():
  29. parser = argparse.ArgumentParser(description='Model training')
  30. parser.add_argument(
  31. "--config", dest="cfg", help="The config file.", default=None, type=str)
  32. parser.add_argument(
  33. '--model_path',
  34. dest='model_path',
  35. help='The path of model for prediction',
  36. type=str,
  37. default=None)
  38. parser.add_argument(
  39. '--video_path',
  40. dest='video_path',
  41. help='The path of video',
  42. default=None)
  43. parser.add_argument(
  44. '--save_dir',
  45. dest='save_dir',
  46. help='The directory for saving the model snapshot',
  47. type=str,
  48. default='./output/results')
  49. parser.add_argument(
  50. '--fg_estimate',
  51. default=True,
  52. type=eval,
  53. choices=[True, False],
  54. help='Whether to estimate foreground when predicting.')
  55. parser.add_argument(
  56. '--device',
  57. dest='device',
  58. help='Set the device type, which may be GPU, CPU or XPU.',
  59. default='gpu',
  60. type=str)
  61. return parser.parse_args()
  62. def main(args):
  63. assert args.cfg is not None, \
  64. 'No configuration file specified, please set --config'
  65. cfg = Config(args.cfg)
  66. builder = MatBuilder(cfg)
  67. paddleseg.utils.show_env_info()
  68. paddleseg.utils.show_cfg_info(cfg)
  69. paddleseg.utils.set_device(args.device)
  70. model = builder.model
  71. transforms = ppmatting.transforms.Compose(builder.val_transforms)
  72. predict_video(
  73. model,
  74. model_path=args.model_path,
  75. transforms=transforms,
  76. video_path=args.video_path,
  77. save_dir=args.save_dir,
  78. fg_estimate=args.fg_estimate)
  79. if __name__ == '__main__':
  80. args = parse_args()
  81. main(args)