val.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
  18. sys.path.append(os.path.join(LOCAL_PATH, '..'))
  19. import paddle
  20. import paddleseg
  21. from paddleseg.cvlibs import manager
  22. from paddleseg.utils import get_sys_env, logger, utils
  23. manager.BACKBONES._components_dict.clear()
  24. manager.TRANSFORMS._components_dict.clear()
  25. import ppmatting
  26. from ppmatting.core import evaluate, evaluate_ml
  27. from ppmatting.utils import Config, MatBuilder
  28. def parse_args():
  29. parser = argparse.ArgumentParser(description='Model training')
  30. parser.add_argument(
  31. "--config", dest="cfg", help="The config file.", default=None, type=str)
  32. parser.add_argument(
  33. '--opts',
  34. help='Update the key-value pairs of all options.',
  35. default=None,
  36. nargs='+')
  37. parser.add_argument(
  38. '--model_path',
  39. dest='model_path',
  40. help='The path of model for evaluation',
  41. type=str,
  42. default=None)
  43. parser.add_argument(
  44. '--save_dir',
  45. dest='save_dir',
  46. help='The directory for saving the model snapshot',
  47. type=str,
  48. default='./output/results')
  49. parser.add_argument(
  50. '--num_workers',
  51. dest='num_workers',
  52. help='Num workers for data loader',
  53. type=int,
  54. default=0)
  55. parser.add_argument(
  56. '--save_results',
  57. dest='save_results',
  58. help='save prediction alpha while evaluating',
  59. action='store_true')
  60. parser.add_argument(
  61. '--metrics',
  62. dest='metrics',
  63. nargs='+',
  64. help='The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn")',
  65. type=str,
  66. default='sad')
  67. parser.add_argument(
  68. '--device',
  69. dest='device',
  70. help='Set the device type, which may be GPU, CPU or XPU.',
  71. default='gpu',
  72. type=str)
  73. return parser.parse_args()
  74. def main(args):
  75. assert args.cfg is not None, \
  76. 'No configuration file specified, please set --config'
  77. cfg = Config(args.cfg, opts=args.opts)
  78. builder = MatBuilder(cfg)
  79. paddleseg.utils.show_env_info()
  80. paddleseg.utils.show_cfg_info(cfg)
  81. paddleseg.utils.set_device(args.device)
  82. model = builder.model
  83. val_dataset = builder.val_dataset
  84. if isinstance(model, paddle.nn.Layer):
  85. if args.model_path:
  86. utils.load_entire_model(model, args.model_path)
  87. logger.info('Loaded trained params of model successfully')
  88. evaluate(
  89. model,
  90. val_dataset,
  91. num_workers=args.num_workers,
  92. save_dir=args.save_dir,
  93. save_results=args.save_results,
  94. metrics=args.metrics)
  95. else:
  96. evaluate_ml(
  97. model,
  98. val_dataset,
  99. save_dir=args.save_dir,
  100. save_results=args.save_results)
  101. if __name__ == '__main__':
  102. args = parse_args()
  103. main(args)