predict.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import paddle
  18. import paddleseg
  19. from paddleseg.cvlibs import manager
  20. from paddleseg.utils import get_sys_env, logger
  21. import ppmatting
  22. from ppmatting.core import predict, load
  23. from ppmatting.utils import get_image_list, Config, MatBuilder
  24. current_path = os.path.abspath(os.path.dirname(__file__))
  25. def parse_args():
  26. parser = argparse.ArgumentParser(description='Model training')
  27. parser.add_argument(
  28. "--config", dest="cfg", help="The config file.", default=None, type=str)
  29. parser.add_argument(
  30. '--model_path',
  31. dest='model_path',
  32. help='The path of model for prediction',
  33. type=str,
  34. default=None)
  35. parser.add_argument(
  36. '--image_path',
  37. dest='image_path',
  38. help='The path of image, it can be a file or a directory including images',
  39. type=str,
  40. default=None)
  41. parser.add_argument(
  42. '--trimap_path',
  43. dest='trimap_path',
  44. help='The path of trimap, it can be a file or a directory including images. '
  45. 'The image should be the same as image when it is a directory.',
  46. type=str,
  47. default=None)
  48. parser.add_argument(
  49. '--save_dir',
  50. dest='save_dir',
  51. help='The directory for saving the model snapshot',
  52. type=str,
  53. default='./output/results')
  54. parser.add_argument(
  55. '--fg_estimate',
  56. default=True,
  57. type=eval,
  58. choices=[True, False],
  59. help='Whether to estimate foreground when predicting.')
  60. parser.add_argument(
  61. '--device',
  62. dest='device',
  63. help='Set the device type, which may be GPU, CPU or XPU.',
  64. default='gpu',
  65. type=str)
  66. return parser.parse_args()
  67. def main(args):
  68. assert args.cfg is not None, \
  69. 'No configuration file specified, please set --config'
  70. cfg = Config(args.cfg)
  71. builder = MatBuilder(cfg)
  72. paddleseg.utils.show_env_info()
  73. paddleseg.utils.show_cfg_info(cfg)
  74. paddleseg.utils.set_device(args.device)
  75. model = builder.model
  76. transforms = ppmatting.transforms.Compose(builder.val_transforms)
  77. image_list, image_dir = get_image_list(args.image_path)
  78. if args.trimap_path is None:
  79. trimap_list = None
  80. else:
  81. trimap_list, _ = get_image_list(args.trimap_path)
  82. logger.info('Number of predict images = {}'.format(len(image_list)))
  83. predict(
  84. model,
  85. model_path=args.model_path,
  86. transforms=transforms,
  87. image_list=image_list,
  88. image_dir=image_dir,
  89. trimap_list=trimap_list,
  90. save_dir=args.save_dir,
  91. fg_estimate=args.fg_estimate)
  92. def get_rel_path(path: str):
  93. return "{}/../{}".format(current_path, path)
  94. global model, transforms
  95. def load_model():
  96. cfg = Config(get_rel_path("configs/ppmattingv2/ppmattingv2-stdc1-human_512.yml"))
  97. builder = MatBuilder(cfg)
  98. paddleseg.utils.show_env_info()
  99. paddleseg.utils.show_cfg_info(cfg)
  100. paddleseg.utils.set_device("cpu")
  101. global model, transforms
  102. model = builder.model
  103. model_path = get_rel_path("models/ppmattingv2-stdc1-human_512.pdparams")
  104. paddleseg.utils.show_env_info()
  105. paddleseg.utils.show_cfg_info(cfg)
  106. paddleseg.utils.set_device("cpu")
  107. transforms = ppmatting.transforms.Compose(builder.val_transforms)
  108. load(model, model_path)
  109. def seg(img_path: str, save_dir: str):
  110. # cfg = Config(get_rel_path("configs/quick_start/ppmattingv2-stdc1-human_512.yml"))
  111. # builder = MatBuilder(cfg)
  112. # paddleseg.utils.show_env_info()
  113. # paddleseg.utils.show_cfg_info(cfg)
  114. # paddleseg.utils.set_device("cpu")
  115. # model = builder.model
  116. # transforms = ppmatting.transforms.Compose(builder.val_transforms)
  117. image_list, image_dir = get_image_list(img_path)
  118. logger.info('Number of predict images = {}'.format(len(image_list)))
  119. model_path = get_rel_path("models/ppmatting-hrnet_w18-human_512.pdparams")
  120. predict(
  121. model,
  122. model_path=model_path,
  123. transforms=transforms,
  124. image_list=image_list,
  125. image_dir=image_dir,
  126. trimap_list=None,
  127. save_dir=save_dir,
  128. fg_estimate=True)
  129. if __name__ == '__main__':
  130. args = parse_args()
  131. main(args)