predict.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import paddle
  18. import paddleseg
  19. from paddleseg.cvlibs import manager
  20. from paddleseg.utils import get_sys_env, logger
  21. LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
  22. sys.path.append(os.path.join(LOCAL_PATH, '..'))
  23. manager.BACKBONES._components_dict.clear()
  24. manager.TRANSFORMS._components_dict.clear()
  25. import ppmatting
  26. from ppmatting.core import predict, load
  27. from ppmatting.utils import get_image_list, Config, MatBuilder
  28. current_path = os.path.abspath(os.path.dirname(__file__))
  29. def parse_args():
  30. parser = argparse.ArgumentParser(description='Model training')
  31. parser.add_argument(
  32. "--config", dest="cfg", help="The config file.", default=None, type=str)
  33. parser.add_argument(
  34. '--model_path',
  35. dest='model_path',
  36. help='The path of model for prediction',
  37. type=str,
  38. default=None)
  39. parser.add_argument(
  40. '--image_path',
  41. dest='image_path',
  42. help='The path of image, it can be a file or a directory including images',
  43. type=str,
  44. default=None)
  45. parser.add_argument(
  46. '--trimap_path',
  47. dest='trimap_path',
  48. help='The path of trimap, it can be a file or a directory including images. '
  49. 'The image should be the same as image when it is a directory.',
  50. type=str,
  51. default=None)
  52. parser.add_argument(
  53. '--save_dir',
  54. dest='save_dir',
  55. help='The directory for saving the model snapshot',
  56. type=str,
  57. default='./output/results')
  58. parser.add_argument(
  59. '--fg_estimate',
  60. default=True,
  61. type=eval,
  62. choices=[True, False],
  63. help='Whether to estimate foreground when predicting.')
  64. parser.add_argument(
  65. '--device',
  66. dest='device',
  67. help='Set the device type, which may be GPU, CPU or XPU.',
  68. default='gpu',
  69. type=str)
  70. return parser.parse_args()
  71. def main(args):
  72. assert args.cfg is not None, \
  73. 'No configuration file specified, please set --config'
  74. cfg = Config(args.cfg)
  75. builder = MatBuilder(cfg)
  76. paddleseg.utils.show_env_info()
  77. paddleseg.utils.show_cfg_info(cfg)
  78. paddleseg.utils.set_device(args.device)
  79. model = builder.model
  80. transforms = ppmatting.transforms.Compose(builder.val_transforms)
  81. image_list, image_dir = get_image_list(args.image_path)
  82. if args.trimap_path is None:
  83. trimap_list = None
  84. else:
  85. trimap_list, _ = get_image_list(args.trimap_path)
  86. logger.info('Number of predict images = {}'.format(len(image_list)))
  87. predict(
  88. model,
  89. model_path=args.model_path,
  90. transforms=transforms,
  91. image_list=image_list,
  92. image_dir=image_dir,
  93. trimap_list=trimap_list,
  94. save_dir=args.save_dir,
  95. fg_estimate=args.fg_estimate)
  96. def get_rel_path(path: str):
  97. return "{}/../{}".format(current_path, path)
  98. global model, transforms
  99. def load_model():
  100. cfg = Config(get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml"))
  101. builder = MatBuilder(cfg)
  102. paddleseg.utils.show_env_info()
  103. paddleseg.utils.show_cfg_info(cfg)
  104. paddleseg.utils.set_device("cpu")
  105. global model, transforms
  106. model = builder.model
  107. model_path = get_rel_path("models/ppmatting-hrnet_w18-human_512.pdparams")
  108. paddleseg.utils.show_env_info()
  109. paddleseg.utils.show_cfg_info(cfg)
  110. paddleseg.utils.set_device("cpu")
  111. transforms = ppmatting.transforms.Compose(builder.val_transforms)
  112. load(model, model_path)
  113. def seg(img_path: str, save_dir: str):
  114. # cfg = Config(get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml"))
  115. # builder = MatBuilder(cfg)
  116. # paddleseg.utils.show_env_info()
  117. # paddleseg.utils.show_cfg_info(cfg)
  118. # paddleseg.utils.set_device("cpu")
  119. # model = builder.model
  120. # transforms = ppmatting.transforms.Compose(builder.val_transforms)
  121. image_list, image_dir = get_image_list(img_path)
  122. logger.info('Number of predict images = {}'.format(len(image_list)))
  123. model_path = get_rel_path("models/ppmatting-hrnet_w18-human_512.pdparams")
  124. predict(
  125. model,
  126. model_path=model_path,
  127. transforms=transforms,
  128. image_list=image_list,
  129. image_dir=image_dir,
  130. trimap_list=None,
  131. save_dir=save_dir,
  132. fg_estimate=True)
  133. if __name__ == '__main__':
  134. args = parse_args()
  135. main(args)