matting_dataset.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import math
  16. import cv2
  17. import numpy as np
  18. import random
  19. import paddle
  20. from paddleseg.cvlibs import manager
  21. import ppmatting.transforms as T
  22. @manager.DATASETS.add_component
  23. class MattingDataset(paddle.io.Dataset):
  24. """
  25. Pass in a dataset that conforms to the format.
  26. matting_dataset/
  27. |--bg/
  28. |
  29. |--train/
  30. | |--fg/
  31. | |--alpha/
  32. |
  33. |--val/
  34. | |--fg/
  35. | |--alpha/
  36. | |--trimap/ (if existing)
  37. |
  38. |--train.txt
  39. |
  40. |--val.txt
  41. See README.md for more information of dataset.
  42. Args:
  43. dataset_root(str): The root path of dataset.
  44. transforms(list): Transforms for image.
  45. mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'.
  46. train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png`
  47. or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None.
  48. val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png`
  49. or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`.
  50. It shold be provided if mode equal to 'val'. Default: None.
  51. get_trimap (bool, optional): Whether to get triamp. Default: True.
  52. separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '.
  53. key_del (tuple|list, optional): The key which is not need will be delete to accellect data reader. Default: None.
  54. if_rssn (bool, optional): Whether to use RSSN while Compositing image. Including denoise and blur. Default: False.
  55. """
  56. def __init__(self,
  57. dataset_root,
  58. transforms,
  59. mode='train',
  60. train_file=None,
  61. val_file=None,
  62. get_trimap=True,
  63. separator=' ',
  64. key_del=None,
  65. if_rssn=False):
  66. super().__init__()
  67. self.dataset_root = dataset_root
  68. self.transforms = T.Compose(transforms)
  69. self.mode = mode
  70. self.get_trimap = get_trimap
  71. self.separator = separator
  72. self.key_del = key_del
  73. self.if_rssn = if_rssn
  74. # check file
  75. if mode == 'train' or mode == 'trainval':
  76. if train_file is None:
  77. raise ValueError(
  78. "When `mode` is 'train' or 'trainval', `train_file must be provided!"
  79. )
  80. if isinstance(train_file, str):
  81. train_file = [train_file]
  82. file_list = train_file
  83. if mode == 'val' or mode == 'trainval':
  84. if val_file is None:
  85. raise ValueError(
  86. "When `mode` is 'val' or 'trainval', `val_file must be provided!"
  87. )
  88. if isinstance(val_file, str):
  89. val_file = [val_file]
  90. file_list = val_file
  91. if mode == 'trainval':
  92. file_list = train_file + val_file
  93. # read file
  94. self.fg_bg_list = []
  95. for file in file_list:
  96. file = os.path.join(dataset_root, file)
  97. with open(file, 'r') as f:
  98. lines = f.readlines()
  99. for line in lines:
  100. line = line.strip()
  101. self.fg_bg_list.append(line)
  102. if mode != 'val':
  103. random.shuffle(self.fg_bg_list)
  104. def __getitem__(self, idx):
  105. data = {}
  106. fg_bg_file = self.fg_bg_list[idx]
  107. fg_bg_file = fg_bg_file.split(self.separator)
  108. data['img_name'] = fg_bg_file[0] # using in save prediction results
  109. fg_file = os.path.join(self.dataset_root, fg_bg_file[0])
  110. alpha_file = fg_file.replace('/fg', '/alpha')
  111. fg = cv2.imread(fg_file)
  112. alpha = cv2.imread(alpha_file, 0)
  113. data['alpha'] = alpha
  114. data['gt_fields'] = []
  115. # line is: fg [bg] [trimap]
  116. if len(fg_bg_file) >= 2:
  117. bg_file = os.path.join(self.dataset_root, fg_bg_file[1])
  118. bg = cv2.imread(bg_file)
  119. data['img'], data['fg'], data['bg'] = self.composite(fg, alpha, bg)
  120. if self.mode in ['train', 'trainval']:
  121. data['gt_fields'].append('fg')
  122. data['gt_fields'].append('bg')
  123. data['gt_fields'].append('alpha')
  124. if len(fg_bg_file) == 3 and self.get_trimap:
  125. if self.mode == 'val':
  126. trimap_path = os.path.join(self.dataset_root, fg_bg_file[2])
  127. if os.path.exists(trimap_path):
  128. data['trimap'] = trimap_path
  129. data['gt_fields'].append('trimap')
  130. data['ori_trimap'] = cv2.imread(trimap_path, 0)
  131. else:
  132. raise FileNotFoundError(
  133. 'trimap is not Found: {}'.format(fg_bg_file[2]))
  134. else:
  135. data['img'] = fg
  136. if self.mode in ['train', 'trainval']:
  137. data['fg'] = fg.copy()
  138. data['bg'] = fg.copy()
  139. data['gt_fields'].append('fg')
  140. data['gt_fields'].append('bg')
  141. data['gt_fields'].append('alpha')
  142. data['trans_info'] = [] # Record shape change information
  143. # Generate trimap from alpha if no trimap file provided
  144. if self.get_trimap:
  145. if 'trimap' not in data:
  146. data['trimap'] = self.gen_trimap(
  147. data['alpha'], mode=self.mode).astype('float32')
  148. data['gt_fields'].append('trimap')
  149. if self.mode == 'val':
  150. data['ori_trimap'] = data['trimap'].copy()
  151. # Delete key which is not need
  152. if self.key_del is not None:
  153. for key in self.key_del:
  154. if key in data.keys():
  155. data.pop(key)
  156. if key in data['gt_fields']:
  157. data['gt_fields'].remove(key)
  158. data = self.transforms(data)
  159. # When evaluation, gt should not be transforms.
  160. if self.mode == 'val':
  161. data['gt_fields'].append('alpha')
  162. data['img'] = data['img'].astype('float32')
  163. for key in data.get('gt_fields', []):
  164. data[key] = data[key].astype('float32')
  165. if 'trimap' in data:
  166. data['trimap'] = data['trimap'][np.newaxis, :, :]
  167. if 'ori_trimap' in data:
  168. data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :]
  169. data['alpha'] = data['alpha'][np.newaxis, :, :] / 255.
  170. return data
  171. def __len__(self):
  172. return len(self.fg_bg_list)
  173. def composite(self, fg, alpha, ori_bg):
  174. if self.if_rssn:
  175. if np.random.rand() < 0.5:
  176. fg = cv2.fastNlMeansDenoisingColored(fg, None, 3, 3, 7, 21)
  177. ori_bg = cv2.fastNlMeansDenoisingColored(ori_bg, None, 3, 3, 7,
  178. 21)
  179. if np.random.rand() < 0.5:
  180. radius = np.random.choice([19, 29, 39, 49, 59])
  181. ori_bg = cv2.GaussianBlur(ori_bg, (radius, radius), 0, 0)
  182. fg_h, fg_w = fg.shape[:2]
  183. ori_bg_h, ori_bg_w = ori_bg.shape[:2]
  184. wratio = fg_w / ori_bg_w
  185. hratio = fg_h / ori_bg_h
  186. ratio = wratio if wratio > hratio else hratio
  187. # Resize ori_bg if it is smaller than fg.
  188. if ratio > 1:
  189. resize_h = math.ceil(ori_bg_h * ratio)
  190. resize_w = math.ceil(ori_bg_w * ratio)
  191. bg = cv2.resize(
  192. ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
  193. else:
  194. bg = ori_bg
  195. bg = bg[0:fg_h, 0:fg_w, :]
  196. alpha = alpha / 255
  197. alpha = np.expand_dims(alpha, axis=2)
  198. image = alpha * fg + (1 - alpha) * bg
  199. image = image.astype(np.uint8)
  200. return image, fg, bg
  201. @staticmethod
  202. def gen_trimap(alpha, mode='train', eval_kernel=25):
  203. if mode == 'train':
  204. k_size = random.choice(range(2, 5))
  205. iterations = np.random.randint(5, 15)
  206. kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
  207. (k_size, k_size))
  208. dilated = cv2.dilate(alpha, kernel, iterations=iterations)
  209. eroded = cv2.erode(alpha, kernel, iterations=iterations)
  210. trimap = np.zeros(alpha.shape)
  211. trimap.fill(128)
  212. trimap[eroded > 254.5] = 255
  213. trimap[dilated < 0.5] = 0
  214. else:
  215. k_size = eval_kernel
  216. kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
  217. (k_size, k_size))
  218. dilated = cv2.dilate(alpha, kernel)
  219. eroded = cv2.erode(alpha, kernel)
  220. trimap = np.zeros(alpha.shape)
  221. trimap.fill(128)
  222. trimap[eroded > 254.5] = 255
  223. trimap[dilated < 0.5] = 0
  224. return trimap