utils.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. from urllib.parse import urlparse
  17. from paddleseg.utils import logger, download_pretrained_model
  18. def get_files(root_path):
  19. res = []
  20. for root, dirs, files in os.walk(root_path, followlinks=True):
  21. for f in files:
  22. if f.endswith(('.jpg', '.png', '.jpeg', 'JPG')):
  23. res.append(os.path.join(root, f))
  24. return res
  25. def get_image_list(image_path):
  26. """Get image list"""
  27. valid_suffix = [
  28. '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png'
  29. ]
  30. image_list = []
  31. image_dir = None
  32. if os.path.isfile(image_path):
  33. image_dir = None
  34. if os.path.splitext(image_path)[-1] in valid_suffix:
  35. image_list.append(image_path)
  36. else:
  37. image_dir = os.path.dirname(image_path)
  38. with open(image_path, 'r') as f:
  39. for line in f:
  40. line = line.strip()
  41. if len(line.split()) > 1:
  42. raise RuntimeError(
  43. 'There should be only one image path per line in `image_path` file. Wrong line: {}'
  44. .format(line))
  45. image_list.append(os.path.join(image_dir, line))
  46. elif os.path.isdir(image_path):
  47. image_dir = image_path
  48. for root, dirs, files in os.walk(image_path):
  49. for f in files:
  50. if '.ipynb_checkpoints' in root:
  51. continue
  52. if os.path.splitext(f)[-1] in valid_suffix:
  53. image_list.append(os.path.join(root, f))
  54. image_list.sort()
  55. else:
  56. raise FileNotFoundError(
  57. '`image_path` is not found. it should be an image file or a directory including images'
  58. )
  59. if len(image_list) == 0:
  60. raise RuntimeError('There are not image file in `image_path`')
  61. return image_list, image_dir
  62. def mkdir(path):
  63. sub_dir = os.path.dirname(path)
  64. if not os.path.exists(sub_dir):
  65. os.makedirs(sub_dir)
  66. def load_pretrained_model(model, pretrained_model):
  67. if pretrained_model is not None:
  68. logger.info('Loading pretrained model from {}'.format(pretrained_model))
  69. if urlparse(pretrained_model).netloc:
  70. pretrained_model = download_pretrained_model(pretrained_model)
  71. if os.path.exists(pretrained_model):
  72. para_state_dict = paddle.load(pretrained_model)
  73. model_state_dict = model.state_dict()
  74. keys = model_state_dict.keys()
  75. num_params_loaded = 0
  76. for k in keys:
  77. if k not in para_state_dict:
  78. logger.warning("{} is not in pretrained model".format(k))
  79. elif list(para_state_dict[k].shape) != list(model_state_dict[k]
  80. .shape):
  81. # When the input is more than 3 channels such as trimap-based method, padding zeros to load.
  82. para_shape = list(para_state_dict[k].shape)
  83. model_shape = list(model_state_dict[k].shape)
  84. if 'weight' in k \
  85. and len(para_shape) > 3 \
  86. and len(para_shape) > 3 \
  87. and para_shape[1] < model_shape[1] \
  88. and para_shape[0] == model_shape[0] \
  89. and para_shape[2] == model_shape[2] \
  90. and para_shape[3] == model_shape[3]:
  91. zeros_pad = paddle.zeros(
  92. (para_shape[0], model_shape[1] - para_shape[1],
  93. para_shape[2], para_shape[3]))
  94. para_state_dict[k] = paddle.concat(
  95. [para_state_dict[k], zeros_pad], axis=1)
  96. model_state_dict[k] = para_state_dict[k]
  97. num_params_loaded += 1
  98. else:
  99. logger.warning(
  100. "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
  101. .format(k, para_state_dict[k].shape,
  102. model_state_dict[k].shape))
  103. else:
  104. model_state_dict[k] = para_state_dict[k]
  105. num_params_loaded += 1
  106. model.set_dict(model_state_dict)
  107. logger.info("There are {}/{} variables loaded into {}.".format(
  108. num_params_loaded,
  109. len(model_state_dict), model.__class__.__name__))
  110. else:
  111. raise ValueError('The pretrained model directory is not Found: {}'.
  112. format(pretrained_model))
  113. else:
  114. logger.info(
  115. 'No pretrained model to load, {} will be trained from scratch.'.
  116. format(model.__class__.__name__))