transforms.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import random
  16. import string
  17. import cv2
  18. import numpy as np
  19. from paddleseg.transforms import functional
  20. from paddleseg.cvlibs import manager
  21. from paddleseg.utils import seg_env
  22. from PIL import Image
  23. @manager.TRANSFORMS.add_component
  24. class Compose:
  25. """
  26. Do transformation on input data with corresponding pre-processing and augmentation operations.
  27. The shape of input data to all operations is [height, width, channels].
  28. """
  29. def __init__(self, transforms, to_rgb=True):
  30. if not isinstance(transforms, list):
  31. raise TypeError('The transforms must be a list!')
  32. self.transforms = transforms
  33. self.to_rgb = to_rgb
  34. def __call__(self, data):
  35. """
  36. Args:
  37. data (dict): The data to transform.
  38. Returns:
  39. dict: Data after transformation
  40. """
  41. if 'trans_info' not in data:
  42. data['trans_info'] = []
  43. for op in self.transforms:
  44. data = op(data)
  45. if data is None:
  46. return None
  47. data['img'] = np.transpose(data['img'], (2, 0, 1))
  48. for key in data.get('gt_fields', []):
  49. if len(data[key].shape) == 2:
  50. continue
  51. data[key] = np.transpose(data[key], (2, 0, 1))
  52. return data
  53. @manager.TRANSFORMS.add_component
  54. class LoadImages:
  55. def __init__(self, to_rgb=True):
  56. self.to_rgb = to_rgb
  57. def __call__(self, data):
  58. if isinstance(data['img'], str):
  59. data['img'] = cv2.imread(data['img'])
  60. for key in data.get('gt_fields', []):
  61. if isinstance(data[key], str):
  62. data[key] = cv2.imread(data[key], cv2.IMREAD_UNCHANGED)
  63. # if alpha and trimap has 3 channels, extract one.
  64. if key in ['alpha', 'trimap']:
  65. if len(data[key].shape) > 2:
  66. data[key] = data[key][:, :, 0]
  67. if self.to_rgb:
  68. data['img'] = cv2.cvtColor(data['img'], cv2.COLOR_BGR2RGB)
  69. for key in data.get('gt_fields', []):
  70. if len(data[key].shape) == 2:
  71. continue
  72. data[key] = cv2.cvtColor(data[key], cv2.COLOR_BGR2RGB)
  73. return data
  74. @manager.TRANSFORMS.add_component
  75. class Resize:
  76. def __init__(self, target_size=(512, 512), random_interp=False):
  77. if isinstance(target_size, list) or isinstance(target_size, tuple):
  78. if len(target_size) != 2:
  79. raise ValueError(
  80. '`target_size` should include 2 elements, but it is {}'.
  81. format(target_size))
  82. else:
  83. raise TypeError(
  84. "Type of `target_size` is invalid. It should be list or tuple, but it is {}"
  85. .format(type(target_size)))
  86. self.target_size = target_size
  87. self.random_interp = random_interp
  88. self.interps = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC]
  89. def __call__(self, data):
  90. if self.random_interp:
  91. interp = np.random.choice(self.interps)
  92. else:
  93. interp = cv2.INTER_LINEAR
  94. data['trans_info'].append(('resize', data['img'].shape[0:2]))
  95. data['img'] = functional.resize(data['img'], self.target_size, interp)
  96. for key in data.get('gt_fields', []):
  97. if key == 'trimap':
  98. data[key] = functional.resize(data[key], self.target_size,
  99. cv2.INTER_NEAREST)
  100. else:
  101. data[key] = functional.resize(data[key], self.target_size,
  102. interp)
  103. return data
  104. @manager.TRANSFORMS.add_component
  105. class RandomResize:
  106. """
  107. Resize image to a size determinned by `scale` and `size`.
  108. Args:
  109. size(tuple|list): The reference size to resize. A tuple or list with length 2.
  110. scale(tupel|list, optional): A range of scale base on `size`. A tuple or list with length 2. Default: None.
  111. """
  112. def __init__(self, size=None, scale=None):
  113. if isinstance(size, list) or isinstance(size, tuple):
  114. if len(size) != 2:
  115. raise ValueError(
  116. '`size` should include 2 elements, but it is {}'.format(
  117. size))
  118. elif size is not None:
  119. raise TypeError(
  120. "Type of `size` is invalid. It should be list or tuple, but it is {}"
  121. .format(type(size)))
  122. if scale is not None:
  123. if isinstance(scale, list) or isinstance(scale, tuple):
  124. if len(scale) != 2:
  125. raise ValueError(
  126. '`scale` should include 2 elements, but it is {}'.
  127. format(scale))
  128. else:
  129. raise TypeError(
  130. "Type of `scale` is invalid. It should be list or tuple, but it is {}"
  131. .format(type(scale)))
  132. self.size = size
  133. self.scale = scale
  134. def __call__(self, data):
  135. h, w = data['img'].shape[:2]
  136. if self.scale is not None:
  137. scale = np.random.uniform(self.scale[0], self.scale[1])
  138. else:
  139. scale = 1.
  140. if self.size is not None:
  141. scale_factor = max(self.size[0] / w, self.size[1] / h)
  142. else:
  143. scale_factor = 1
  144. scale = scale * scale_factor
  145. w = int(round(w * scale))
  146. h = int(round(h * scale))
  147. data['img'] = functional.resize(data['img'], (w, h))
  148. for key in data.get('gt_fields', []):
  149. if key == 'trimap':
  150. data[key] = functional.resize(data[key], (w, h),
  151. cv2.INTER_NEAREST)
  152. else:
  153. data[key] = functional.resize(data[key], (w, h))
  154. return data
  155. @manager.TRANSFORMS.add_component
  156. class ResizeByLong:
  157. """
  158. Resize the long side of an image to given size, and then scale the other side proportionally.
  159. Args:
  160. long_size (int): The target size of long side.
  161. """
  162. def __init__(self, long_size):
  163. self.long_size = long_size
  164. def __call__(self, data):
  165. data['trans_info'].append(('resize', data['img'].shape[0:2]))
  166. data['img'] = functional.resize_long(data['img'], self.long_size)
  167. for key in data.get('gt_fields', []):
  168. if key == 'trimap':
  169. data[key] = functional.resize_long(data[key], self.long_size,
  170. cv2.INTER_NEAREST)
  171. else:
  172. data[key] = functional.resize_long(data[key], self.long_size)
  173. return data
  174. @manager.TRANSFORMS.add_component
  175. class ResizeByShort:
  176. """
  177. Resize the short side of an image to given size, and then scale the other side proportionally.
  178. Args:
  179. short_size (int): The target size of short side.
  180. """
  181. def __init__(self, short_size):
  182. self.short_size = short_size
  183. def __call__(self, data):
  184. data['trans_info'].append(('resize', data['img'].shape[0:2]))
  185. data['img'] = functional.resize_short(data['img'], self.short_size)
  186. for key in data.get('gt_fields', []):
  187. if key == 'trimap':
  188. data[key] = functional.resize_short(data[key], self.short_size,
  189. cv2.INTER_NEAREST)
  190. else:
  191. data[key] = functional.resize_short(data[key], self.short_size)
  192. return data
  193. @manager.TRANSFORMS.add_component
  194. class ResizeToIntMult:
  195. """
  196. Resize to some int muitple, d.g. 32.
  197. """
  198. def __init__(self, mult_int=32):
  199. self.mult_int = mult_int
  200. def __call__(self, data):
  201. data['trans_info'].append(('resize', data['img'].shape[0:2]))
  202. h, w = data['img'].shape[0:2]
  203. rw = w - w % self.mult_int
  204. rh = h - h % self.mult_int
  205. data['img'] = functional.resize(data['img'], (rw, rh))
  206. for key in data.get('gt_fields', []):
  207. if key == 'trimap':
  208. data[key] = functional.resize(data[key], (rw, rh),
  209. cv2.INTER_NEAREST)
  210. else:
  211. data[key] = functional.resize(data[key], (rw, rh))
  212. return data
  213. @manager.TRANSFORMS.add_component
  214. class Normalize:
  215. """
  216. Normalize an image.
  217. Args:
  218. mean (list, optional): The mean value of a data set. Default: [0.5, 0.5, 0.5].
  219. std (list, optional): The standard deviation of a data set. Default: [0.5, 0.5, 0.5].
  220. Raises:
  221. ValueError: When mean/std is not list or any value in std is 0.
  222. """
  223. def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
  224. self.mean = mean
  225. self.std = std
  226. if not (isinstance(self.mean,
  227. (list, tuple)) and isinstance(self.std,
  228. (list, tuple))):
  229. raise ValueError(
  230. "{}: input type is invalid. It should be list or tuple".format(
  231. self))
  232. from functools import reduce
  233. if reduce(lambda x, y: x * y, self.std) == 0:
  234. raise ValueError('{}: std is invalid!'.format(self))
  235. def __call__(self, data):
  236. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  237. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  238. data['img'] = functional.normalize(data['img'], mean, std)
  239. if 'fg' in data.get('gt_fields', []):
  240. data['fg'] = functional.normalize(data['fg'], mean, std)
  241. if 'bg' in data.get('gt_fields', []):
  242. data['bg'] = functional.normalize(data['bg'], mean, std)
  243. return data
  244. @manager.TRANSFORMS.add_component
  245. class RandomCropByAlpha:
  246. """
  247. Randomly crop while centered on uncertain area by a certain probability.
  248. Args:
  249. crop_size (tuple|list): The size you want to crop from image.
  250. p (float): The probability centered on uncertain area.
  251. """
  252. def __init__(self, crop_size=((320, 320), (480, 480), (640, 640)),
  253. prob=0.5):
  254. self.crop_size = crop_size
  255. self.prob = prob
  256. def __call__(self, data):
  257. idex = np.random.randint(low=0, high=len(self.crop_size))
  258. crop_w, crop_h = self.crop_size[idex]
  259. img_h = data['img'].shape[0]
  260. img_w = data['img'].shape[1]
  261. if np.random.rand() < self.prob:
  262. crop_center = np.where((data['alpha'] > 0) & (data['alpha'] < 255))
  263. center_h_array, center_w_array = crop_center
  264. if len(center_h_array) == 0:
  265. return data
  266. rand_ind = np.random.randint(len(center_h_array))
  267. center_h = center_h_array[rand_ind]
  268. center_w = center_w_array[rand_ind]
  269. delta_h = crop_h // 2
  270. delta_w = crop_w // 2
  271. start_h = max(0, center_h - delta_h)
  272. start_w = max(0, center_w - delta_w)
  273. else:
  274. start_h = 0
  275. start_w = 0
  276. if img_h > crop_h:
  277. start_h = np.random.randint(img_h - crop_h + 1)
  278. if img_w > crop_w:
  279. start_w = np.random.randint(img_w - crop_w + 1)
  280. end_h = min(img_h, start_h + crop_h)
  281. end_w = min(img_w, start_w + crop_w)
  282. data['img'] = data['img'][start_h:end_h, start_w:end_w]
  283. for key in data.get('gt_fields', []):
  284. data[key] = data[key][start_h:end_h, start_w:end_w]
  285. return data
  286. @manager.TRANSFORMS.add_component
  287. class RandomCrop:
  288. """
  289. Randomly crop
  290. Args:
  291. crop_size (tuple|list): The size you want to crop from image.
  292. """
  293. def __init__(self, crop_size=((320, 320), (480, 480), (640, 640))):
  294. if not isinstance(crop_size[0], (list, tuple)):
  295. crop_size = [crop_size]
  296. self.crop_size = crop_size
  297. def __call__(self, data):
  298. idex = np.random.randint(low=0, high=len(self.crop_size))
  299. crop_w, crop_h = self.crop_size[idex]
  300. img_h, img_w = data['img'].shape[0:2]
  301. start_h = 0
  302. start_w = 0
  303. if img_h > crop_h:
  304. start_h = np.random.randint(img_h - crop_h + 1)
  305. if img_w > crop_w:
  306. start_w = np.random.randint(img_w - crop_w + 1)
  307. end_h = min(img_h, start_h + crop_h)
  308. end_w = min(img_w, start_w + crop_w)
  309. data['img'] = data['img'][start_h:end_h, start_w:end_w]
  310. for key in data.get('gt_fields', []):
  311. data[key] = data[key][start_h:end_h, start_w:end_w]
  312. return data
  313. @manager.TRANSFORMS.add_component
  314. class LimitLong:
  315. """
  316. Limit the long edge of image.
  317. If the long edge is larger than max_long, resize the long edge
  318. to max_long, while scale the short edge proportionally.
  319. If the long edge is smaller than min_long, resize the long edge
  320. to min_long, while scale the short edge proportionally.
  321. Args:
  322. max_long (int, optional): If the long edge of image is larger than max_long,
  323. it will be resize to max_long. Default: None.
  324. min_long (int, optional): If the long edge of image is smaller than min_long,
  325. it will be resize to min_long. Default: None.
  326. """
  327. def __init__(self, max_long=None, min_long=None):
  328. if max_long is not None:
  329. if not isinstance(max_long, int):
  330. raise TypeError(
  331. "Type of `max_long` is invalid. It should be int, but it is {}"
  332. .format(type(max_long)))
  333. if min_long is not None:
  334. if not isinstance(min_long, int):
  335. raise TypeError(
  336. "Type of `min_long` is invalid. It should be int, but it is {}"
  337. .format(type(min_long)))
  338. if (max_long is not None) and (min_long is not None):
  339. if min_long > max_long:
  340. raise ValueError(
  341. '`max_long should not smaller than min_long, but they are {} and {}'
  342. .format(max_long, min_long))
  343. self.max_long = max_long
  344. self.min_long = min_long
  345. def __call__(self, data):
  346. h, w = data['img'].shape[:2]
  347. long_edge = max(h, w)
  348. target = long_edge
  349. if (self.max_long is not None) and (long_edge > self.max_long):
  350. target = self.max_long
  351. elif (self.min_long is not None) and (long_edge < self.min_long):
  352. target = self.min_long
  353. data['trans_info'].append(('resize', data['img'].shape[0:2]))
  354. if target != long_edge:
  355. data['img'] = functional.resize_long(data['img'], target)
  356. for key in data.get('gt_fields', []):
  357. if key == 'trimap':
  358. data[key] = functional.resize_long(data[key], target,
  359. cv2.INTER_NEAREST)
  360. else:
  361. data[key] = functional.resize_long(data[key], target)
  362. return data
  363. @manager.TRANSFORMS.add_component
  364. class LimitShort:
  365. """
  366. Limit the short edge of image.
  367. If the short edge is larger than max_short, resize the short edge
  368. to max_short, while scale the long edge proportionally.
  369. If the short edge is smaller than min_short, resize the short edge
  370. to min_short, while scale the long edge proportionally.
  371. Args:
  372. max_short (int, optional): If the short edge of image is larger than max_short,
  373. it will be resize to max_short. Default: None.
  374. min_short (int, optional): If the short edge of image is smaller than min_short,
  375. it will be resize to min_short. Default: None.
  376. """
  377. def __init__(self, max_short=None, min_short=None):
  378. if max_short is not None:
  379. if not isinstance(max_short, int):
  380. raise TypeError(
  381. "Type of `max_short` is invalid. It should be int, but it is {}"
  382. .format(type(max_short)))
  383. if min_short is not None:
  384. if not isinstance(min_short, int):
  385. raise TypeError(
  386. "Type of `min_short` is invalid. It should be int, but it is {}"
  387. .format(type(min_short)))
  388. if (max_short is not None) and (min_short is not None):
  389. if min_short > max_short:
  390. raise ValueError(
  391. '`max_short should not smaller than min_short, but they are {} and {}'
  392. .format(max_short, min_short))
  393. self.max_short = max_short
  394. self.min_short = min_short
  395. def __call__(self, data):
  396. h, w = data['img'].shape[:2]
  397. short_edge = min(h, w)
  398. target = short_edge
  399. if (self.max_short is not None) and (short_edge > self.max_short):
  400. target = self.max_short
  401. elif (self.min_short is not None) and (short_edge < self.min_short):
  402. target = self.min_short
  403. data['trans_info'].append(('resize', data['img'].shape[0:2]))
  404. if target != short_edge:
  405. data['img'] = functional.resize_short(data['img'], target)
  406. for key in data.get('gt_fields', []):
  407. if key == 'trimap':
  408. data[key] = functional.resize_short(data[key], target,
  409. cv2.INTER_NEAREST)
  410. else:
  411. data[key] = functional.resize_short(data[key], target)
  412. return data
  413. @manager.TRANSFORMS.add_component
  414. class RandomHorizontalFlip:
  415. """
  416. Flip an image horizontally with a certain probability.
  417. Args:
  418. prob (float, optional): A probability of horizontally flipping. Default: 0.5.
  419. """
  420. def __init__(self, prob=0.5):
  421. self.prob = prob
  422. def __call__(self, data):
  423. if random.random() < self.prob:
  424. data['img'] = functional.horizontal_flip(data['img'])
  425. for key in data.get('gt_fields', []):
  426. data[key] = functional.horizontal_flip(data[key])
  427. return data
  428. @manager.TRANSFORMS.add_component
  429. class RandomBlur:
  430. """
  431. Blurring an image by a Gaussian function with a certain probability.
  432. Args:
  433. prob (float, optional): A probability of blurring an image. Default: 0.1.
  434. """
  435. def __init__(self, prob=0.1):
  436. self.prob = prob
  437. def __call__(self, data):
  438. if self.prob <= 0:
  439. n = 0
  440. elif self.prob >= 1:
  441. n = 1
  442. else:
  443. n = int(1.0 / self.prob)
  444. if n > 0:
  445. if np.random.randint(0, n) == 0:
  446. radius = np.random.randint(3, 10)
  447. if radius % 2 != 1:
  448. radius = radius + 1
  449. if radius > 9:
  450. radius = 9
  451. data['img'] = cv2.GaussianBlur(data['img'], (radius, radius), 0,
  452. 0)
  453. for key in data.get('gt_fields', []):
  454. if key == 'trimap':
  455. continue
  456. data[key] = cv2.GaussianBlur(data[key], (radius, radius), 0,
  457. 0)
  458. return data
  459. @manager.TRANSFORMS.add_component
  460. class RandomDistort:
  461. """
  462. Distort an image with random configurations.
  463. Args:
  464. brightness_range (float, optional): A range of brightness. Default: 0.5.
  465. brightness_prob (float, optional): A probability of adjusting brightness. Default: 0.5.
  466. contrast_range (float, optional): A range of contrast. Default: 0.5.
  467. contrast_prob (float, optional): A probability of adjusting contrast. Default: 0.5.
  468. saturation_range (float, optional): A range of saturation. Default: 0.5.
  469. saturation_prob (float, optional): A probability of adjusting saturation. Default: 0.5.
  470. hue_range (int, optional): A range of hue. Default: 18.
  471. hue_prob (float, optional): A probability of adjusting hue. Default: 0.5.
  472. """
  473. def __init__(self,
  474. brightness_range=0.5,
  475. brightness_prob=0.5,
  476. contrast_range=0.5,
  477. contrast_prob=0.5,
  478. saturation_range=0.5,
  479. saturation_prob=0.5,
  480. hue_range=18,
  481. hue_prob=0.5):
  482. self.brightness_range = brightness_range
  483. self.brightness_prob = brightness_prob
  484. self.contrast_range = contrast_range
  485. self.contrast_prob = contrast_prob
  486. self.saturation_range = saturation_range
  487. self.saturation_prob = saturation_prob
  488. self.hue_range = hue_range
  489. self.hue_prob = hue_prob
  490. def __call__(self, data):
  491. brightness_lower = 1 - self.brightness_range
  492. brightness_upper = 1 + self.brightness_range
  493. contrast_lower = 1 - self.contrast_range
  494. contrast_upper = 1 + self.contrast_range
  495. saturation_lower = 1 - self.saturation_range
  496. saturation_upper = 1 + self.saturation_range
  497. hue_lower = -self.hue_range
  498. hue_upper = self.hue_range
  499. ops = [
  500. functional.brightness, functional.contrast, functional.saturation,
  501. functional.hue
  502. ]
  503. random.shuffle(ops)
  504. params_dict = {
  505. 'brightness': {
  506. 'brightness_lower': brightness_lower,
  507. 'brightness_upper': brightness_upper
  508. },
  509. 'contrast': {
  510. 'contrast_lower': contrast_lower,
  511. 'contrast_upper': contrast_upper
  512. },
  513. 'saturation': {
  514. 'saturation_lower': saturation_lower,
  515. 'saturation_upper': saturation_upper
  516. },
  517. 'hue': {
  518. 'hue_lower': hue_lower,
  519. 'hue_upper': hue_upper
  520. }
  521. }
  522. prob_dict = {
  523. 'brightness': self.brightness_prob,
  524. 'contrast': self.contrast_prob,
  525. 'saturation': self.saturation_prob,
  526. 'hue': self.hue_prob
  527. }
  528. im = data['img'].astype('uint8')
  529. im = Image.fromarray(im)
  530. for id in range(len(ops)):
  531. params = params_dict[ops[id].__name__]
  532. params['im'] = im
  533. prob = prob_dict[ops[id].__name__]
  534. if np.random.uniform(0, 1) < prob:
  535. im = ops[id](**params)
  536. data['img'] = np.asarray(im)
  537. for key in data.get('gt_fields', []):
  538. if key in ['alpha', 'trimap']:
  539. continue
  540. else:
  541. im = data[key].astype('uint8')
  542. im = Image.fromarray(im)
  543. for id in range(len(ops)):
  544. params = params_dict[ops[id].__name__]
  545. params['im'] = im
  546. prob = prob_dict[ops[id].__name__]
  547. if np.random.uniform(0, 1) < prob:
  548. im = ops[id](**params)
  549. data[key] = np.asarray(im)
  550. return data
  551. @manager.TRANSFORMS.add_component
  552. class Padding:
  553. """
  554. Add bottom-right padding to a raw image or annotation image.
  555. Args:
  556. target_size (list|tuple): The target size after padding.
  557. im_padding_value (list, optional): The padding value of raw image.
  558. Default: [127.5, 127.5, 127.5].
  559. label_padding_value (int, optional): The padding value of annotation image. Default: 255.
  560. Raises:
  561. TypeError: When target_size is neither list nor tuple.
  562. ValueError: When the length of target_size is not 2.
  563. """
  564. def __init__(self, target_size, im_padding_value=(127.5, 127.5, 127.5)):
  565. if isinstance(target_size, list) or isinstance(target_size, tuple):
  566. if len(target_size) != 2:
  567. raise ValueError(
  568. '`target_size` should include 2 elements, but it is {}'.
  569. format(target_size))
  570. else:
  571. raise TypeError(
  572. "Type of target_size is invalid. It should be list or tuple, now is {}"
  573. .format(type(target_size)))
  574. self.target_size = target_size
  575. self.im_padding_value = im_padding_value
  576. def __call__(self, data):
  577. im_height, im_width = data['img'].shape[0], data['img'].shape[1]
  578. target_height = self.target_size[1]
  579. target_width = self.target_size[0]
  580. pad_height = max(0, target_height - im_height)
  581. pad_width = max(0, target_width - im_width)
  582. data['trans_info'].append(('padding', data['img'].shape[0:2]))
  583. if (pad_height == 0) and (pad_width == 0):
  584. return data
  585. else:
  586. data['img'] = cv2.copyMakeBorder(
  587. data['img'],
  588. 0,
  589. pad_height,
  590. 0,
  591. pad_width,
  592. cv2.BORDER_CONSTANT,
  593. value=self.im_padding_value)
  594. for key in data.get('gt_fields', []):
  595. if key in ['trimap', 'alpha']:
  596. value = 0
  597. else:
  598. value = self.im_padding_value
  599. data[key] = cv2.copyMakeBorder(
  600. data[key],
  601. 0,
  602. pad_height,
  603. 0,
  604. pad_width,
  605. cv2.BORDER_CONSTANT,
  606. value=value)
  607. return data
  608. @manager.TRANSFORMS.add_component
  609. class RandomSharpen:
  610. def __init__(self, prob=0.1):
  611. if prob < 0:
  612. self.prob = 0
  613. elif prob > 1:
  614. self.prob = 1
  615. else:
  616. self.prob = prob
  617. def __call__(self, data):
  618. if np.random.rand() > self.prob:
  619. return data
  620. radius = np.random.choice([0, 3, 5, 7, 9])
  621. w = np.random.uniform(0.1, 0.5)
  622. blur_img = cv2.GaussianBlur(data['img'], (radius, radius), 5)
  623. data['img'] = cv2.addWeighted(data['img'], 1 + w, blur_img, -w, 0)
  624. for key in data.get('gt_fields', []):
  625. if key == 'trimap' or key == 'alpha':
  626. continue
  627. blur_img = cv2.GaussianBlur(data[key], (0, 0), 5)
  628. data[key] = cv2.addWeighted(data[key], 1.5, blur_img, -0.5, 0)
  629. return data
  630. @manager.TRANSFORMS.add_component
  631. class RandomNoise:
  632. def __init__(self, prob=0.1):
  633. if prob < 0:
  634. self.prob = 0
  635. elif prob > 1:
  636. self.prob = 1
  637. else:
  638. self.prob = prob
  639. def __call__(self, data):
  640. if np.random.rand() > self.prob:
  641. return data
  642. mean = np.random.uniform(0, 0.04)
  643. var = np.random.uniform(0, 0.001)
  644. noise = np.random.normal(mean, var**0.5, data['img'].shape) * 255
  645. data['img'] = data['img'] + noise
  646. data['img'] = np.clip(data['img'], 0, 255)
  647. return data
  648. @manager.TRANSFORMS.add_component
  649. class RandomReJpeg:
  650. def __init__(self, prob=0.1):
  651. if prob < 0:
  652. self.prob = 0
  653. elif prob > 1:
  654. self.prob = 1
  655. else:
  656. self.prob = prob
  657. def __call__(self, data):
  658. if np.random.rand() > self.prob:
  659. return data
  660. q = np.random.randint(70, 95)
  661. img = data['img'].astype('uint8')
  662. # Ensure no conflicts between processes
  663. tmp_name = str(os.getpid()) + '.jpg'
  664. tmp_name = os.path.join(seg_env.TMP_HOME, tmp_name)
  665. cv2.imwrite(tmp_name, img, [int(cv2.IMWRITE_JPEG_QUALITY), q])
  666. data['img'] = cv2.imread(tmp_name)
  667. return data