123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- import os
- import math
- import cv2
- import numpy as np
- import random
- import paddle
- from paddleseg.cvlibs import manager
- import ppmatting.transforms as T
- @manager.DATASETS.add_component
- class MattingDataset(paddle.io.Dataset):
- """
- Pass in a dataset that conforms to the format.
- matting_dataset/
- |--bg/
- |
- |--train/
- | |--fg/
- | |--alpha/
- |
- |--val/
- | |--fg/
- | |--alpha/
- | |--trimap/ (if existing)
- |
- |--train.txt
- |
- |--val.txt
- See README.md for more information of dataset.
- Args:
- dataset_root(str): The root path of dataset.
- transforms(list): Transforms for image.
- mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'trainval'). Default: 'train'.
- train_file (str|list, optional): File list is used to train. It should be `foreground_image.png background_image.png`
- or `foreground_image.png`. It shold be provided if mode equal to 'train'. Default: None.
- val_file (str|list, optional): File list is used to evaluation. It should be `foreground_image.png background_image.png`
- or `foreground_image.png` or ``foreground_image.png background_image.png trimap_image.png`.
- It shold be provided if mode equal to 'val'. Default: None.
- get_trimap (bool, optional): Whether to get triamp. Default: True.
- separator (str, optional): The separator of train_file or val_file. If file name contains ' ', '|' may be perfect. Default: ' '.
- key_del (tuple|list, optional): The key which is not need will be delete to accellect data reader. Default: None.
- if_rssn (bool, optional): Whether to use RSSN while Compositing image. Including denoise and blur. Default: False.
- """
- def __init__(self,
- dataset_root,
- transforms,
- mode='train',
- train_file=None,
- val_file=None,
- get_trimap=True,
- separator=' ',
- key_del=None,
- if_rssn=False):
- super().__init__()
- self.dataset_root = dataset_root
- self.transforms = T.Compose(transforms)
- self.mode = mode
- self.get_trimap = get_trimap
- self.separator = separator
- self.key_del = key_del
- self.if_rssn = if_rssn
- if mode == 'train' or mode == 'trainval':
- if train_file is None:
- raise ValueError(
- "When `mode` is 'train' or 'trainval', `train_file must be provided!"
- )
- if isinstance(train_file, str):
- train_file = [train_file]
- file_list = train_file
- if mode == 'val' or mode == 'trainval':
- if val_file is None:
- raise ValueError(
- "When `mode` is 'val' or 'trainval', `val_file must be provided!"
- )
- if isinstance(val_file, str):
- val_file = [val_file]
- file_list = val_file
- if mode == 'trainval':
- file_list = train_file + val_file
- self.fg_bg_list = []
- for file in file_list:
- file = os.path.join(dataset_root, file)
- with open(file, 'r') as f:
- lines = f.readlines()
- for line in lines:
- line = line.strip()
- self.fg_bg_list.append(line)
- if mode != 'val':
- random.shuffle(self.fg_bg_list)
- def __getitem__(self, idx):
- data = {}
- fg_bg_file = self.fg_bg_list[idx]
- fg_bg_file = fg_bg_file.split(self.separator)
- data['img_name'] = fg_bg_file[0]
- fg_file = os.path.join(self.dataset_root, fg_bg_file[0])
- alpha_file = fg_file.replace('/fg', '/alpha')
- fg = cv2.imread(fg_file)
- alpha = cv2.imread(alpha_file, 0)
- data['alpha'] = alpha
- data['gt_fields'] = []
- if len(fg_bg_file) >= 2:
- bg_file = os.path.join(self.dataset_root, fg_bg_file[1])
- bg = cv2.imread(bg_file)
- data['img'], data['fg'], data['bg'] = self.composite(fg, alpha, bg)
- if self.mode in ['train', 'trainval']:
- data['gt_fields'].append('fg')
- data['gt_fields'].append('bg')
- data['gt_fields'].append('alpha')
- if len(fg_bg_file) == 3 and self.get_trimap:
- if self.mode == 'val':
- trimap_path = os.path.join(self.dataset_root, fg_bg_file[2])
- if os.path.exists(trimap_path):
- data['trimap'] = trimap_path
- data['gt_fields'].append('trimap')
- data['ori_trimap'] = cv2.imread(trimap_path, 0)
- else:
- raise FileNotFoundError(
- 'trimap is not Found: {}'.format(fg_bg_file[2]))
- else:
- data['img'] = fg
- if self.mode in ['train', 'trainval']:
- data['fg'] = fg.copy()
- data['bg'] = fg.copy()
- data['gt_fields'].append('fg')
- data['gt_fields'].append('bg')
- data['gt_fields'].append('alpha')
- data['trans_info'] = []
- if self.get_trimap:
- if 'trimap' not in data:
- data['trimap'] = self.gen_trimap(
- data['alpha'], mode=self.mode).astype('float32')
- data['gt_fields'].append('trimap')
- if self.mode == 'val':
- data['ori_trimap'] = data['trimap'].copy()
- if self.key_del is not None:
- for key in self.key_del:
- if key in data.keys():
- data.pop(key)
- if key in data['gt_fields']:
- data['gt_fields'].remove(key)
- data = self.transforms(data)
- if self.mode == 'val':
- data['gt_fields'].append('alpha')
- data['img'] = data['img'].astype('float32')
- for key in data.get('gt_fields', []):
- data[key] = data[key].astype('float32')
- if 'trimap' in data:
- data['trimap'] = data['trimap'][np.newaxis, :, :]
- if 'ori_trimap' in data:
- data['ori_trimap'] = data['ori_trimap'][np.newaxis, :, :]
- data['alpha'] = data['alpha'][np.newaxis, :, :] / 255.
- return data
- def __len__(self):
- return len(self.fg_bg_list)
- def composite(self, fg, alpha, ori_bg):
- if self.if_rssn:
- if np.random.rand() < 0.5:
- fg = cv2.fastNlMeansDenoisingColored(fg, None, 3, 3, 7, 21)
- ori_bg = cv2.fastNlMeansDenoisingColored(ori_bg, None, 3, 3, 7,
- 21)
- if np.random.rand() < 0.5:
- radius = np.random.choice([19, 29, 39, 49, 59])
- ori_bg = cv2.GaussianBlur(ori_bg, (radius, radius), 0, 0)
- fg_h, fg_w = fg.shape[:2]
- ori_bg_h, ori_bg_w = ori_bg.shape[:2]
- wratio = fg_w / ori_bg_w
- hratio = fg_h / ori_bg_h
- ratio = wratio if wratio > hratio else hratio
- if ratio > 1:
- resize_h = math.ceil(ori_bg_h * ratio)
- resize_w = math.ceil(ori_bg_w * ratio)
- bg = cv2.resize(
- ori_bg, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
- else:
- bg = ori_bg
- bg = bg[0:fg_h, 0:fg_w, :]
- alpha = alpha / 255
- alpha = np.expand_dims(alpha, axis=2)
- image = alpha * fg + (1 - alpha) * bg
- image = image.astype(np.uint8)
- return image, fg, bg
- @staticmethod
- def gen_trimap(alpha, mode='train', eval_kernel=25):
- if mode == 'train':
- k_size = random.choice(range(2, 5))
- iterations = np.random.randint(5, 15)
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
- (k_size, k_size))
- dilated = cv2.dilate(alpha, kernel, iterations=iterations)
- eroded = cv2.erode(alpha, kernel, iterations=iterations)
- trimap = np.zeros(alpha.shape)
- trimap.fill(128)
- trimap[eroded > 254.5] = 255
- trimap[dilated < 0.5] = 0
- else:
- k_size = eval_kernel
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
- (k_size, k_size))
- dilated = cv2.dilate(alpha, kernel)
- eroded = cv2.erode(alpha, kernel)
- trimap = np.zeros(alpha.shape)
- trimap.fill(128)
- trimap[eroded > 254.5] = 255
- trimap[dilated < 0.5] = 0
- return trimap