export.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import paddle
  18. import yaml
  19. import paddleseg
  20. from paddleseg.cvlibs import manager
  21. from paddleseg.utils import logger
  22. LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
  23. sys.path.append(os.path.join(LOCAL_PATH, '..'))
  24. manager.BACKBONES._components_dict.clear()
  25. manager.TRANSFORMS._components_dict.clear()
  26. import ppmatting
  27. from ppmatting.utils import get_input_spec, Config, MatBuilder
  28. def parse_args():
  29. parser = argparse.ArgumentParser(description='Model export.')
  30. # params of training
  31. parser.add_argument(
  32. "--config",
  33. dest="cfg",
  34. help="The config file.",
  35. default=None,
  36. type=str,
  37. required=True)
  38. parser.add_argument(
  39. '--save_dir',
  40. dest='save_dir',
  41. help='The directory for saving the exported model',
  42. type=str,
  43. default='./output')
  44. parser.add_argument(
  45. '--model_path',
  46. dest='model_path',
  47. help='The path of model for export',
  48. type=str,
  49. default=None)
  50. parser.add_argument(
  51. '--trimap',
  52. dest='trimap',
  53. help='Whether to input trimap',
  54. action='store_true')
  55. parser.add_argument(
  56. "--input_shape",
  57. nargs='+',
  58. help="Export the model with fixed input shape, such as 1 3 1024 1024.",
  59. type=int,
  60. default=None)
  61. return parser.parse_args()
  62. def main(args):
  63. assert args.cfg is not None, \
  64. 'No configuration file specified, please set --config'
  65. cfg = Config(args.cfg)
  66. builder = MatBuilder(cfg)
  67. paddleseg.utils.show_env_info()
  68. paddleseg.utils.show_cfg_info(cfg)
  69. os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
  70. net = builder.model
  71. net.eval()
  72. if args.model_path:
  73. para_state_dict = paddle.load(args.model_path)
  74. net.set_dict(para_state_dict)
  75. logger.info('Loaded trained params of model successfully.')
  76. if args.input_shape is None:
  77. shape = [None, 3, None, None]
  78. else:
  79. shape = args.input_shape
  80. input_spec = get_input_spec(
  81. net.__class__.__name__, shape=shape, trimap=args.trimap)
  82. net = paddle.jit.to_static(net, input_spec=input_spec)
  83. save_path = os.path.join(args.save_dir, 'model')
  84. paddle.jit.save(net, save_path)
  85. yml_file = os.path.join(args.save_dir, 'deploy.yaml')
  86. with open(yml_file, 'w') as file:
  87. transforms = cfg.val_dataset_cfg.get('transforms', [{
  88. 'type': 'Normalize'
  89. }])
  90. data = {
  91. 'Deploy': {
  92. 'transforms': transforms,
  93. 'model': 'model.pdmodel',
  94. 'params': 'model.pdiparams',
  95. 'input_shape': shape
  96. },
  97. 'ModelName': net.__class__.__name__
  98. }
  99. yaml.dump(data, file)
  100. logger.info(f'Model is saved in {args.save_dir}.')
  101. if __name__ == '__main__':
  102. args = parse_args()
  103. main(args)